Skip to content

Commit 1ed29c7

Browse files
DO NOT MERGE: S3 write-op override for fast file xfer
1 parent e4e9168 commit 1ed29c7

File tree

5 files changed

+135
-31
lines changed

5 files changed

+135
-31
lines changed

airbyte-integrations/connectors/destination-s3/metadata.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ data:
22
connectorSubtype: file
33
connectorType: destination
44
definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362
5-
dockerImageTag: 1.5.6
5+
dockerImageTag: 1.5.7
66
dockerRepository: airbyte/destination-s3
77
githubIssueLabel: destination-s3
88
icon: s3.svg

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ class S3V2ConfigurationFactory :
7373
objectStoragePathConfiguration = pojo.toObjectStoragePathConfiguration(),
7474
objectStorageFormatConfiguration = pojo.toObjectStorageFormatConfiguration(),
7575
objectStorageCompressionConfiguration = pojo.toCompressionConfiguration(),
76-
numUploadWorkers = pojo.numObjectLoaders ?: 25,
77-
partSizeBytes = (pojo.partSizeMb ?: 50) * 1024L * 1024L,
76+
numUploadWorkers = pojo.numObjectLoaders ?: 20,
77+
partSizeBytes = (pojo.partSizeMb ?: 10) * 1024L * 1024L,
7878
useLegacyClient = pojo.useLegacyClient ?: false,
79-
objectSizeBytes = (pojo.totalDataMb ?: 2024) * 1024L * 1024L,
8079
)
8180
}
8281
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt

-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class S3V2Specification :
8080
val numObjectLoaders: Int? = null
8181
@get:JsonProperty("part_size_mb")
8282
val partSizeMb: Int? = null
83-
@get:JsonProperty("total_data_mb")
84-
val totalDataMb: Int? = null
8583
@get:JsonProperty("use_legacy_client")
8684
val useLegacyClient: Boolean? = null
8785
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt

+131-24
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,176 @@ package io.airbyte.integrations.destination.s3_v2
22

33
import io.airbyte.cdk.load.command.DestinationCatalog
44
import io.airbyte.cdk.load.file.object_storage.PathFactory
5+
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
56
import io.airbyte.cdk.load.file.s3.S3Client
7+
import io.airbyte.cdk.load.file.s3.S3Object
8+
import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9+
import io.airbyte.cdk.load.message.CheckpointMessage
10+
import io.airbyte.cdk.load.message.DestinationFile
11+
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12+
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13+
import io.airbyte.cdk.load.message.DestinationRecord
14+
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15+
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16+
import io.airbyte.cdk.load.message.GlobalCheckpoint
17+
import io.airbyte.cdk.load.message.StreamCheckpoint
18+
import io.airbyte.cdk.load.message.Undefined
19+
import io.airbyte.cdk.load.state.SyncManager
620
import io.airbyte.cdk.load.task.SelfTerminating
721
import io.airbyte.cdk.load.task.TerminalCondition
22+
import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
823
import io.airbyte.cdk.load.write.WriteOpOverride
24+
import io.airbyte.protocol.models.v0.AirbyteMessage
925
import io.github.oshai.kotlinlogging.KotlinLogging
1026
import jakarta.inject.Singleton
27+
import java.io.RandomAccessFile
28+
import java.nio.file.Path
29+
import java.util.concurrent.atomic.AtomicLong
30+
import java.util.function.Consumer
1131
import kotlin.random.Random
1232
import kotlin.time.measureTime
1333
import kotlinx.coroutines.Dispatchers
1434
import kotlinx.coroutines.ExperimentalCoroutinesApi
1535
import kotlinx.coroutines.async
1636
import kotlinx.coroutines.awaitAll
37+
import kotlinx.coroutines.channels.Channel
1738
import kotlinx.coroutines.coroutineScope
39+
import kotlinx.coroutines.flow.consumeAsFlow
40+
import kotlinx.coroutines.flow.flowOn
41+
import kotlinx.coroutines.launch
1842
import kotlinx.coroutines.withContext
1943

