Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE: S3 write-op override for fast file xfer #55753

Draft
wants to merge 2 commits into
base: jschmidt/perf-test/spammy-override
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import io.airbyte.cdk.load.task.internal.SpillToDiskTaskFactory
import io.airbyte.cdk.load.task.internal.UpdateBatchStateTaskFactory
import io.airbyte.cdk.load.task.internal.UpdateCheckpointsTask
import io.airbyte.cdk.load.util.setOnce
import io.airbyte.cdk.load.write.WriteOpOverride
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import io.micronaut.context.annotation.Value
Expand Down Expand Up @@ -153,6 +154,7 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
private val loadPipeline: LoadPipeline?,
private val partitioner: InputPartitioner,
private val updateBatchTaskFactory: UpdateBatchStateTaskFactory,
private val writeOpOverride: WriteOpOverride? = null
) : DestinationTaskLauncher {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -200,6 +202,13 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
}

override suspend fun run() {
if (writeOpOverride != null) {
log.info { "Write operation override found, running override task." }
return
} else {
log.info { "No write operation override found, continuing with normal operation." }
}

// Start the input consumer ASAP
log.info { "Starting input consumer task" }
val inputConsumerTask =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.airbyte.cdk.Operation
import io.airbyte.cdk.load.state.DestinationFailure
import io.airbyte.cdk.load.state.DestinationSuccess
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.Task
import io.airbyte.cdk.load.task.TaskLauncher
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
Expand All @@ -17,6 +18,8 @@ import java.io.InputStream
import javax.inject.Singleton
import kotlinx.coroutines.runBlocking

interface WriteOpOverride: Task

/**
* Write operation. Executed by the core framework when the operation is "write". Launches the core
* services and awaits completion.
Expand All @@ -26,12 +29,21 @@ import kotlinx.coroutines.runBlocking
class WriteOperation(
private val taskLauncher: TaskLauncher,
private val syncManager: SyncManager,
private val writeOpOverride: WriteOpOverride? = null
) : Operation {
val log = KotlinLogging.logger {}

override fun execute() = runBlocking {
taskLauncher.run()

if (writeOpOverride != null) {
val now = System.currentTimeMillis()
log.info { "Running override task" }
writeOpOverride.execute()
log.info { "Write operation override took ${System.currentTimeMillis() - now} ms" }
return@runBlocking
}

when (val result = syncManager.awaitDestinationResult()) {
is DestinationSuccess -> {
if (!syncManager.allStreamsComplete()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ data:
connectorSubtype: file
connectorType: destination
definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362
dockerImageTag: 1.5.5
dockerImageTag: 1.5.7
dockerRepository: airbyte/destination-s3
githubIssueLabel: destination-s3
icon: s3.svg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ data class S3V2Configuration<T : OutputStream>(

/** Below has no effect until [S3V2ObjectLoader] is enabled. */
val numPartWorkers: Int = 2,
val numUploadWorkers: Int = 10,
val numUploadWorkers: Int = 5,
val maxMemoryRatioReservedForParts: Double = 0.2,
val objectSizeBytes: Long = 200L * 1024 * 1024,
val partSizeBytes: Long = 10L * 1024 * 1024,
Expand Down Expand Up @@ -73,9 +73,7 @@ class S3V2ConfigurationFactory :
objectStoragePathConfiguration = pojo.toObjectStoragePathConfiguration(),
objectStorageFormatConfiguration = pojo.toObjectStorageFormatConfiguration(),
objectStorageCompressionConfiguration = pojo.toCompressionConfiguration(),
numPartWorkers = pojo.numPartWorkers ?: 2,
numUploadWorkers = pojo.numObjectLoaders ?: 3,
maxMemoryRatioReservedForParts = pojo.maxMemoryRatioReservedForParts ?: 0.2,
numUploadWorkers = pojo.numObjectLoaders ?: 10,
partSizeBytes = (pojo.partSizeMb ?: 10) * 1024L * 1024L,
useLegacyClient = pojo.useLegacyClient ?: false,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ class S3V2Specification :
)
override val fileNamePattern: String? = null

