Skip to content

Commit 9db9582

Browse files
DO NOT MERGE: S3 write-op override for fast file xfer
1 parent e4e9168 commit 9db9582

File tree

5 files changed

+172
-37
lines changed

5 files changed

+172
-37
lines changed

airbyte-integrations/connectors/destination-s3/metadata.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ data:
22
connectorSubtype: file
33
connectorType: destination
44
definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362
5-
dockerImageTag: 1.5.6
5+
dockerImageTag: 1.5.7
66
dockerRepository: airbyte/destination-s3
77
githubIssueLabel: destination-s3
88
icon: s3.svg

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ class S3V2ConfigurationFactory :
7373
objectStoragePathConfiguration = pojo.toObjectStoragePathConfiguration(),
7474
objectStorageFormatConfiguration = pojo.toObjectStorageFormatConfiguration(),
7575
objectStorageCompressionConfiguration = pojo.toCompressionConfiguration(),
76-
numUploadWorkers = pojo.numObjectLoaders ?: 25,
77-
partSizeBytes = (pojo.partSizeMb ?: 50) * 1024L * 1024L,
76+
numUploadWorkers = pojo.numObjectLoaders ?: 10,
77+
partSizeBytes = (pojo.partSizeMb ?: 10) * 1024L * 1024L,
7878
useLegacyClient = pojo.useLegacyClient ?: false,
79-
objectSizeBytes = (pojo.totalDataMb ?: 2024) * 1024L * 1024L,
8079
)
8180
}
8281
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt

-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class S3V2Specification :
8080
val numObjectLoaders: Int? = null
8181
@get:JsonProperty("part_size_mb")
8282
val partSizeMb: Int? = null
83-
@get:JsonProperty("total_data_mb")
84-
val totalDataMb: Int? = null
8583
@get:JsonProperty("use_legacy_client")
8684
val useLegacyClient: Boolean? = null
8785
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt

+167-29
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,207 @@ package io.airbyte.integrations.destination.s3_v2
22

33
import io.airbyte.cdk.load.command.DestinationCatalog
44
import io.airbyte.cdk.load.file.object_storage.PathFactory
5+
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
56
import io.airbyte.cdk.load.file.s3.S3Client
7+
import io.airbyte.cdk.load.file.s3.S3Object
8+
import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9+
import io.airbyte.cdk.load.message.CheckpointMessage
10+
import io.airbyte.cdk.load.message.DestinationFile
11+
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12+
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13+
import io.airbyte.cdk.load.message.DestinationRecord
14+
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15+
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16+
import io.airbyte.cdk.load.message.GlobalCheckpoint
17+
import io.airbyte.cdk.load.message.StreamCheckpoint
18+
import io.airbyte.cdk.load.message.Undefined
19+
import io.airbyte.cdk.load.state.SyncManager
620
import io.airbyte.cdk.load.task.SelfTerminating
721
import io.airbyte.cdk.load.task.TerminalCondition
22+
import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
823
import io.airbyte.cdk.load.write.WriteOpOverride
24+
import io.airbyte.protocol.models.v0.AirbyteMessage
925
import io.github.oshai.kotlinlogging.KotlinLogging
1026
import jakarta.inject.Singleton
27+
import java.io.File
28+
import java.io.RandomAccessFile
29+
import java.nio.file.Path
30+
import java.util.concurrent.atomic.AtomicLong
31+
import java.util.function.Consumer
1132
import kotlin.random.Random
1233
import kotlin.time.measureTime
1334
import kotlinx.coroutines.Dispatchers
1435
import kotlinx.coroutines.ExperimentalCoroutinesApi
1536
import kotlinx.coroutines.async
1637
import kotlinx.coroutines.awaitAll
38+
import kotlinx.coroutines.channels.Channel
1739
import kotlinx.coroutines.coroutineScope
40+
import kotlinx.coroutines.flow.consumeAsFlow
41+
import kotlinx.coroutines.flow.flowOn
42+
import kotlinx.coroutines.launch
1843
import kotlinx.coroutines.withContext
1944

