Skip to content

Commit 6def13b

Browse files
DO NOT MERGE: S3 write-op override for fast file xfer
1 parent e4e9168 commit 6def13b

File tree

5 files changed

+173
-36
lines changed

5 files changed

+173
-36
lines changed

airbyte-integrations/connectors/destination-s3/metadata.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ data:
22
connectorSubtype: file
33
connectorType: destination
44
definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362
5-
dockerImageTag: 1.5.6
5+
dockerImageTag: 1.5.7
66
dockerRepository: airbyte/destination-s3
77
githubIssueLabel: destination-s3
88
icon: s3.svg

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ class S3V2ConfigurationFactory :
7373
objectStoragePathConfiguration = pojo.toObjectStoragePathConfiguration(),
7474
objectStorageFormatConfiguration = pojo.toObjectStorageFormatConfiguration(),
7575
objectStorageCompressionConfiguration = pojo.toCompressionConfiguration(),
76-
numUploadWorkers = pojo.numObjectLoaders ?: 25,
77-
partSizeBytes = (pojo.partSizeMb ?: 50) * 1024L * 1024L,
76+
numUploadWorkers = pojo.numObjectLoaders ?: 10,
77+
partSizeBytes = (pojo.partSizeMb ?: 10) * 1024L * 1024L,
7878
useLegacyClient = pojo.useLegacyClient ?: false,
79-
objectSizeBytes = (pojo.totalDataMb ?: 2024) * 1024L * 1024L,
8079
)
8180
}
8281
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt

-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class S3V2Specification :
8080
val numObjectLoaders: Int? = null
8181
@get:JsonProperty("part_size_mb")
8282
val partSizeMb: Int? = null
83-
@get:JsonProperty("total_data_mb")
84-
val totalDataMb: Int? = null
8583
@get:JsonProperty("use_legacy_client")
8684
val useLegacyClient: Boolean? = null
8785
}

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt

+168-28
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,209 @@ package io.airbyte.integrations.destination.s3_v2
22

33
import io.airbyte.cdk.load.command.DestinationCatalog
44
import io.airbyte.cdk.load.file.object_storage.PathFactory
5+
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
56
import io.airbyte.cdk.load.file.s3.S3Client
7+
import io.airbyte.cdk.load.file.s3.S3Object
8+
import io.airbyte.cdk.load.file.s3.S3StreamingUpload
9+
import io.airbyte.cdk.load.message.CheckpointMessage
10+
import io.airbyte.cdk.load.message.DestinationFile
11+
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
12+
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
13+
import io.airbyte.cdk.load.message.DestinationRecord
14+
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
15+
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
16+
import io.airbyte.cdk.load.message.GlobalCheckpoint
17+
import io.airbyte.cdk.load.message.StreamCheckpoint
18+
import io.airbyte.cdk.load.message.Undefined
19+
import io.airbyte.cdk.load.state.SyncManager
620
import io.airbyte.cdk.load.task.SelfTerminating
721
import io.airbyte.cdk.load.task.TerminalCondition
22+
import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
823
import io.airbyte.cdk.load.write.WriteOpOverride
24+
import io.airbyte.protocol.models.v0.AirbyteMessage
925
import io.github.oshai.kotlinlogging.KotlinLogging
1026
import jakarta.inject.Singleton
27+
import java.io.File
28+
import java.io.RandomAccessFile
29+
import java.nio.file.Path
30+
import java.util.concurrent.atomic.AtomicLong
31+
import java.util.function.Consumer
1132
import kotlin.random.Random
1233
import kotlin.time.measureTime
1334
import kotlinx.coroutines.Dispatchers
1435
import kotlinx.coroutines.ExperimentalCoroutinesApi
1536
import kotlinx.coroutines.async
1637
import kotlinx.coroutines.awaitAll
38+
import kotlinx.coroutines.channels.Channel
1739
import kotlinx.coroutines.coroutineScope
40+
import kotlinx.coroutines.flow.consumeAsFlow
41+
import kotlinx.coroutines.flow.flowOn
42+
import kotlinx.coroutines.launch
1843
import kotlinx.coroutines.withContext
1944