@get:JsonProperty("num_part_workers")
val numPartWorkers: Int? = null
@get:JsonProperty("num_upload_workers")
@get:JsonProperty("num_uploaders")
val numObjectLoaders: Int? = null
@get:JsonProperty("part_size_mb")
val partSizeMb: Int? = null
@get:JsonProperty("max_memory_ratio_reserved_for_parts")
val maxMemoryRatioReservedForParts: Double? = null
@get:JsonProperty("use_legacy_client")
val useLegacyClient: Boolean? = null
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package io.airbyte.integrations.destination.s3_v2

import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.file.object_storage.PathFactory
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
import io.airbyte.cdk.load.file.s3.S3Client
import io.airbyte.cdk.load.file.s3.S3Object
import io.airbyte.cdk.load.file.s3.S3StreamingUpload
import io.airbyte.cdk.load.message.CheckpointMessage
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationFileStreamComplete
import io.airbyte.cdk.load.message.DestinationFileStreamIncomplete
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationRecordStreamComplete
import io.airbyte.cdk.load.message.DestinationRecordStreamIncomplete
import io.airbyte.cdk.load.message.GlobalCheckpoint
import io.airbyte.cdk.load.message.StreamCheckpoint
import io.airbyte.cdk.load.message.Undefined
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.SelfTerminating
import io.airbyte.cdk.load.task.TerminalCondition
import io.airbyte.cdk.load.task.internal.ReservingDeserializingInputFlow
import io.airbyte.cdk.load.write.WriteOpOverride
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import java.io.File
import java.io.RandomAccessFile
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import java.nio.file.Path
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import kotlin.random.Random
import kotlin.time.measureTime
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

data class FileSegment(
val fileUrl: String,
val objectKey: String,
val upload: StreamingUpload<S3Object>,
val partNumber: Int,
val partSize: Long,
val mappedbuffer: MappedByteBuffer,
val callback: suspend () -> Unit = {}
)

