@@ -2,69 +2,207 @@ package io.airbyte.integrations.destination.s3_v2
2
2
3
3
import io.airbyte.cdk.load.command.DestinationCatalog
4
4
import io.airbyte.cdk.load.file.object_storage.PathFactory
5
+ import io.airbyte.cdk.load.file.object_storage.StreamingUpload
5
6
import io.airbyte.cdk.load.file.s3.S3Client
7
+ import io.airbyte.cdk.load.file.s3.S3Object
8
+ import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9
+ import io.airbyte.cdk.load.message.CheckpointMessage
10
+ import io.airbyte.cdk.load.message.DestinationFile
11
+ import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12
+ import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13
+ import io.airbyte.cdk.load.message.DestinationRecord
14
+ import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15
+ import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16
+ import io.airbyte.cdk.load.message.GlobalCheckpoint
17
+ import io.airbyte.cdk.load.message.StreamCheckpoint
18
+ import io.airbyte.cdk.load.message.Undefined
19
+ import io.airbyte.cdk.load.state.SyncManager
6
20
import io.airbyte.cdk.load.task.SelfTerminating
7
21
import io.airbyte.cdk.load.task.TerminalCondition
22
+ import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
8
23
import io.airbyte.cdk.load.write.WriteOpOverride
24
+ import io.airbyte.protocol.models.v0.AirbyteMessage
9
25
import io.github.oshai.kotlinlogging.KotlinLogging
10
26
import jakarta.inject.Singleton
27
+ import java.io.File
28
+ import java.io.RandomAccessFile
29
+ import java.nio.file.Path
30
+ import java.util.concurrent.atomic.AtomicLong
31
+ import java.util.function.Consumer
11
32
import kotlin.random.Random
12
33
import kotlin.time.measureTime
13
34
import kotlinx.coroutines.Dispatchers
14
35
import kotlinx.coroutines.ExperimentalCoroutinesApi
15
36
import kotlinx.coroutines.async
16
37
import kotlinx.coroutines.awaitAll
38
+ import kotlinx.coroutines.channels.Channel
17
39
import kotlinx.coroutines.coroutineScope
40
+ import kotlinx.coroutines.flow.consumeAsFlow
41
+ import kotlinx.coroutines.flow.flowOn
42
+ import kotlinx.coroutines.launch
18
43
import kotlinx.coroutines.withContext
19
44
45
+ data class FileSegment (
46
+ val fileUrl : String ,
47
+ val objectKey : String ,
48
+ val upload : StreamingUpload <S3Object >,
49
+ val partNumber : Int ,
50
+ val partSize : Long ,
51
+ val callback : suspend () -> Unit = {}
52
+ )
53
+
20
54
@Singleton
21
55
class S3V2WriteOpOverride (
22
56
private val client : S3Client ,
23
57
private val catalog : DestinationCatalog ,
24
58
private val config : S3V2Configuration <* >,
25
59
private val pathFactory : PathFactory ,
60
+ private val reservingDeserializingInputFlow : ReservingDeserializingInputFlow ,
61
+ private val outputConsumer : Consumer <AirbyteMessage >,
62
+ private val syncManager : SyncManager ,
26
63
): WriteOpOverride {
27
64
private val log = KotlinLogging .logger { }
28
65
29
66
override val terminalCondition: TerminalCondition = SelfTerminating
30
67
31
68
@OptIn(ExperimentalCoroutinesApi ::class )
32
69
override suspend fun execute () = coroutineScope {
33
- val prng = Random (System .currentTimeMillis())
34
- val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35
- val randomString = randomPart.take(32 ).joinToString(" " ) { " %02x" .format(it) }
36
- val stream = catalog.streams.first()
37
- val objectKey = pathFactory.getFinalDirectory(stream) + " /mock-perf-test-$randomString "
70
+ val mockPartQueue: Channel <FileSegment > = Channel (Channel .UNLIMITED )
71
+ val streamCount = AtomicLong (catalog.streams.size.toLong())
72
+ val totalBytesLoaded = AtomicLong (0L )
73
+ try {
74
+ withContext(Dispatchers .IO ) {
75
+ val duration = measureTime {
76
+ launch {
77
+ reservingDeserializingInputFlow.collect { (_, reservation) ->
78
+ when (val message = reservation.value) {
79
+ is GlobalCheckpoint -> {
80
+ outputConsumer.accept(
81
+ message.withDestinationStats(CheckpointMessage .Stats (0 ))
82
+ .asProtocolMessage()
83
+ )
84
+ }
85
+ is StreamCheckpoint -> {
86
+ val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
87
+ .markCheckpoint()
88
+ log.info { " Flushing state" }
89
+ outputConsumer.accept(
90
+ message.withDestinationStats(
91
+ CheckpointMessage .Stats (
92
+ count
93
+ )
94
+ )
95
+ .asProtocolMessage()
96
+ )
97
+ log.info { " Done flushing state" }
98
+ }
99
+ is DestinationFile -> {
100
+ syncManager.getStreamManager(message.stream)
101
+ .incrementReadCount()
102
+ if (message.fileMessage.bytes == null ) {
103
+ throw IllegalStateException (" This can't work unless you set FileMessage.bytes!" )
104
+ }
105
+ val size = message.fileMessage.bytes!!
106
+ val numWholeParts = (size / config.partSizeBytes).toInt()
107
+ val numParts =
108
+ numWholeParts + if (size % config.partSizeBytes > 0 ) 1 else 0
109
+ val lastPartSize = size % config.partSizeBytes
110
+ val fileUrl = message.fileMessage.fileUrl!!
111
+ log.info {
112
+ " Breaking file $fileUrl (size=${size} B) into $numParts ${config.partSizeBytes} B parts"
113
+ }
114
+ val stream = catalog.getStream(message.stream)
115
+ val directory = pathFactory.getFinalDirectory(stream)
116
+ val sourceFileName = message.fileMessage.sourceFileUrl!!
117
+ val objectKey = Path .of(directory, sourceFileName).toString()
118
+ val upload = client.startStreamingUpload(objectKey)
119
+ val partCounter = AtomicLong (numParts.toLong())
120
+ repeat(numParts) { partNumber ->
121
+ mockPartQueue.send(
122
+ FileSegment (
123
+ fileUrl,
124
+ objectKey,
125
+ upload,
126
+ partNumber + 1 ,
127
+ if (partNumber == numParts - 1 ) lastPartSize else config.partSizeBytes
128
+ ) {
129
+ val partsRemaining = partCounter.decrementAndGet()
130
+ if (partsRemaining == 0L ) {
131
+ log.info {
132
+ " Finished uploading $numParts parts of $fileUrl ; deleting file and finishing upload"
133
+ }
134
+ File (fileUrl).delete()
135
+ log.info {
136
+ " Finished deleting"
137
+ }
138
+ upload.complete()
139
+ log.info {
140
+ " Finished completing the upload"
141
+ }
142
+ } else {
143
+ log.info {
144
+ " Finished uploading part ${partNumber + 1 } of $fileUrl . $partsRemaining parts remaining"
145
+ }
146
+ }
147
+ }
148
+ )
149
+ }
150
+ }
38
151
39
- val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40
- val partsPerWorker = numParts / config.numUploadWorkers
41
- val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
152
+ is DestinationFileStreamComplete ,
153
+ is DestinationFileStreamIncomplete -> {
154
+ if (streamCount.decrementAndGet() == 0L ) {
155
+ log.info {
156
+ " Read final stream complete, closing mockPartQueue"
157
+ }
158
+ mockPartQueue.close()
159
+ } else {
160
+ log.info {
161
+ " Read stream complete, ${streamCount.get()} streams remaining"
162
+ }
163
+ }
164
+ }
42
165
43
- log.info {
44
- " root key= $objectKey ; part_size= ${config.partSizeBytes} b; num_parts= $numParts (per_worker= $partsPerWorker ); total_size= ${actualSizeBytes} b; num_workers= ${config.numUploadWorkers} "
45
- }
166
+ is DestinationRecordStreamComplete ,
167
+ is DestinationRecordStreamIncomplete ,
168
+ is DestinationRecord -> throw NotImplementedError ( " This hack is only for files " )
46
169
47
- val duration = measureTime {
48
- withContext(Dispatchers .IO .limitedParallelism(config.numUploadWorkers)) {
49
- (0 until config.numUploadWorkers).map {
50
- async {
51
- val workerKey = " $objectKey -worker-$it "
52
- log.info { " Starting upload to $workerKey " }
53
- val upload = client.startStreamingUpload(workerKey)
54
- repeat(partsPerWorker) {
55
- log.info { " Uploading part ${it + 1 } of $workerKey " }
56
- upload.uploadPart(randomPart, it + 1 )
170
+ Undefined ->
171
+ log.warn {
172
+ " Undefined message received. This should not happen."
173
+ }
174
+ }
175
+ reservation.release()
57
176
}
58
- log.info { " Completing upload to $workerKey " }
59
- upload.complete()
60
177
}
61
- }.awaitAll()
178
+
179
+ (0 until config.numUploadWorkers).map {
180
+ async {
181
+ mockPartQueue.consumeAsFlow().collect { segment ->
182
+ log.info { " Starting upload to ${segment.objectKey} part ${segment.partNumber} " }
183
+ RandomAccessFile (segment.fileUrl, " r" ).use { file ->
184
+ val partBytes = ByteArray (segment.partSize.toInt())
185
+ file.seek((segment.partNumber - 1 ) * config.partSizeBytes)
186
+ file.read(partBytes)
187
+ segment.upload.uploadPart(partBytes, segment.partNumber)
188
+ log.info {
189
+ " Finished uploading part ${segment.partNumber} of ${segment.fileUrl} "
190
+ }
191
+ totalBytesLoaded.addAndGet(segment.partSize)
192
+ segment.callback()
193
+ }
194
+ }
195
+ }
196
+ }.awaitAll()
197
+ }
198
+ log.info {
199
+ val mbs = totalBytesLoaded.get()
200
+ .toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
201
+ " Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds} s (${mbs} MB/s)"
202
+ }
62
203
}
63
- }
64
- val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65
- log.info {
66
- // format mbs to 2 decimal places
67
- " Uploaded $actualSizeBytes bytes in $duration seconds (${" %.2f" .format(mbs)} MB/s)"
204
+ } catch (e: Throwable ) {
205
+ log.error(e) { " Error uploading file, bailing" }
68
206
}
69
207
}
70
208
}
0 commit comments