@@ -2,69 +2,209 @@ package io.airbyte.integrations.destination.s3_v2
2
2
3
3
import io.airbyte.cdk.load.command.DestinationCatalog
4
4
import io.airbyte.cdk.load.file.object_storage.PathFactory
5
+ import io.airbyte.cdk.load.file.object_storage.StreamingUpload
5
6
import io.airbyte.cdk.load.file.s3.S3Client
7
+ import io.airbyte.cdk.load.file.s3.S3Object
8
+ import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9
+ import io.airbyte.cdk.load.message.CheckpointMessage
10
+ import io.airbyte.cdk.load.message.DestinationFile
11
+ import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12
+ import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13
+ import io.airbyte.cdk.load.message.DestinationRecord
14
+ import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15
+ import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16
+ import io.airbyte.cdk.load.message.GlobalCheckpoint
17
+ import io.airbyte.cdk.load.message.StreamCheckpoint
18
+ import io.airbyte.cdk.load.message.Undefined
19
+ import io.airbyte.cdk.load.state.SyncManager
6
20
import io.airbyte.cdk.load.task.SelfTerminating
7
21
import io.airbyte.cdk.load.task.TerminalCondition
22
+ import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
8
23
import io.airbyte.cdk.load.write.WriteOpOverride
24
+ import io.airbyte.protocol.models.v0.AirbyteMessage
9
25
import io.github.oshai.kotlinlogging.KotlinLogging
10
26
import jakarta.inject.Singleton
27
+ import java.io.File
28
+ import java.io.RandomAccessFile
29
+ import java.nio.file.Path
30
+ import java.util.concurrent.atomic.AtomicLong
31
+ import java.util.function.Consumer
11
32
import kotlin.random.Random
12
33
import kotlin.time.measureTime
13
34
import kotlinx.coroutines.Dispatchers
14
35
import kotlinx.coroutines.ExperimentalCoroutinesApi
15
36
import kotlinx.coroutines.async
16
37
import kotlinx.coroutines.awaitAll
38
+ import kotlinx.coroutines.channels.Channel
17
39
import kotlinx.coroutines.coroutineScope
40
+ import kotlinx.coroutines.flow.consumeAsFlow
41
+ import kotlinx.coroutines.flow.flowOn
42
+ import kotlinx.coroutines.launch
18
43
import kotlinx.coroutines.withContext
19
44
45
+ data class FileSegment (
46
+ val fileUrl : String ,
47
+ val objectKey : String ,
48
+ val upload : StreamingUpload <S3Object >,
49
+ val partNumber : Int ,
50
+ val partSize : Long ,
51
+ val callback : suspend () -> Unit = {}
52
+ )
53
+
20
54
@Singleton
21
55
class S3V2WriteOpOverride (
22
56
private val client : S3Client ,
23
57
private val catalog : DestinationCatalog ,
24
58
private val config : S3V2Configuration <* >,
25
59
private val pathFactory : PathFactory ,
60
+ private val reservingDeserializingInputFlow : ReservingDeserializingInputFlow ,
61
+ private val outputConsumer : Consumer <AirbyteMessage >,
62
+ private val syncManager : SyncManager ,
26
63
): WriteOpOverride {
27
64
private val log = KotlinLogging .logger { }
28
65
29
66
override val terminalCondition: TerminalCondition = SelfTerminating
30
67
31
68
@OptIn(ExperimentalCoroutinesApi ::class )
32
69
override suspend fun execute () = coroutineScope {
33
- val prng = Random (System .currentTimeMillis())
34
- val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35
- val randomString = randomPart.take(32 ).joinToString(" " ) { " %02x" .format(it) }
36
- val stream = catalog.streams.first()
37
- val objectKey = pathFactory.getFinalDirectory(stream) + " /mock-perf-test-$randomString "
70
+ val mockPartQueue: Channel <FileSegment > = Channel (Channel .UNLIMITED )
71
+ val streamCount = AtomicLong (catalog.streams.size.toLong())
72
+ withContext(Dispatchers .IO ) {
73
+ launch {
74
+ reservingDeserializingInputFlow.collect { (_, reservation) ->
75
+ when (val message = reservation.value) {
76
+ is GlobalCheckpoint -> launch {
77
+ outputConsumer.accept(
78
+ message.withDestinationStats(CheckpointMessage .Stats (0 ))
79
+ .asProtocolMessage()
80
+ )
81
+ }
38
82
39
- val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40
- val partsPerWorker = numParts / config.numUploadWorkers
41
- val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
83
+ is StreamCheckpoint -> {
84
+ val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
85
+ .markCheckpoint()
86
+ launch {
87
+ log.info { " Flushing state" }
88
+ outputConsumer.accept(
89
+ message.withDestinationStats(CheckpointMessage .Stats (count))
90
+ .asProtocolMessage()
91
+ )
92
+ log.info { " Done flushing state" }
93
+ }
94
+ }
42
95
43
- log.info {
44
- " root key=$objectKey ; part_size=${config.partSizeBytes} b; num_parts=$numParts (per_worker=$partsPerWorker ); total_size=${actualSizeBytes} b; num_workers=${config.numUploadWorkers} "
45
- }
96
+ is DestinationFile -> {
97
+ syncManager.getStreamManager(message.stream).incrementReadCount()
98
+ if (message.fileMessage.bytes == null ) {
99
+ throw IllegalStateException (" This can't work unless you set FileMessage.bytes!" )
100
+ }
101
+ val size = message.fileMessage.bytes!!
102
+ val numWholeParts = (size / config.partSizeBytes).toInt()
103
+ val numParts =
104
+ numWholeParts + if (size % config.partSizeBytes > 0 ) 1 else 0
105
+ val lastPartSize = size % config.partSizeBytes
106
+ val fileUrl = message.fileMessage.fileUrl!!
107
+ log.info {
108
+ " Breaking file $fileUrl (size=${size} B) into $numParts ${config.partSizeBytes} B parts"
109
+ }
110
+ val stream = catalog.getStream(message.stream)
111
+ val directory = pathFactory.getFinalDirectory(stream)
112
+ val sourceFileName = message.fileMessage.sourceFileUrl!!
113
+ val objectKey = Path .of(directory, sourceFileName).toString()
114
+ val upload = client.startStreamingUpload(objectKey)
115
+ val partCounter = AtomicLong (numParts.toLong())
116
+ repeat(numParts) { partNumber ->
117
+ mockPartQueue.send(
118
+ FileSegment (
119
+ fileUrl,
120
+ objectKey,
121
+ upload,
122
+ partNumber + 1 ,
123
+ if (partNumber == numParts - 1 ) lastPartSize else config.partSizeBytes
124
+ ) {
125
+ val partsRemaining = partCounter.decrementAndGet()
126
+ if (partsRemaining == 0L ) {
127
+ log.info {
128
+ " Finished uploading $numParts parts of $fileUrl ; deleting file and finishing upload"
129
+ }
130
+ File (fileUrl).delete()
131
+ log.info {
132
+ " Finished deleting"
133
+ }
134
+ upload.complete()
135
+ log.info {
136
+ " Finished completing the upload"
137
+ }
138
+ } else {
139
+ log.info {
140
+ " Finished uploading part ${partNumber + 1 } of $fileUrl . $partsRemaining parts remaining"
141
+ }
142
+ }
143
+ }
144
+ )
145
+ }
146
+ }
46
147
47
- val duration = measureTime {
48
- withContext(Dispatchers .IO .limitedParallelism(config.numUploadWorkers)) {
49
- (0 until config.numUploadWorkers).map {
50
- async {
51
- val workerKey = " $objectKey -worker-$it "
52
- log.info { " Starting upload to $workerKey " }
53
- val upload = client.startStreamingUpload(workerKey)
54
- repeat(partsPerWorker) {
55
- log.info { " Uploading part ${it + 1 } of $workerKey " }
56
- upload.uploadPart(randomPart, it + 1 )
148
+ is DestinationFileStreamComplete ,
149
+ is DestinationFileStreamIncomplete -> {
150
+ if (streamCount.decrementAndGet() == 0L ) {
151
+ log.info {
152
+ " Read final stream complete, closing mockPartQueue"
153
+ }
154
+ mockPartQueue.close()
155
+ } else {
156
+ log.info {
157
+ " Read stream complete, ${streamCount.get()} streams remaining"
158
+ }
159
+ }
57
160
}
58
- log.info { " Completing upload to $workerKey " }
59
- upload.complete()
161
+
162
+ is DestinationRecordStreamComplete ,
163
+ is DestinationRecordStreamIncomplete ,
164
+ is DestinationRecord -> throw NotImplementedError (" This hack is only for files" )
165
+
166
+ Undefined ->
167
+ log.warn {
168
+ " Undefined message received. This should not happen."
169
+ }
60
170
}
61
- }.awaitAll()
171
+ reservation.release()
172
+ }
62
173
}
63
174
}
64
- val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65
- log.info {
66
- // format mbs to 2 decimal places
67
- " Uploaded $actualSizeBytes bytes in $duration seconds (${" %.2f" .format(mbs)} MB/s)"
175
+
176
+ val totalBytesLoaded = AtomicLong (0L )
177
+ try {
178
+ val duration = measureTime {
179
+ val dispatcher = Dispatchers .IO .limitedParallelism(config.numUploadWorkers)
180
+ withContext(dispatcher) {
181
+ (0 until config.numUploadWorkers).map {
182
+ async {
183
+ mockPartQueue.consumeAsFlow().collect { segment ->
184
+ log.info { " Starting upload to ${segment.objectKey} part ${segment.partNumber} " }
185
+ RandomAccessFile (segment.fileUrl, " r" ).use { file ->
186
+ val partBytes = ByteArray (segment.partSize.toInt())
187
+ file.seek((segment.partNumber - 1 ) * config.partSizeBytes)
188
+ file.read(partBytes)
189
+ segment.upload.uploadPart(partBytes, segment.partNumber)
190
+ log.info {
191
+ " Finished uploading part ${segment.partNumber} of ${segment.fileUrl} "
192
+ }
193
+ totalBytesLoaded.addAndGet(segment.partSize)
194
+ segment.callback()
195
+ }
196
+ }
197
+ }
198
+ }.awaitAll()
199
+ }
200
+ }
201
+ log.info {
202
+ val mbs = totalBytesLoaded.get()
203
+ .toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
204
+ " Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds} s (${mbs} MB/s)"
205
+ }
206
+ } catch (e: Throwable ) {
207
+ log.error(e) { " Error uploading file, bailing" }
68
208
}
69
209
}
70
210
}
0 commit comments