@@ -2,69 +2,176 @@ package io.airbyte.integrations.destination.s3_v2
2
2
3
3
import io.airbyte.cdk.load.command.DestinationCatalog
4
4
import io.airbyte.cdk.load.file.object_storage.PathFactory
5
+ import io.airbyte.cdk.load.file.object_storage.StreamingUpload
5
6
import io.airbyte.cdk.load.file.s3.S3Client
7
+ import io.airbyte.cdk.load.file.s3.S3Object
8
+ import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9
+ import io.airbyte.cdk.load.message.CheckpointMessage
10
+ import io.airbyte.cdk.load.message.DestinationFile
11
+ import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12
+ import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13
+ import io.airbyte.cdk.load.message.DestinationRecord
14
+ import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15
+ import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16
+ import io.airbyte.cdk.load.message.GlobalCheckpoint
17
+ import io.airbyte.cdk.load.message.StreamCheckpoint
18
+ import io.airbyte.cdk.load.message.Undefined
19
+ import io.airbyte.cdk.load.state.SyncManager
6
20
import io.airbyte.cdk.load.task.SelfTerminating
7
21
import io.airbyte.cdk.load.task.TerminalCondition
22
+ import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
8
23
import io.airbyte.cdk.load.write.WriteOpOverride
24
+ import io.airbyte.protocol.models.v0.AirbyteMessage
9
25
import io.github.oshai.kotlinlogging.KotlinLogging
10
26
import jakarta.inject.Singleton
27
+ import java.io.RandomAccessFile
28
+ import java.nio.file.Path
29
+ import java.util.concurrent.atomic.AtomicLong
30
+ import java.util.function.Consumer
11
31
import kotlin.random.Random
12
32
import kotlin.time.measureTime
13
33
import kotlinx.coroutines.Dispatchers
14
34
import kotlinx.coroutines.ExperimentalCoroutinesApi
15
35
import kotlinx.coroutines.async
16
36
import kotlinx.coroutines.awaitAll
37
+ import kotlinx.coroutines.channels.Channel
17
38
import kotlinx.coroutines.coroutineScope
39
+ import kotlinx.coroutines.flow.consumeAsFlow
40
+ import kotlinx.coroutines.flow.flowOn
41
+ import kotlinx.coroutines.launch
18
42
import kotlinx.coroutines.withContext
19
43
44
+ data class FileSegment (
45
+ val fileUrl : String ,
46
+ val objectKey : String ,
47
+ val upload : StreamingUpload <S3Object >,
48
+ val partNumber : Int ,
49
+ val callback : suspend () -> Unit = {}
50
+ )
51
+
20
52
@Singleton
21
53
class S3V2WriteOpOverride (
22
54
private val client : S3Client ,
23
55
private val catalog : DestinationCatalog ,
24
56
private val config : S3V2Configuration <* >,
25
57
private val pathFactory : PathFactory ,
58
+ private val reservingDeserializingInputFlow : ReservingDeserializingInputFlow ,
59
+ private val outputConsumer : Consumer <AirbyteMessage >,
60
+ private val syncManager : SyncManager ,
26
61
): WriteOpOverride {
27
62
private val log = KotlinLogging .logger { }
28
63
29
64
override val terminalCondition: TerminalCondition = SelfTerminating
30
65
31
66
@OptIn(ExperimentalCoroutinesApi ::class )
32
67
override suspend fun execute () = coroutineScope {
33
- val prng = Random (System .currentTimeMillis())
34
- val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35
- val randomString = randomPart.take(32 ).joinToString(" " ) { " %02x" .format(it) }
36
- val stream = catalog.streams.first()
37
- val objectKey = pathFactory.getFinalDirectory(stream) + " /mock-perf-test-$randomString "
38
-
39
- val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40
- val partsPerWorker = numParts / config.numUploadWorkers
41
- val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
42
-
43
- log.info {
44
- " root key=$objectKey ; part_size=${config.partSizeBytes} b; num_parts=$numParts (per_worker=$partsPerWorker ); total_size=${actualSizeBytes} b; num_workers=${config.numUploadWorkers} "
68
+ val mockPartQueue: Channel <FileSegment > = Channel (Channel .UNLIMITED )
69
+ val streamCount = AtomicLong (catalog.streams.size.toLong())
70
+ launch {
71
+ reservingDeserializingInputFlow.collect { (_, reservation) ->
72
+ when (val message = reservation.value) {
73
+ is GlobalCheckpoint -> launch {
74
+ outputConsumer.accept(
75
+ message.withDestinationStats(CheckpointMessage .Stats (0 )).asProtocolMessage()
76
+ )
77
+ }
78
+ is StreamCheckpoint -> launch {
79
+ val (_, count) = syncManager.getStreamManager(message.checkpoint.stream).markCheckpoint()
80
+ outputConsumer.accept(
81
+ message.withDestinationStats(CheckpointMessage .Stats (count)).asProtocolMessage()
82
+ )
83
+ }
84
+ is DestinationFile -> {
85
+ syncManager.getStreamManager(message.stream).incrementReadCount()
86
+ if (message.fileMessage.bytes == null ) {
87
+ throw IllegalStateException (" This can't work unless you set FileMessage.bytes!" )
88
+ }
89
+ val size = message.fileMessage.bytes!!
90
+ val numWholeParts = (size / config.partSizeBytes).toInt()
91
+ val numParts = numWholeParts + if (size % config.partSizeBytes > 0 ) 1 else 0
92
+ val fileUrl = message.fileMessage.fileUrl!!
93
+ log.info {
94
+ " Breaking file $fileUrl (size=${size} B) into $numParts ${config.partSizeBytes} B parts"
95
+ }
96
+ val stream = catalog.getStream(message.stream)
97
+ val directory = pathFactory.getFinalDirectory(stream)
98
+ val sourceFileName = message.fileMessage.sourceFileUrl!!
99
+ val objectKey = Path .of(directory, sourceFileName).toString()
100
+ val upload = client.startStreamingUpload(objectKey)
101
+ val partCounter = AtomicLong (numParts.toLong())
102
+ repeat(numParts) { partNumber ->
103
+ mockPartQueue.send(FileSegment (fileUrl, objectKey, upload, partNumber + 1 ) {
104
+ val partsRemaining = partCounter.decrementAndGet()
105
+ if (partsRemaining == 0L ) {
106
+ log.info {
107
+ " Finished uploading $numParts parts of $fileUrl "
108
+ }
109
+ upload.complete()
110
+ } else {
111
+ log.info {
112
+ " Finished uploading part ${partNumber + 1 } of $fileUrl . $partsRemaining parts remaining"
113
+ }
114
+ }
115
+ })
116
+ }
117
+ }
118
+ is DestinationFileStreamComplete ,
119
+ is DestinationFileStreamIncomplete -> {
120
+ if (streamCount.decrementAndGet() == 0L ) {
121
+ log.info {
122
+ " Read final stream complete, closing mockPartQueue"
123
+ }
124
+ mockPartQueue.close()
125
+ } else {
126
+ log.info {
127
+ " Read stream complete, ${streamCount.get()} streams remaining"
128
+ }
129
+ }
130
+ }
131
+ is DestinationRecordStreamComplete ,
132
+ is DestinationRecordStreamIncomplete ,
133
+ is DestinationRecord -> throw NotImplementedError (" This hack is only for files" )
134
+ Undefined ->
135
+ log.warn {
136
+ " Undefined message received. This should not happen."
137
+ }
138
+ }
139
+ reservation.release()
140
+ }
45
141
}
46
142
143
+ val totalBytesLoaded = AtomicLong (0L )
47
144
val duration = measureTime {
48
- withContext(Dispatchers .IO .limitedParallelism(config.numUploadWorkers)) {
145
+ val dispatcher = Dispatchers .IO .limitedParallelism(config.numUploadWorkers)
146
+ withContext(dispatcher) {
49
147
(0 until config.numUploadWorkers).map {
50
148
async {
51
- val workerKey = " $objectKey -worker-$it "
52
- log.info { " Starting upload to $workerKey " }
53
- val upload = client.startStreamingUpload(workerKey)
54
- repeat(partsPerWorker) {
55
- log.info { " Uploading part ${it + 1 } of $workerKey " }
56
- upload.uploadPart(randomPart, it + 1 )
149
+ mockPartQueue.consumeAsFlow().collect { segment ->
150
+ log.info { " Starting upload to ${segment.objectKey} part ${segment.partNumber} " }
151
+ RandomAccessFile (segment.fileUrl, " r" ).use { file ->
152
+ val partSize = if (segment.partNumber == config.numPartWorkers) {
153
+ file.length() - (segment.partNumber - 1 ) * config.partSizeBytes
154
+ } else {
155
+ config.partSizeBytes
156
+ }
157
+ val partBytes = ByteArray (partSize.toInt())
158
+ file.seek((segment.partNumber - 1 ) * config.partSizeBytes)
159
+ file.read(partBytes)
160
+ segment.upload.uploadPart(partBytes, segment.partNumber)
161
+ log.info {
162
+ " Finished uploading part ${segment.partNumber} of ${segment.fileUrl} "
163
+ }
164
+ totalBytesLoaded.addAndGet(partSize)
165
+ segment.callback()
166
+ }
57
167
}
58
- log.info { " Completing upload to $workerKey " }
59
- upload.complete()
60
168
}
61
169
}.awaitAll()
62
170
}
63
171
}
64
- val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65
172
log.info {
66
- // format mbs to 2 decimal places
67
- " Uploaded $actualSizeBytes bytes in $duration seconds (${" %.2f " .format( mbs)} MB/s)"
173
+ val mbs = totalBytesLoaded.get().toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
174
+ " Uploaded ${totalBytesLoaded.get()} bytes in ${ duration.inWholeSeconds} s (${mbs} MB/s)"
68
175
}
69
176
}
70
177
}
0 commit comments