Skip to content

Commit 230d96e

Browse files
DO NOT MERGE: S3 write-op override for fast file xfer
1 parent e4e9168 commit 230d96e

File tree

5 files changed

+183
-37
lines changed

5 files changed

+183
-37
lines changed

airbyte-integrations/connectors/destination-s3/metadata.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ data:
22
connectorSubtype: file
33
connectorType: destination
44
definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362
5-
dockerImageTag: 1.5.6
5+
dockerImageTag: 1.5.7
66
dockerRepository: airbyte/destination-s3
77
githubIssueLabel: destination-s3
88
icon: s3.svg

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ class S3V2ConfigurationFactory :
7373
objectStoragePathConfiguration = pojo.toObjectStoragePathConfiguration(),
7474
objectStorageFormatConfiguration = pojo.toObjectStorageFormatConfiguration(),
7575
objectStorageCompressionConfiguration = pojo.toCompressionConfiguration(),
76-
numUploadWorkers = pojo.numObjectLoaders ?: 25,
77-
partSizeBytes = (pojo.partSizeMb ?: 50) * 1024L * 1024L,
76+
numUploadWorkers = pojo.numObjectLoaders ?: 10,
77+
partSizeBytes = (pojo.partSizeMb ?: 10) * 1024L * 1024L,
7878
useLegacyClient = pojo.useLegacyClient ?: false,
79-
objectSizeBytes = (pojo.totalDataMb ?: 2024) * 1024L * 1024L,
8079
)
8180
}
8281
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt

-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class S3V2Specification :
8080
val numObjectLoaders: Int? = null
8181
@get:JsonProperty("part_size_mb")
8282
val partSizeMb: Int? = null
83-
@get:JsonProperty("total_data_mb")
84-
val totalDataMb: Int? = null
8583
@get:JsonProperty("use_legacy_client")
8684
val useLegacyClient: Boolean? = null
8785
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt

+178-29
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,218 @@ package io.airbyte.integrations.destination.s3_v2
22

33
import io.airbyte.cdk.load.command.DestinationCatalog
44
import io.airbyte.cdk.load.file.object_storage.PathFactory
5+
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
56
import io.airbyte.cdk.load.file.s3.S3Client
7+
import io.airbyte.cdk.load.file.s3.S3Object
8+
import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9+
import io.airbyte.cdk.load.message.CheckpointMessage
10+
import io.airbyte.cdk.load.message.DestinationFile
11+
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12+
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13+
import io.airbyte.cdk.load.message.DestinationRecord
14+
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15+
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16+
import io.airbyte.cdk.load.message.GlobalCheckpoint
17+
import io.airbyte.cdk.load.message.StreamCheckpoint
18+
import io.airbyte.cdk.load.message.Undefined
19+
import io.airbyte.cdk.load.state.SyncManager
620
import io.airbyte.cdk.load.task.SelfTerminating
721
import io.airbyte.cdk.load.task.TerminalCondition
22+
import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
823
import io.airbyte.cdk.load.write.WriteOpOverride
24+
import io.airbyte.protocol.models.v0.AirbyteMessage
925
import io.github.oshai.kotlinlogging.KotlinLogging
1026
import jakarta.inject.Singleton
27+
import java.io.File
28+
import java.io.RandomAccessFile
29+
import java.nio.MappedByteBuffer
30+
import java.nio.channels.FileChannel
31+
import java.nio.file.Path
32+
import java.util.concurrent.atomic.AtomicLong
33+
import java.util.function.Consumer
1134
import kotlin.random.Random
1235
import kotlin.time.measureTime
1336
import kotlinx.coroutines.Dispatchers
1437
import kotlinx.coroutines.ExperimentalCoroutinesApi
1538
import kotlinx.coroutines.async
1639
import kotlinx.coroutines.awaitAll
40+
import kotlinx.coroutines.channels.Channel
1741
import kotlinx.coroutines.coroutineScope
42+
import kotlinx.coroutines.flow.consumeAsFlow
43+
import kotlinx.coroutines.flow.flowOn
44+
import kotlinx.coroutines.launch
1845
import kotlinx.coroutines.withContext
1946