45+
data class FileSegment(
46+
val fileUrl: String,
47+
val objectKey: String,
48+
val upload: StreamingUpload<S3Object>,
49+
val partNumber: Int,
50+
val partSize: Long,
51+
val callback: suspend () -> Unit = {}
52+
)
53+
2054
@Singleton
2155
class S3V2WriteOpOverride(
2256
private val client: S3Client,
2357
private val catalog: DestinationCatalog,
2458
private val config: S3V2Configuration<*>,
2559
private val pathFactory: PathFactory,
60+
private val reservingDeserializingInputFlow: ReservingDeserializingInputFlow,
61+
private val outputConsumer: Consumer<AirbyteMessage>,
62+
private val syncManager: SyncManager,
2663
): WriteOpOverride {
2764
private val log = KotlinLogging.logger { }
2865

2966
override val terminalCondition: TerminalCondition = SelfTerminating
3067

3168
@OptIn(ExperimentalCoroutinesApi::class)
3269
override suspend fun execute() = coroutineScope {
33-
val prng = Random(System.currentTimeMillis())
34-
val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35-
val randomString = randomPart.take(32).joinToString("") { "%02x".format(it) }
36-
val stream = catalog.streams.first()
37-
val objectKey = pathFactory.getFinalDirectory(stream) + "/mock-perf-test-$randomString"
70+
val mockPartQueue: Channel<FileSegment> = Channel(Channel.UNLIMITED)
71+
val streamCount = AtomicLong(catalog.streams.size.toLong())
72+
val totalBytesLoaded = AtomicLong(0L)
73+
try {
74+
withContext(Dispatchers.IO) {
75+
val duration = measureTime {
76+
launch {
77+
reservingDeserializingInputFlow.collect { (_, reservation) ->
78+
when (val message = reservation.value) {
79+
is GlobalCheckpoint -> {
80+
outputConsumer.accept(
81+
message.withDestinationStats(CheckpointMessage.Stats(0))
82+
.asProtocolMessage()
83+
)
84+
}
85+
is StreamCheckpoint -> {
86+
val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
87+
.markCheckpoint()
88+
log.info { "Flushing state" }
89+
outputConsumer.accept(
90+
message.withDestinationStats(
91+
CheckpointMessage.Stats(
92+
count
93+
)
94+
)
95+
.asProtocolMessage()
96+
)
97+
log.info { "Done flushing state" }
98+
}
99+
is DestinationFile -> {
100+
syncManager.getStreamManager(message.stream)
101+
.incrementReadCount()
102+
if (message.fileMessage.bytes == null) {
103+
throw IllegalStateException("This can't work unless you set FileMessage.bytes!")
104+
}
105+
val size = message.fileMessage.bytes!!
106+
val numWholeParts = (size / config.partSizeBytes).toInt()
107+
val numParts =
108+
numWholeParts + if (size % config.partSizeBytes > 0) 1 else 0
109+
val lastPartSize = size % config.partSizeBytes
110+
val fileUrl = message.fileMessage.fileUrl!!
111+
log.info {
112+
"Breaking file $fileUrl (size=${size}B) into $numParts ${config.partSizeBytes}B parts"
113+
}
114+
val stream = catalog.getStream(message.stream)
115+
val directory = pathFactory.getFinalDirectory(stream)
116+
val sourceFileName = message.fileMessage.sourceFileUrl!!
117+
val objectKey = Path.of(directory, sourceFileName).toString()
118+
val upload = client.startStreamingUpload(objectKey)
119+
val partCounter = AtomicLong(numParts.toLong())
120+
repeat(numParts) { partNumber ->
121+
mockPartQueue.send(
122+
FileSegment(
123+
fileUrl,
124+
objectKey,
125+
upload,
126+
partNumber + 1,
127+
if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes
128+
) {
129+
val partsRemaining = partCounter.decrementAndGet()
130+
if (partsRemaining == 0L) {
131+
log.info {
132+
"Finished uploading $numParts parts of $fileUrl; deleting file and finishing upload"
133+
}
134+
File(fileUrl).delete()
135+
log.info {
136+
"Finished deleting"
137+
}
138+
upload.complete()
139+
log.info {
140+
"Finished completing the upload"
141+
}
142+
} else {
143+
log.info {
144+
"Finished uploading part ${partNumber + 1} of $fileUrl. $partsRemaining parts remaining"
145+
}
146+
}
147+
}
148+
)
149+
}
150+
}
38151

39-
val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40-
val partsPerWorker = numParts / config.numUploadWorkers
41-
val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
152+
is DestinationFileStreamComplete,
153+
is DestinationFileStreamIncomplete -> {
154+
if (streamCount.decrementAndGet() == 0L) {
155+
log.info {
156+
"Read final stream complete, closing mockPartQueue"
157+
}
158+
mockPartQueue.close()
159+
} else {
160+
log.info {
161+
"Read stream complete, ${streamCount.get()} streams remaining"
162+
}
163+
}
164+
}
42165

43-
log.info {
44-
"root key=$objectKey; part_size=${config.partSizeBytes}b; num_parts=$numParts (per_worker=$partsPerWorker); total_size=${actualSizeBytes}b; num_workers=${config.numUploadWorkers}"
45-
}
166+
is DestinationRecordStreamComplete,
167+
is DestinationRecordStreamIncomplete,
168+
is DestinationRecord -> throw NotImplementedError("This hack is only for files")
46169

47-
val duration = measureTime {
48-
withContext(Dispatchers.IO.limitedParallelism(config.numUploadWorkers)) {
49-
(0 until config.numUploadWorkers).map {
50-
async {
51-
val workerKey = "$objectKey-worker-$it"
52-
log.info { "Starting upload to $workerKey" }
53-
val upload = client.startStreamingUpload(workerKey)
54-
repeat(partsPerWorker) {
55-
log.info { "Uploading part ${it + 1} of $workerKey" }
56-
upload.uploadPart(randomPart, it + 1)
170+
Undefined ->
171+
log.warn {
172+
"Undefined message received. This should not happen."
173+
}
174+
}
175+
reservation.release()
57176
}
58-
log.info { "Completing upload to $workerKey" }
59-
upload.complete()
60177
}
61-
}.awaitAll()
178+
179+
(0 until config.numUploadWorkers).map {
180+
async {
181+
mockPartQueue.consumeAsFlow().collect { segment ->
182+
log.info { "Starting upload to ${segment.objectKey} part ${segment.partNumber}" }
183+
RandomAccessFile(segment.fileUrl, "r").use { file ->
184+
val partBytes = ByteArray(segment.partSize.toInt())
185+
file.seek((segment.partNumber - 1) * config.partSizeBytes)
186+
file.read(partBytes)
187+
segment.upload.uploadPart(partBytes, segment.partNumber)
188+
log.info {
189+
"Finished uploading part ${segment.partNumber} of ${segment.fileUrl}"
190+
}
191+
totalBytesLoaded.addAndGet(segment.partSize)
192+
segment.callback()
193+
}
194+
}
195+
}
196+
}.awaitAll()
197+
}
198+
log.info {
199+
val mbs = totalBytesLoaded.get()
200+
.toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
201+
"Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds}s (${mbs}MB/s)"
202+
}
62203
}
63-
}
64-
val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65-
log.info {
66-
// format mbs to 2 decimal places
67-
"Uploaded $actualSizeBytes bytes in $duration seconds (${"%.2f".format(mbs)} MB/s)"
204+
} catch (e: Throwable) {
205+
log.error(e) { "Error uploading file, bailing" }
68206
}
69207
}
70208
}

airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class S3V2JsonNoFrillsPerformanceTest :
1515
configSpecClass = S3V2Specification::class.java,
1616
defaultRecordsToInsert = 1_000_000,
1717
micronautProperties = S3V2TestUtils.PERFORMANCE_TEST_MICRONAUT_PROPERTIES,
18-
numFilesForFileTransfer = 5,
19-
fileSizeMbForFileTransfer = 1024,
18+
numFilesForFileTransfer = 3,
19+
fileSizeMbForFileTransfer = 1099,
2020
) {
2121
@Test
2222
override fun testFileTransfer() {

0 commit comments

Comments
 (0)