45+
data class FileSegment(
46+
val fileUrl: String,
47+
val objectKey: String,
48+
val upload: StreamingUpload<S3Object>,
49+
val partNumber: Int,
50+
val partSize: Long,
51+
val callback: suspend () -> Unit = {}
52+
)
53+
2054
@Singleton
2155
class S3V2WriteOpOverride(
2256
private val client: S3Client,
2357
private val catalog: DestinationCatalog,
2458
private val config: S3V2Configuration<*>,
2559
private val pathFactory: PathFactory,
60+
private val reservingDeserializingInputFlow: ReservingDeserializingInputFlow,
61+
private val outputConsumer: Consumer<AirbyteMessage>,
62+
private val syncManager: SyncManager,
2663
): WriteOpOverride {
2764
private val log = KotlinLogging.logger { }
2865

2966
override val terminalCondition: TerminalCondition = SelfTerminating
3067

3168
@OptIn(ExperimentalCoroutinesApi::class)
3269
override suspend fun execute() = coroutineScope {
33-
val prng = Random(System.currentTimeMillis())
34-
val randomPart = prng.nextBytes(config.partSizeBytes.toInt())
35-
val randomString = randomPart.take(32).joinToString("") { "%02x".format(it) }
36-
val stream = catalog.streams.first()
37-
val objectKey = pathFactory.getFinalDirectory(stream) + "/mock-perf-test-$randomString"
70+
val mockPartQueue: Channel<FileSegment> = Channel(Channel.UNLIMITED)
71+
val streamCount = AtomicLong(catalog.streams.size.toLong())
72+
withContext(Dispatchers.IO) {
73+
launch {
74+
reservingDeserializingInputFlow.collect { (_, reservation) ->
75+
when (val message = reservation.value) {
76+
is GlobalCheckpoint -> launch {
77+
outputConsumer.accept(
78+
message.withDestinationStats(CheckpointMessage.Stats(0))
79+
.asProtocolMessage()
80+
)
81+
}
3882

39-
val numParts = (config.objectSizeBytes / config.partSizeBytes).toInt()
40-
val partsPerWorker = numParts / config.numUploadWorkers
41-
val actualSizeBytes = partsPerWorker * config.numUploadWorkers * config.partSizeBytes
83+
is StreamCheckpoint -> {
84+
val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
85+
.markCheckpoint()
86+
launch {
87+
log.info { "Flushing state" }
88+
outputConsumer.accept(
89+
message.withDestinationStats(CheckpointMessage.Stats(count))
90+
.asProtocolMessage()
91+
)
92+
log.info { "Done flushing state" }
93+
}
94+
}
4295

43-
log.info {
44-
"root key=$objectKey; part_size=${config.partSizeBytes}b; num_parts=$numParts (per_worker=$partsPerWorker); total_size=${actualSizeBytes}b; num_workers=${config.numUploadWorkers}"
45-
}
96+
is DestinationFile -> {
97+
syncManager.getStreamManager(message.stream).incrementReadCount()
98+
if (message.fileMessage.bytes == null) {
99+
throw IllegalStateException("This can't work unless you set FileMessage.bytes!")
100+
}
101+
val size = message.fileMessage.bytes!!
102+
val numWholeParts = (size / config.partSizeBytes).toInt()
103+
val numParts =
104+
numWholeParts + if (size % config.partSizeBytes > 0) 1 else 0
105+
val lastPartSize = size % config.partSizeBytes
106+
val fileUrl = message.fileMessage.fileUrl!!
107+
log.info {
108+
"Breaking file $fileUrl (size=${size}B) into $numParts ${config.partSizeBytes}B parts"
109+
}
110+
val stream = catalog.getStream(message.stream)
111+
val directory = pathFactory.getFinalDirectory(stream)
112+
val sourceFileName = message.fileMessage.sourceFileUrl!!
113+
val objectKey = Path.of(directory, sourceFileName).toString()
114+
val upload = client.startStreamingUpload(objectKey)
115+
val partCounter = AtomicLong(numParts.toLong())
116+
repeat(numParts) { partNumber ->
117+
mockPartQueue.send(
118+
FileSegment(
119+
fileUrl,
120+
objectKey,
121+
upload,
122+
partNumber + 1,
123+
if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes
124+
) {
125+
val partsRemaining = partCounter.decrementAndGet()
126+
if (partsRemaining == 0L) {
127+
log.info {
128+
"Finished uploading $numParts parts of $fileUrl; deleting file and finishing upload"
129+
}
130+
File(fileUrl).delete()
131+
log.info {
132+
"Finished deleting"
133+
}
134+
upload.complete()
135+
log.info {
136+
"Finished completing the upload"
137+
}
138+
} else {
139+
log.info {
140+
"Finished uploading part ${partNumber + 1} of $fileUrl. $partsRemaining parts remaining"
141+
}
142+
}
143+
}
144+
)
145+
}
146+
}
46147

47-
val duration = measureTime {
48-
withContext(Dispatchers.IO.limitedParallelism(config.numUploadWorkers)) {
49-
(0 until config.numUploadWorkers).map {
50-
async {
51-
val workerKey = "$objectKey-worker-$it"
52-
log.info { "Starting upload to $workerKey" }
53-
val upload = client.startStreamingUpload(workerKey)
54-
repeat(partsPerWorker) {
55-
log.info { "Uploading part ${it + 1} of $workerKey" }
56-
upload.uploadPart(randomPart, it + 1)
148+
is DestinationFileStreamComplete,
149+
is DestinationFileStreamIncomplete -> {
150+
if (streamCount.decrementAndGet() == 0L) {
151+
log.info {
152+
"Read final stream complete, closing mockPartQueue"
153+
}
154+
mockPartQueue.close()
155+
} else {
156+
log.info {
157+
"Read stream complete, ${streamCount.get()} streams remaining"
158+
}
159+
}
57160
}
58-
log.info { "Completing upload to $workerKey" }
59-
upload.complete()
161+
162+
is DestinationRecordStreamComplete,
163+
is DestinationRecordStreamIncomplete,
164+
is DestinationRecord -> throw NotImplementedError("This hack is only for files")
165+
166+
Undefined ->
167+
log.warn {
168+
"Undefined message received. This should not happen."
169+
}
60170
}
61-
}.awaitAll()
171+
reservation.release()
172+
}
62173
}
63174
}
64-
val mbs = actualSizeBytes.toFloat() / duration.inWholeSeconds.toFloat() / 1024 / 1024
65-
log.info {
66-
// format mbs to 2 decimal places
67-
"Uploaded $actualSizeBytes bytes in $duration seconds (${"%.2f".format(mbs)} MB/s)"
175+
176+
val totalBytesLoaded = AtomicLong(0L)
177+
try {
178+
val duration = measureTime {
179+
val dispatcher = Dispatchers.IO.limitedParallelism(config.numUploadWorkers)
180+
withContext(dispatcher) {
181+
(0 until config.numUploadWorkers).map {
182+
async {
183+
mockPartQueue.consumeAsFlow().collect { segment ->
184+
log.info { "Starting upload to ${segment.objectKey} part ${segment.partNumber}" }
185+
RandomAccessFile(segment.fileUrl, "r").use { file ->
186+
val partBytes = ByteArray(segment.partSize.toInt())
187+
file.seek((segment.partNumber - 1) * config.partSizeBytes)
188+
file.read(partBytes)
189+
segment.upload.uploadPart(partBytes, segment.partNumber)
190+
log.info {
191+
"Finished uploading part ${segment.partNumber} of ${segment.fileUrl}"
192+
}
193+
totalBytesLoaded.addAndGet(segment.partSize)
194+
segment.callback()
195+
}
196+
}
197+
}
198+
}.awaitAll()
199+
}
200+
}
201+
log.info {
202+
val mbs = totalBytesLoaded.get()
203+
.toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
204+
"Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds}s (${mbs}MB/s)"
205+
}
206+
} catch (e: Throwable) {
207+
log.error(e) { "Error uploading file, bailing" }
68208
}
69209
}
70210
}

airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class S3V2JsonNoFrillsPerformanceTest :
1515
configSpecClass = S3V2Specification::class.java,
1616
defaultRecordsToInsert = 1_000_000,
1717
micronautProperties = S3V2TestUtils.PERFORMANCE_TEST_MICRONAUT_PROPERTIES,
18-
numFilesForFileTransfer = 5,
19-
fileSizeMbForFileTransfer = 1024,
18+
numFilesForFileTransfer = 3,
19+
fileSizeMbForFileTransfer = 1099,
2020
) {
2121
@Test
2222
override fun testFileTransfer() {

0 commit comments

Comments
 (0)