@Singleton
class S3V2WriteOpOverride(
private val client: S3Client,
private val catalog: DestinationCatalog,
private val config: S3V2Configuration<*>,
private val pathFactory: PathFactory,
private val reservingDeserializingInputFlow: ReservingDeserializingInputFlow,
private val outputConsumer: Consumer<AirbyteMessage>,
private val syncManager: SyncManager,
): WriteOpOverride {
private val log = KotlinLogging.logger { }

override val terminalCondition: TerminalCondition = SelfTerminating

@OptIn(ExperimentalCoroutinesApi::class)
override suspend fun execute() = coroutineScope {
val mockPartQueue: Channel<FileSegment> = Channel(Channel.UNLIMITED)
val streamCount = AtomicLong(catalog.streams.size.toLong())
val totalBytesLoaded = AtomicLong(0L)
try {
withContext(Dispatchers.IO) {
val duration = measureTime {
launch {
reservingDeserializingInputFlow.collect { (_, reservation) ->
when (val message = reservation.value) {
is GlobalCheckpoint -> {
outputConsumer.accept(
message.withDestinationStats(CheckpointMessage.Stats(0))
.asProtocolMessage()
)
}
is StreamCheckpoint -> {
val (_, count) = syncManager.getStreamManager(message.checkpoint.stream)
.markCheckpoint()
log.info { "Flushing state" }
outputConsumer.accept(
message.withDestinationStats(
CheckpointMessage.Stats(
count
)
)
.asProtocolMessage()
)
log.info { "Done flushing state" }
}
is DestinationFile -> {
syncManager.getStreamManager(message.stream)
.incrementReadCount()
if (message.fileMessage.bytes == null) {
throw IllegalStateException("This can't work unless you set FileMessage.bytes!")
}
val size = message.fileMessage.bytes!!
val numWholeParts = (size / config.partSizeBytes).toInt()
val numParts =
numWholeParts + if (size % config.partSizeBytes > 0) 1 else 0
val lastPartSize = size % config.partSizeBytes
val fileUrl = message.fileMessage.fileUrl!!
log.info {
"Breaking file $fileUrl (size=${size}B) into $numParts ${config.partSizeBytes}B parts"
}
val stream = catalog.getStream(message.stream)
val directory = pathFactory.getFinalDirectory(stream)
val sourceFileName = message.fileMessage.sourceFileUrl!!
val objectKey = Path.of(directory, sourceFileName).toString()
val upload = client.startStreamingUpload(objectKey)
val partCounter = AtomicLong(numParts.toLong())
val raf = RandomAccessFile(fileUrl, "r")
val memoryMap = raf.channel.map(
FileChannel.MapMode.READ_ONLY,
0,
size
)
repeat(numParts) { partNumber ->
mockPartQueue.send(
FileSegment(
fileUrl,
objectKey,
upload,
partNumber + 1,
if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes,
memoryMap.slice(
(partNumber * config.partSizeBytes).toInt(),
(if (partNumber == numParts - 1) lastPartSize else config.partSizeBytes).toInt()
),
) {
val partsRemaining = partCounter.decrementAndGet()
if (partsRemaining == 0L) {
log.info {
"Finished uploading $numParts parts of $fileUrl; deleting file and finishing upload"
}
raf.close()
File(fileUrl).delete()
log.info {
"Finished deleting"
}
upload.complete()
log.info {
"Finished completing the upload"
}
} else {
log.info {
"Finished uploading part ${partNumber + 1} of $fileUrl. $partsRemaining parts remaining"
}
}
}
)
}
}

is DestinationFileStreamComplete,
is DestinationFileStreamIncomplete -> {
if (streamCount.decrementAndGet() == 0L) {
log.info {
"Read final stream complete, closing mockPartQueue"
}
mockPartQueue.close()
} else {
log.info {
"Read stream complete, ${streamCount.get()} streams remaining"
}
}
}

is DestinationRecordStreamComplete,
is DestinationRecordStreamIncomplete,
is DestinationRecord -> throw NotImplementedError("This hack is only for files")

Undefined ->
log.warn {
"Undefined message received. This should not happen."
}
}
reservation.release()
}
}

(0 until config.numUploadWorkers).map {
async {
mockPartQueue.consumeAsFlow().collect { segment ->
log.info { "Starting upload to ${segment.objectKey} part ${segment.partNumber}" }
val partBytes = ByteArray(segment.partSize.toInt())
segment.mappedbuffer.get(partBytes)
segment.upload.uploadPart(partBytes, segment.partNumber)
log.info {
"Finished uploading part ${segment.partNumber} of ${segment.fileUrl}"
}
totalBytesLoaded.addAndGet(segment.partSize)
segment.callback()
}
}
}.awaitAll()
}
log.info {
val mbs = totalBytesLoaded.get()
.toDouble() / 1024 / 1024 / duration.inWholeSeconds.toDouble()
"Uploaded ${totalBytesLoaded.get()} bytes in ${duration.inWholeSeconds}s (${mbs}MB/s)"
}
}
} catch (e: Throwable) {
log.error(e) { "Error uploading file, bailing" }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class S3V2JsonNoFrillsPerformanceTest :
configSpecClass = S3V2Specification::class.java,
defaultRecordsToInsert = 1_000_000,
micronautProperties = S3V2TestUtils.PERFORMANCE_TEST_MICRONAUT_PROPERTIES,
numFilesForFileTransfer = 5,
fileSizeMbForFileTransfer = 1024,
numFilesForFileTransfer = 3,
fileSizeMbForFileTransfer = 1099,
) {
@Test
override fun testFileTransfer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,12 @@ class S3V2WriteTestJsonUncompressed :
schematizedArrayBehavior = SchematizedNestedValueBehavior.PASS_THROUGH,
preserveUndeclaredFields = true,
allTypesBehavior = Untyped,
)
) {
@Test
override fun testBasicWrite() {
super.testBasicWrite()
}
}

class S3V2WriteTestJsonRootLevelFlattening :
S3V2WriteTest(
Expand Down
Loading