44+
data class FileSegment(
45+
val fileUrl: String,
46+
val objectKey: String,
47+
val upload: StreamingUpload<S3Object>,
48+
val partNumber: Int,
49+
val callback: suspend () -> Unit = {}
50+
)
51+
2052
@Singleton
2153
class S3V2WriteOpOverride(
2254
private val client: S3Client,
2355
private val catalog: DestinationCatalog,
2456
private val config: S3V2Configuration<*>,
2557
private val pathFactory: PathFactory,
58+
private val reservingDeserializingInputFlow: ReservingDeserializingInputFlow,
59+
private val outputConsumer: Consumer<AirbyteMessage>,
60+
private val syncManager: SyncManager,
2661
): WriteOpOverride {
2762
private val log = KotlinLogging.logger { }
2863

2964
override val terminalCondition: TerminalCondition = SelfTerminating
3065

3166
@OptIn(ExperimentalCoroutinesApi::class)
3267
override suspend fun execute() = coroutineScope {
33-
val prng = Random(System.currentTimeMillis())
34-
val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35-
val randomString = randomPart.take(32).joinToString("") { "%02x".format(it) }
36-
val stream = catalog.streams.first()
37-
val objectKey = pathFactory.getFinalDirectory(stream) + "/mock-perf-test-$randomString"
38-
39-
val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40-
val partsPerWorker = numParts / config.numUploadWorkers
41-
val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
42-
43-
log.info {
44-
"root key=$objectKey; part_size=${config.partSizeBytes}b; num_parts=$numParts (per_worker=$partsPerWorker); total_size=${actualSizeBytes}b; num_workers=${config.numUploadWorkers}"
68+
val mockPartQueue: Channel<FileSegment> = Channel(Channel.UNLIMITED)
69+
val streamCount = AtomicLong(catalog.streams.size.toLong())
70+
launch {
71+
reservingDeserializingInputFlow.collect { (_, reservation) ->
72+
when (val message = reservation.value) {
73+
is GlobalCheckpoint -> launch {
74+
outputConsumer.accept(
75+
message.withDestinationStats(CheckpointMessage.Stats(0)).asProtocolMessage()
76+
)
77+
}
78+
is StreamCheckpoint -> launch {
79+
val (_, count) = syncManager.getStreamManager(message.checkpoint.stream).markCheckpoint()
80+
outputConsumer.accept(
81+
message.withDestinationStats(CheckpointMessage.Stats(count)).asProtocolMessage()
82+
)
83+
}
84+
is DestinationFile -> {
85+
syncManager.getStreamManager(message.stream).incrementReadCount()
86+
if (message.fileMessage.bytes == null) {
87+
throw IllegalStateException("This can't work unless you set FileMessage.bytes!")
88+
}
89+
val size = message.fileMessage.bytes!!
90+
val numWholeParts = (size / config.partSizeBytes).toInt()
91+
val numParts = numWholeParts + if (size % config.partSizeBytes > 0) 1 else 0
92+
val fileUrl = message.fileMessage.fileUrl!!
93+
log.info {
94+
"Breaking file $fileUrl (size=${size}B) into $numParts ${config.partSizeBytes}B parts"
95+
}
96+
val stream = catalog.getStream(message.stream)
97+
val directory = pathFactory.getFinalDirectory(stream)
98+
val sourceFileName = message.fileMessage.sourceFileUrl!!
99+
val objectKey = Path.of(directory, sourceFileName).toString()
100+
val upload = client.startStreamingUpload(objectKey)
101+
val partCounter = AtomicLong(numParts.toLong())
102+
repeat(numParts) { partNumber ->
103+
mockPartQueue.send(FileSegment(fileUrl, objectKey, upload, partNumber + 1) {
104+
val partsRemaining = partCounter.decrementAndGet()
105+
if (partsRemaining == 0L) {
106+
log.info {
107+
"Finished uploading $numParts parts of $fileUrl"
108+
}
109+
upload.complete()
110+
} else {
111+
log.info {
112+
"Finished uploading part ${partNumber + 1} of $fileUrl. $partsRemaining parts remaining"
113+
}
114+
}
115+
})
116+
}
117+
}
118+
is DestinationFileStreamComplete,
119+
is DestinationFileStreamIncomplete -> {
120+
if (streamCount.decrementAndGet() == 0L) {
121+
log.info {
122+
"Read final stream complete, closing mockPartQueue"
123+
}
124+
mockPartQueue.close()
125+
} else {
126+
log.info {
127+
"Read stream complete, ${streamCount.get()} streams remaining"
128+
}
129+
}
130+
}
131+
is DestinationRecordStreamComplete,
132+
is DestinationRecordStreamIncomplete,
133+
is DestinationRecord -> throw NotImplementedError("This hack is only for files")
134+
Undefined ->
135+
log.warn {
136+
"Undefined message received. This should not happen."
137+
}
138+
}
139+
reservation.release()
140+
}
45141
}
46142

143+
val totalBytesLoaded = AtomicLong(0L)
47144
val duration = measureTime {
48-
withContext(Dispatchers.IO.limitedParallelism(config.numUploadWorkers)) {
145+
val dispatcher = Dispatchers.IO.limitedParallelism(config.numUploadWorkers)
146+
withContext(dispatcher) {
49147
(0 until config.numUploadWorkers).map {
50148
async {
51-
val workerKey = "$objectKey-worker-$it"
52-
log.info { "Starting upload to $workerKey" }
53-
val upload = client.startStreamingUpload(workerKey)
54-
repeat(partsPerWorker) {
55-
log.info { "Uploading part ${it + 1} of $workerKey" }
56-
upload.uploadPart(randomPart, it + 1)
149+
mockPartQueue.consumeAsFlow().collect { segment ->
150+
log.info { "Starting upload to ${segment.objectKey} part ${segment.partNumber}" }
151+
RandomAccessFile(segment.fileUrl, "r").use { file ->
152+
val partSize = if (segment.partNumber == config.numPartWorkers) {
153+
file.length() - (segment.partNumber - 1) * config.partSizeBytes
154+
} else {
155+
config.partSizeBytes
156+
}
157+
val partBytes = ByteArray(partSize.toInt())
158+
file.seek((segment.partNumber - 1) * config.partSizeBytes)
159+
file.read(partBytes)
160+
segment.upload.uploadPart(partBytes, segment.partNumber)
161+
log.info {
162+
"Finished uploading part ${segment.partNumber} of ${segment.fileUrl}"
163+
}
164+
totalBytesLoaded.addAndGet(partSize)
165+
segment.callback()
166+
}
57167
}
58-
log.info { "Completing upload to $workerKey" }
59-
upload.complete()
60168
}
61169
}.awaitAll()
62170
}
63171
}
64-
val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65172
log.info {
66-
// format mbs to 2 decimal places
67-
"Uploaded $actualSizeBytes bytes in $duration seconds (${"%.2f".format(mbs)} MB/s)"
173+
val mbs = totalBytesLoaded.get().toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
174+
"Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds}s (${mbs}MB/s)"
68175
}
69176
}
70177
}

airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class S3V2JsonNoFrillsPerformanceTest :
1515
configSpecClass = S3V2Specification::class.java,
1616
defaultRecordsToInsert = 1_000_000,
1717
micronautProperties = S3V2TestUtils.PERFORMANCE_TEST_MICRONAUT_PROPERTIES,
18-
numFilesForFileTransfer = 5,
18+
numFilesForFileTransfer = 10,
1919
fileSizeMbForFileTransfer = 1024,
2020
) {
2121
@Test

0 commit comments

Comments
 (0)