diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt index 1cb1c75b95d2f..3f9eb0c9b01c7 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt @@ -44,6 +44,7 @@ import io.airbyte.cdk.load.task.internal.SpillToDiskTaskFactory import io.airbyte.cdk.load.task.internal.UpdateBatchStateTaskFactory import io.airbyte.cdk.load.task.internal.UpdateCheckpointsTask import io.airbyte.cdk.load.util.setOnce +import io.airbyte.cdk.load.write.WriteOpOverride import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import io.micronaut.context.annotation.Value @@ -153,6 +154,7 @@ class DefaultDestinationTaskLauncher( private val loadPipeline: LoadPipeline?, private val partitioner: InputPartitioner, private val updateBatchTaskFactory: UpdateBatchStateTaskFactory, + private val writeOpOverride: WriteOpOverride? = null ) : DestinationTaskLauncher { private val log = KotlinLogging.logger {} @@ -200,6 +202,13 @@ class DefaultDestinationTaskLauncher( } override suspend fun run() { + if (writeOpOverride != null) { + log.info { "Write operation override found, running override task." } + return + } else { + log.info { "No write operation override found, continuing with normal operation." } + } + // Start the input consumer ASAP log.info { "Starting input consumer task" } val inputConsumerTask = diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/write/WriteOperation.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/write/WriteOperation.kt index 5d673a5b26f60..025f6a0ebe438 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/write/WriteOperation.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/write/WriteOperation.kt @@ -8,6 +8,7 @@ import io.airbyte.cdk.Operation import io.airbyte.cdk.load.state.DestinationFailure import io.airbyte.cdk.load.state.DestinationSuccess import io.airbyte.cdk.load.state.SyncManager +import io.airbyte.cdk.load.task.Task import io.airbyte.cdk.load.task.TaskLauncher import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Factory @@ -17,6 +18,8 @@ import java.io.InputStream import javax.inject.Singleton import kotlinx.coroutines.runBlocking +interface WriteOpOverride: Task + /** * Write operation. Executed by the core framework when the operation is "write". Launches the core * services and awaits completion. @@ -26,12 +29,21 @@ import kotlinx.coroutines.runBlocking class WriteOperation( private val taskLauncher: TaskLauncher, private val syncManager: SyncManager, + private val writeOpOverride: WriteOpOverride? = null ) : Operation { val log = KotlinLogging.logger {} override fun execute() = runBlocking { taskLauncher.run() + if (writeOpOverride != null) { + val now = System.currentTimeMillis() + log.info { "Running override task" } + writeOpOverride.execute() + log.info { "Write operation override took ${System.currentTimeMillis() - now} ms" } + return@runBlocking + } + when (val result = syncManager.awaitDestinationResult()) { is DestinationSuccess -> { if (!syncManager.allStreamsComplete()) { diff --git a/airbyte-integrations/connectors/destination-s3/metadata.yaml b/airbyte-integrations/connectors/destination-s3/metadata.yaml index 652a236bc3add..c8188a57bacf4 100644 --- a/airbyte-integrations/connectors/destination-s3/metadata.yaml +++ b/airbyte-integrations/connectors/destination-s3/metadata.yaml @@ -2,7 +2,7 @@ data: connectorSubtype: file connectorType: destination definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362 - dockerImageTag: 1.5.5 + dockerImageTag: 1.5.7 dockerRepository: airbyte/destination-s3 githubIssueLabel: destination-s3 icon: s3.svg diff --git a/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt b/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt index 683e8befccbf2..441d77080e89e 100644 --- a/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt +++ b/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Configuration.kt @@ -44,7 +44,7 @@ data class S3V2Configuration( /** Below has no effect until [S3V2ObjectLoader] is enabled. */ val numPartWorkers: Int = 2, - val numUploadWorkers: Int = 10, + val numUploadWorkers: Int = 5, val maxMemoryRatioReservedForParts: Double = 0.2, val objectSizeBytes: Long = 200L * 1024 * 1024, val partSizeBytes: Long = 10L * 1024 * 1024, @@ -73,9 +73,7 @@ class S3V2ConfigurationFactory : objectStoragePathConfiguration = pojo.toObjectStoragePathConfiguration(), objectStorageFormatConfiguration = pojo.toObjectStorageFormatConfiguration(), objectStorageCompressionConfiguration = pojo.toCompressionConfiguration(), - numPartWorkers = pojo.numPartWorkers ?: 2, - numUploadWorkers = pojo.numObjectLoaders ?: 3, - maxMemoryRatioReservedForParts = pojo.maxMemoryRatioReservedForParts ?: 0.2, + numUploadWorkers = pojo.numObjectLoaders ?: 10, partSizeBytes = (pojo.partSizeMb ?: 10) * 1024L * 1024L, useLegacyClient = pojo.useLegacyClient ?: false, ) diff --git a/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt b/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt index 31ab6711642e4..b7a03f41e5c6c 100644 --- a/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt +++ b/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2Specification.kt @@ -76,14 +76,10 @@ class S3V2Specification : ) override val fileNamePattern: String? = null - @get:JsonProperty("num_part_workers") - val numPartWorkers: Int? = null - @get:JsonProperty("num_upload_workers") + @get:JsonProperty("num_uploaders") val numObjectLoaders: Int? = null @get:JsonProperty("part_size_mb") val partSizeMb: Int? = null - @get:JsonProperty("max_memory_ratio_reserved_for_parts") - val maxMemoryRatioReservedForParts: Double? = null @get:JsonProperty("use_legacy_client") val useLegacyClient: Boolean? = null } diff --git a/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt b/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt new file mode 100644 index 0000000000000..2a13fcab2c220 --- /dev/null +++ b/airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2WriteOpOverride.kt @@ -0,0 +1,219 @@ +package io.airbyte.integrations.destination.s3_v2 + +import io.airbyte.cdk.load.command.DestinationCatalog +import io.airbyte.cdk.load.file.object_storage.PathFactory +import io.airbyte.cdk.load.file.object_storage.StreamingUpload +import io.airbyte.cdk.load.file.s3.S3Client +import io.airbyte.cdk.load.file.s3.S3Object +import io.airbyte.cdk.load.file.s3.S3StreamingUpload +import io.airbyte.cdk.load.message.CheckpointMessage +import io.airbyte.cdk.load.message.DestinationFile +import io.airbyte.cdk.load.message.DestinationFileStreamComplete +import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete +import io.airbyte.cdk.load.message.DestinationRecord +import io.airbyte.cdk.load.message.DestinationRecordStreamComplete +import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete +import io.airbyte.cdk.load.message.GlobalCheckpoint +import io.airbyte.cdk.load.message.StreamCheckpoint +import io.airbyte.cdk.load.message.Undefined +import io.airbyte.cdk.load.state.SyncManager +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.TerminalCondition +import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow +import io.airbyte.cdk.load.write.WriteOpOverride +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.github.oshai.kotlinlogging.KotlinLogging +import jakarta.inject.Singleton +import java.io.File +import java.io.RandomAccessFile +import java.nio.MappedByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.Path +import java.util.concurrent.atomic.AtomicLong +import java.util.function.Consumer +import kotlin.random.Random +import kotlin.time.measureTime +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.consumeAsFlow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext + +data class FileSegment( + val fileUrl: String, + val objectKey: String, + val upload: StreamingUpload, + val partNumber: Int, + val partSize: Long, + val mappedbuffer: MappedByteBuffer, + val callback: suspend () -> Unit = {} +) + +@Singleton +class S3V2WriteOpOverride( + private val client: S3Client, + private val catalog: DestinationCatalog, + private val config: S3V2Configuration<*>, + private val pathFactory: PathFactory, + private val reservingDeserializingInputFlow: ReservingDeserializingInputFlow, + private val outputConsumer: Consumer, + private val syncManager: SyncManager, +): WriteOpOverride { + private val log = KotlinLogging.logger { } + + override val terminalCondition: TerminalCondition = SelfTerminating + + @OptIn(ExperimentalCoroutinesApi::class) + override suspend fun execute() = coroutineScope { + val mockPartQueue: Channel = Channel(Channel.UNLIMITED) + val streamCount = AtomicLong(catalog.streams.size.toLong()) + val totalBytesLoaded = AtomicLong(0L) + try { + withContext(Dispatchers.IO) { + val duration = measureTime { + launch { + reservingDeserializingInputFlow.collect { (_, reservation) -> + when (val message = reservation.value) { + is GlobalCheckpoint -> { + outputConsumer.accept( + message.withDestinationStats(CheckpointMessage.Stats(0)) + .asProtocolMessage() + ) + } + is StreamCheckpoint -> { + val (_, count) = syncManager.getStreamManager(message.checkpoint.stream) + .markCheckpoint() + log.info { "Flushing state" } + outputConsumer.accept( + message.withDestinationStats( + CheckpointMessage.Stats( + count + ) + ) + .asProtocolMessage() + ) + log.info { "Done flushing state" } + } + is DestinationFile -> { + syncManager.getStreamManager(message.stream) + .incrementReadCount() + if (message.fileMessage.bytes == null) { + throw IllegalStateException("This can't work unless you set FileMessage.bytes!") + } + val size = message.fileMessage.bytes!! + val numWholeParts = (size / config.partSizeBytes).toInt() + val numParts = + numWholeParts + if (size % config.partSizeBytes > 0) 1 else 0 + val lastPartSize = size % config.partSizeBytes + val fileUrl = message.fileMessage.fileUrl!! + log.info { + "Breaking file $fileUrl (size=${size}B) into $numParts ${config.partSizeBytes}B parts" + } + val stream = catalog.getStream(message.stream) + val directory = pathFactory.getFinalDirectory(stream) + val sourceFileName = message.fileMessage.sourceFileUrl!! + val objectKey = Path.of(directory, sourceFileName).toString() + val upload = client.startStreamingUpload(objectKey) + val partCounter = AtomicLong(numParts.toLong()) + val raf = RandomAccessFile(fileUrl, "r") + val memoryMap = raf.channel.map( + FileChannel.MapMode.READ_ONLY, + 0, + size + ) + repeat(numParts) { partNumber -> + mockPartQueue.send( + FileSegment( + fileUrl, + objectKey, + upload, + partNumber + 1, + if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes, + memoryMap.slice( + (partNumber * config.partSizeBytes).toInt(), + (if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes).toInt() + ), + ) { + val partsRemaining = partCounter.decrementAndGet() + if (partsRemaining == 0L) { + log.info { + "Finished uploading $numParts parts of $fileUrl; deleting file and finishing upload" + } + raf.close() + File(fileUrl).delete() + log.info { + "Finished deleting" + } + upload.complete() + log.info { + "Finished completing the upload" + } + } else { + log.info { + "Finished uploading part ${partNumber + 1} of $fileUrl. $partsRemaining parts remaining" + } + } + } + ) + } + } + + is DestinationFileStreamComplete, + is DestinationFileStreamIncomplete -> { + if (streamCount.decrementAndGet() == 0L) { + log.info { + "Read final stream complete, closing mockPartQueue" + } + mockPartQueue.close() + } else { + log.info { + "Read stream complete, ${streamCount.get()} streams remaining" + } + } + } + + is DestinationRecordStreamComplete, + is DestinationRecordStreamIncomplete, + is DestinationRecord -> throw NotImplementedError("This hack is only for files") + + Undefined -> + log.warn { + "Undefined message received. This should not happen." + } + } + reservation.release() + } + } + + (0 until config.numUploadWorkers).map { + async { + mockPartQueue.consumeAsFlow().collect { segment -> + log.info { "Starting upload to ${segment.objectKey} part ${segment.partNumber}" } + val partBytes = ByteArray(segment.partSize.toInt()) + segment.mappedbuffer.get(partBytes) + segment.upload.uploadPart(partBytes, segment.partNumber) + log.info { + "Finished uploading part ${segment.partNumber} of ${segment.fileUrl}" + } + totalBytesLoaded.addAndGet(segment.partSize) + segment.callback() + } + } + }.awaitAll() + } + log.info { + val mbs = totalBytesLoaded.get() + .toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble() + "Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds}s (${mbs}MB/s)" + } + } + } catch (e: Throwable) { + log.error(e) { "Error uploading file, bailing" } + } + } +} diff --git a/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt b/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt index 0b9cdbc867296..3feb9ad20db61 100644 --- a/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt +++ b/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2PerformanceTest.kt @@ -15,8 +15,8 @@ class S3V2JsonNoFrillsPerformanceTest : configSpecClass = S3V2Specification::class.java, defaultRecordsToInsert = 1_000_000, micronautProperties = S3V2TestUtils.PERFORMANCE_TEST_MICRONAUT_PROPERTIES, - numFilesForFileTransfer = 5, - fileSizeMbForFileTransfer = 1024, + numFilesForFileTransfer = 3, + fileSizeMbForFileTransfer = 1099, ) { @Test override fun testFileTransfer() { diff --git a/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2WriteTest.kt b/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2WriteTest.kt index 623ebd2c6db44..663700e9d2ad3 100644 --- a/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2WriteTest.kt +++ b/airbyte-integrations/connectors/destination-s3/src/test-integration/kotlin/io/airbyte/integrations/destination/s3_v2/S3V2WriteTest.kt @@ -314,7 +314,12 @@ class S3V2WriteTestJsonUncompressed : schematizedArrayBehavior = SchematizedNestedValueBehavior.PASS_THROUGH, preserveUndeclaredFields = true, allTypesBehavior = Untyped, - ) + ) { + @Test + override fun testBasicWrite() { + super.testBasicWrite() + } +} class S3V2WriteTestJsonRootLevelFlattening : S3V2WriteTest(