@@ -2,69 +2,218 @@ package io.airbyte.integrations.destination.s3_v2
2
2
3
3
import io.airbyte.cdk.load.command.DestinationCatalog
4
4
import io.airbyte.cdk.load.file.object_storage.PathFactory
5
+ import io.airbyte.cdk.load.file.object_storage.StreamingUpload
5
6
import io.airbyte.cdk.load.file.s3.S3Client
7
+ import io.airbyte.cdk.load.file.s3.S3Object
8
+ import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9
+ import io.airbyte.cdk.load.message.CheckpointMessage
10
+ import io.airbyte.cdk.load.message.DestinationFile
11
+ import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12
+ import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13
+ import io.airbyte.cdk.load.message.DestinationRecord
14
+ import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15
+ import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16
+ import io.airbyte.cdk.load.message.GlobalCheckpoint
17
+ import io.airbyte.cdk.load.message.StreamCheckpoint
18
+ import io.airbyte.cdk.load.message.Undefined
19
+ import io.airbyte.cdk.load.state.SyncManager
6
20
import io.airbyte.cdk.load.task.SelfTerminating
7
21
import io.airbyte.cdk.load.task.TerminalCondition
22
+ import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
8
23
import io.airbyte.cdk.load.write.WriteOpOverride
24
+ import io.airbyte.protocol.models.v0.AirbyteMessage
9
25
import io.github.oshai.kotlinlogging.KotlinLogging
10
26
import jakarta.inject.Singleton
27
+ import java.io.File
28
+ import java.io.RandomAccessFile
29
+ import java.nio.MappedByteBuffer
30
+ import java.nio.channels.FileChannel
31
+ import java.nio.file.Path
32
+ import java.util.concurrent.atomic.AtomicLong
33
+ import java.util.function.Consumer
11
34
import kotlin.random.Random
12
35
import kotlin.time.measureTime
13
36
import kotlinx.coroutines.Dispatchers
14
37
import kotlinx.coroutines.ExperimentalCoroutinesApi
15
38
import kotlinx.coroutines.async
16
39
import kotlinx.coroutines.awaitAll
40
+ import kotlinx.coroutines.channels.Channel
17
41
import kotlinx.coroutines.coroutineScope
42
+ import kotlinx.coroutines.flow.consumeAsFlow
43
+ import kotlinx.coroutines.flow.flowOn
44
+ import kotlinx.coroutines.launch
18
45
import kotlinx.coroutines.withContext
19
46
47
+ data class FileSegment (
48
+ val fileUrl : String ,
49
+ val objectKey : String ,
50
+ val upload : StreamingUpload <S3Object >,
51
+ val partNumber : Int ,
52
+ val partSize : Long ,
53
+ val mappedbuffer : MappedByteBuffer ,
54
+ val callback : suspend () -> Unit = {}
55
+ )
56
+
20
57
@Singleton
21
58
class S3V2WriteOpOverride (
22
59
private val client : S3Client ,
23
60
private val catalog : DestinationCatalog ,
24
61
private val config : S3V2Configuration <* >,
25
62
private val pathFactory : PathFactory ,
63
+ private val reservingDeserializingInputFlow : ReservingDeserializingInputFlow ,
64
+ private val outputConsumer : Consumer <AirbyteMessage >,
65
+ private val syncManager : SyncManager ,
26
66
): WriteOpOverride {
27
67
private val log = KotlinLogging .logger { }
28
68
29
69
override val terminalCondition: TerminalCondition = SelfTerminating
30
70
31
71
@OptIn(ExperimentalCoroutinesApi ::class )
32
72
override suspend fun execute () = coroutineScope {
33
- val prng = Random (System .currentTimeMillis())
34
- val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35
- val randomString = randomPart.take(32 ).joinToString(" " ) { " %02x" .format(it) }
36
- val stream = catalog.streams.first()
37
- val objectKey = pathFactory.getFinalDirectory(stream) + " /mock-perf-test-$randomString "
73
+ val mockPartQueue: Channel <FileSegment > = Channel (Channel .UNLIMITED )
74
+ val streamCount = AtomicLong (catalog.streams.size.toLong())
75
+ val totalBytesLoaded = AtomicLong (0L )
76
+ try {
77
+ withContext(Dispatchers .IO ) {
78
+ val duration = measureTime {
79
+ launch {
80
+ reservingDeserializingInputFlow.collect { (_, reservation) ->
81
+ when (val message = reservation.value) {
82
+ is GlobalCheckpoint -> {
83
+ outputConsumer.accept(
84
+ message.withDestinationStats(CheckpointMessage .Stats (0 ))
85
+ .asProtocolMessage()
86
+ )
87
+ }
88
+ is StreamCheckpoint -> {
89
+ val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
90
+ .markCheckpoint()
91
+ log.info { " Flushing state" }
92
+ outputConsumer.accept(
93
+ message.withDestinationStats(
94
+ CheckpointMessage .Stats (
95
+ count
96
+ )
97
+ )
98
+ .asProtocolMessage()
99
+ )
100
+ log.info { " Done flushing state" }
101
+ }
102
+ is DestinationFile -> {
103
+ syncManager.getStreamManager(message.stream)
104
+ .incrementReadCount()
105
+ if (message.fileMessage.bytes == null ) {
106
+ throw IllegalStateException (" This can't work unless you set FileMessage.bytes!" )
107
+ }
108
+ val size = message.fileMessage.bytes!!
109
+ val numWholeParts = (size / config.partSizeBytes).toInt()
110
+ val numParts =
111
+ numWholeParts + if (size % config.partSizeBytes > 0 ) 1 else 0
112
+ val lastPartSize = size % config.partSizeBytes
113
+ val fileUrl = message.fileMessage.fileUrl!!
114
+ log.info {
115
+ " Breaking file $fileUrl (size=${size} B) into $numParts ${config.partSizeBytes} B parts"
116
+ }
117
+ val stream = catalog.getStream(message.stream)
118
+ val directory = pathFactory.getFinalDirectory(stream)
119
+ val sourceFileName = message.fileMessage.sourceFileUrl!!
120
+ val objectKey = Path .of(directory, sourceFileName).toString()
121
+ val upload = client.startStreamingUpload(objectKey)
122
+ val partCounter = AtomicLong (numParts.toLong())
123
+ val raf = RandomAccessFile (fileUrl, " r" )
124
+ val memoryMap = raf.channel.map(
125
+ FileChannel .MapMode .READ_ONLY ,
126
+ 0 ,
127
+ size
128
+ )
129
+ repeat(numParts) { partNumber ->
130
+ mockPartQueue.send(
131
+ FileSegment (
132
+ fileUrl,
133
+ objectKey,
134
+ upload,
135
+ partNumber + 1 ,
136
+ if (partNumber == numParts - 1 ) lastPartSize else config.partSizeBytes,
137
+ memoryMap.slice(
138
+ (partNumber * config.partSizeBytes).toInt(),
139
+ (if (partNumber == numParts - 1 ) lastPartSize else config.partSizeBytes).toInt()
140
+ ),
141
+ ) {
142
+ val partsRemaining = partCounter.decrementAndGet()
143
+ if (partsRemaining == 0L ) {
144
+ log.info {
145
+ " Finished uploading $numParts parts of $fileUrl ; deleting file and finishing upload"
146
+ }
147
+ raf.close()
148
+ File (fileUrl).delete()
149
+ log.info {
150
+ " Finished deleting"
151
+ }
152
+ upload.complete()
153
+ log.info {
154
+ " Finished completing the upload"
155
+ }
156
+ } else {
157
+ log.info {
158
+ " Finished uploading part ${partNumber + 1 } of $fileUrl . $partsRemaining parts remaining"
159
+ }
160
+ }
161
+ }
162
+ )
163
+ }
164
+ }
38
165
39
- val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40
- val partsPerWorker = numParts / config.numUploadWorkers
41
- val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
166
+ is DestinationFileStreamComplete ,
167
+ is DestinationFileStreamIncomplete -> {
168
+ if (streamCount.decrementAndGet() == 0L ) {
169
+ log.info {
170
+ " Read final stream complete, closing mockPartQueue"
171
+ }
172
+ mockPartQueue.close()
173
+ } else {
174
+ log.info {
175
+ " Read stream complete, ${streamCount.get()} streams remaining"
176
+ }
177
+ }
178
+ }
42
179
43
- log.info {
44
- " root key= $objectKey ; part_size= ${config.partSizeBytes} b; num_parts= $numParts (per_worker= $partsPerWorker ); total_size= ${actualSizeBytes} b; num_workers= ${config.numUploadWorkers} "
45
- }
180
+ is DestinationRecordStreamComplete ,
181
+ is DestinationRecordStreamIncomplete ,
182
+ is DestinationRecord -> throw NotImplementedError ( " This hack is only for files " )
46
183
47
- val duration = measureTime {
48
- withContext(Dispatchers .IO .limitedParallelism(config.numUploadWorkers)) {
49
- (0 until config.numUploadWorkers).map {
50
- async {
51
- val workerKey = " $objectKey -worker-$it "
52
- log.info { " Starting upload to $workerKey " }
53
- val upload = client.startStreamingUpload(workerKey)
54
- repeat(partsPerWorker) {
55
- log.info { " Uploading part ${it + 1 } of $workerKey " }
56
- upload.uploadPart(randomPart, it + 1 )
184
+ Undefined ->
185
+ log.warn {
186
+ " Undefined message received. This should not happen."
187
+ }
188
+ }
189
+ reservation.release()
57
190
}
58
- log.info { " Completing upload to $workerKey " }
59
- upload.complete()
60
191
}
61
- }.awaitAll()
192
+
193
+ (0 until config.numUploadWorkers).map {
194
+ async {
195
+ mockPartQueue.consumeAsFlow().collect { segment ->
196
+ log.info { " Starting upload to ${segment.objectKey} part ${segment.partNumber} " }
197
+ val partBytes = ByteArray (segment.partSize.toInt())
198
+ segment.mappedbuffer.get(partBytes)
199
+ segment.upload.uploadPart(partBytes, segment.partNumber)
200
+ log.info {
201
+ " Finished uploading part ${segment.partNumber} of ${segment.fileUrl} "
202
+ }
203
+ totalBytesLoaded.addAndGet(segment.partSize)
204
+ segment.callback()
205
+ }
206
+ }
207
+ }.awaitAll()
208
+ }
209
+ log.info {
210
+ val mbs = totalBytesLoaded.get()
211
+ .toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
212
+ " Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds} s (${mbs} MB/s)"
213
+ }
62
214
}
63
- }
64
- val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65
- log.info {
66
- // format mbs to 2 decimal places
67
- " Uploaded $actualSizeBytes bytes in $duration seconds (${" %.2f" .format(mbs)} MB/s)"
215
+ } catch (e: Throwable ) {
216
+ log.error(e) { " Error uploading file, bailing" }
68
217
}
69
218
}
70
219
}
0 commit comments