47+
data class FileSegment(
48+
val fileUrl: String,
49+
val objectKey: String,
50+
val upload: StreamingUpload<S3Object>,
51+
val partNumber: Int,
52+
val partSize: Long,
53+
val mappedbuffer: MappedByteBuffer,
54+
val callback: suspend () -> Unit = {}
55+
)
56+
2057
@Singleton
2158
class S3V2WriteOpOverride(
2259
private val client: S3Client,
2360
private val catalog: DestinationCatalog,
2461
private val config: S3V2Configuration<*>,
2562
private val pathFactory: PathFactory,
63+
private val reservingDeserializingInputFlow: ReservingDeserializingInputFlow,
64+
private val outputConsumer: Consumer<AirbyteMessage>,
65+
private val syncManager: SyncManager,
2666
): WriteOpOverride {
2767
private val log = KotlinLogging.logger { }
2868

2969
override val terminalCondition: TerminalCondition = SelfTerminating
3070

3171
@OptIn(ExperimentalCoroutinesApi::class)
3272
override suspend fun execute() = coroutineScope {
33-
val prng = Random(System.currentTimeMillis())
34-
val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35-
val randomString = randomPart.take(32).joinToString("") { "%02x".format(it) }
36-
val stream = catalog.streams.first()
37-
val objectKey = pathFactory.getFinalDirectory(stream) + "/mock-perf-test-$randomString"
73+
val mockPartQueue: Channel<FileSegment> = Channel(Channel.UNLIMITED)
74+
val streamCount = AtomicLong(catalog.streams.size.toLong())
75+
val totalBytesLoaded = AtomicLong(0L)
76+
try {
77+
withContext(Dispatchers.IO) {
78+
val duration = measureTime {
79+
launch {
80+
reservingDeserializingInputFlow.collect { (_, reservation) ->
81+
when (val message = reservation.value) {
82+
is GlobalCheckpoint -> {
83+
outputConsumer.accept(
84+
message.withDestinationStats(CheckpointMessage.Stats(0))
85+
.asProtocolMessage()
86+
)
87+
}
88+
is StreamCheckpoint -> {
89+
val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
90+
.markCheckpoint()
91+
log.info { "Flushing state" }
92+
outputConsumer.accept(
93+
message.withDestinationStats(
94+
CheckpointMessage.Stats(
95+
count
96+
)
97+
)
98+
.asProtocolMessage()
99+
)
100+
log.info { "Done flushing state" }
101+
}
102+
is DestinationFile -> {
103+
syncManager.getStreamManager(message.stream)
104+
.incrementReadCount()
105+
if (message.fileMessage.bytes == null) {
106+
throw IllegalStateException("This can't work unless you set FileMessage.bytes!")
107+
}
108+
val size = message.fileMessage.bytes!!
109+
val numWholeParts = (size / config.partSizeBytes).toInt()
110+
val numParts =
111+
numWholeParts + if (size % config.partSizeBytes > 0) 1 else 0
112+
val lastPartSize = size % config.partSizeBytes
113+
val fileUrl = message.fileMessage.fileUrl!!
114+
log.info {
115+
"Breaking file $fileUrl (size=${size}B) into $numParts ${config.partSizeBytes}B parts"
116+
}
117+
val stream = catalog.getStream(message.stream)
118+
val directory = pathFactory.getFinalDirectory(stream)
119+
val sourceFileName = message.fileMessage.sourceFileUrl!!
120+
val objectKey = Path.of(directory, sourceFileName).toString()
121+
val upload = client.startStreamingUpload(objectKey)
122+
val partCounter = AtomicLong(numParts.toLong())
123+
val raf = RandomAccessFile(fileUrl, "r")
124+
val memoryMap = raf.channel.map(
125+
FileChannel.MapMode.READ_ONLY,
126+
0,
127+
size
128+
)
129+
repeat(numParts) { partNumber ->
130+
mockPartQueue.send(
131+
FileSegment(
132+
fileUrl,
133+
objectKey,
134+
upload,
135+
partNumber + 1,
136+
if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes,
137+
memoryMap.slice(
138+
(partNumber * config.partSizeBytes).toInt(),
139+
(if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes).toInt()
140+
),
141+
) {
142+
val partsRemaining = partCounter.decrementAndGet()
143+
if (partsRemaining == 0L) {
144+
log.info {
145+
"Finished uploading $numParts parts of $fileUrl; deleting file and finishing upload"
146+
}
147+
raf.close()
148+
File(fileUrl).delete()
149+
log.info {
150+
"Finished deleting"
151+
}
152+
upload.complete()
153+
log.info {
154+
"Finished completing the upload"
155+
}
156+
} else {
157+
log.info {
158+
"Finished uploading part ${partNumber + 1} of $fileUrl. $partsRemaining parts remaining"
159+
}
160+
}
161+
}
162+
)
163+
}
164+
}
38165

39-
val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40-
val partsPerWorker = numParts / config.numUploadWorkers
41-
val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
166+
is DestinationFileStreamComplete,
167+
is DestinationFileStreamIncomplete -> {
168+
if (streamCount.decrementAndGet() == 0L) {
169+
log.info {
170+
"Read final stream complete, closing mockPartQueue"
171+
}
172+
mockPartQueue.close()
173+
} else {
174+
log.info {
175+
"Read stream complete, ${streamCount.get()} streams remaining"
176+
}
177+
}
178+
}
42179

43-
log.info {
44-
"root key=$objectKey; part_size=${config.partSizeBytes}b; num_parts=$numParts (per_worker=$partsPerWorker); total_size=${actualSizeBytes}b; num_workers=${config.numUploadWorkers}"
45-
}
180+
is DestinationRecordStreamComplete,
181+
is DestinationRecordStreamIncomplete,
182+
is DestinationRecord -> throw NotImplementedError("This hack is only for files")
46183

47-
val duration = measureTime {
48-
withContext(Dispatchers.IO.limitedParallelism(config.numUploadWorkers)) {
49-
(0 until config.numUploadWorkers).map {
50-
async {
51-
val workerKey = "$objectKey-worker-$it"
52-
log.info { "Starting upload to $workerKey" }
53-
val upload = client.startStreamingUpload(workerKey)
54-
repeat(partsPerWorker) {
55-
log.info { "Uploading part ${it + 1} of $workerKey" }
56-
upload.uploadPart(randomPart, it + 1)
184+
Undefined ->
185+
log.warn {
186+
"Undefined message received. This should not happen."
187+
}
188+
}
189+
reservation.release()
57190
}
58-
log.info { "Completing upload to $workerKey" }
59-
upload.complete()
60191
}
61-
}.awaitAll()
192+
193+
(0 until config.numUploadWorkers).map {
194+
async {
195+
mockPartQueue.consumeAsFlow().collect { segment ->
196+
log.info { "Starting upload to ${segment.objectKey} part ${segment.partNumber}" }
197+
val partBytes = ByteArray(segment.partSize.toInt())
198+
segment.mappedbuffer.get(partBytes)
199+
segment.upload.uploadPart(partBytes, segment.partNumber)
200+
log.info {
201+
"Finished uploading part ${segment.partNumber} of ${segment.fileUrl}"
202+
}
203+
totalBytesLoaded.addAndGet(segment.partSize)
204+
segment.callback()
205+
}
206+
}
207+
}.awaitAll()
208+
}
209+
log.info {
210+
val mbs = totalBytesLoaded.get()
211+
.toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
212+
"Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds}s (${mbs}MB/s)"
213+
}
62214
}
63-
}
64-
val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65-
log.info {
66-
// format mbs to 2 decimal places
67-
"Uploaded $actualSizeBytes bytes in $duration seconds (${"%.2f".format(mbs)} MB/s)"
215+
} catch (e: Throwable) {
216+
log.error(e) { "Error uploading file, bailing" }
68217
}
69218
}
70219
}

airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class S3V2JsonNoFrillsPerformanceTest :
1515
configSpecClass = S3V2Specification::class.java,
1616
defaultRecordsToInsert = 1_000_000,
1717
micronautProperties = S3V2TestUtils.PERFORMANCE_TEST_MICRONAUT_PROPERTIES,
18-
numFilesForFileTransfer = 5,
19-
fileSizeMbForFileTransfer = 1024,
18+
numFilesForFileTransfer = 3,
19+
fileSizeMbForFileTransfer = 1099,
2020
) {
2121
@Test
2222
override fun testFileTransfer() {

0 commit comments

Comments
 (0)