Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Load CDK BulkLoad interface + MSSQL V2 Usage #55671

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import io.airbyte.cdk.load.message.PipelineEvent
import io.airbyte.cdk.load.message.StreamKey
import io.airbyte.cdk.load.pipeline.BatchUpdate
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.task.implementor.FileAggregateMessage
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.write.LoadStrategy
Expand Down Expand Up @@ -120,7 +119,7 @@ class SyncBeanFactory {
@Named("recordQueue")
fun recordQueue(
loadStrategy: LoadStrategy? = null,
): PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>> {
): PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>> {
return PartitionedQueue(
Array(loadStrategy?.inputPartitions ?: 1) {
ChannelMessageQueue(Channel(Channel.UNLIMITED))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ import io.airbyte.cdk.load.state.CheckpointId
/** Used internally by the CDK to pass messages between steps in the loader pipeline. */
sealed interface PipelineEvent<K : WithStream, T>

/**
* A message that contains a keyed payload. The key is used to manage the state of the payload's
* corresponding [io.airbyte.cdk.load.pipeline.BatchAccumulator]. [checkpointCounts] is used by the
* CDK to perform state message bookkeeping. [postProcessingCallback] is for releasing resources
* associated with the message.
*/
class PipelineMessage<K : WithStream, T>(
val checkpointCounts: Map<CheckpointId, Long>,
val key: K,
val value: T
val value: T,
val postProcessingCallback: suspend () -> Unit = {},
) : PipelineEvent<K, T>

/**
* We send the end message on the stream and not the key, because there's no way to partition an
* empty message.
*/
/** Broadcast at end-of-stream to all partitions to signal that the stream has ended. */
class PipelineEndOfStream<K : WithStream, T>(val stream: DestinationStream.Descriptor) :
PipelineEvent<K, T>
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,38 @@ package io.airbyte.cdk.load.pipeline
import io.airbyte.cdk.load.message.WithStream

/**
* [BatchAccumulator] is used internally by the CDK to implement RecordLoaders. Connector devs
* should never need to implement this interface.
* [BatchAccumulator] is used internally by the CDK to implement
* [io.airbyte.cdk.load.write.LoadStrategy]s. Connector devs should never need to implement this
* interface.
*
* It is the glue that connects a specific step in a specific pipeline to the generic pipeline on
* the back end. (For example, in a three-stage pipeline like bulk load, step 1 is to create a part,
* step 2 is to upload it, and step 3 is to load it from object storage into a table.)
*
* - [S] is a state type that will be threaded through accumulator calls.
* - [K] is a key type associated the input data. (NOTE: Currently, there is no support for
* key-mapping, so the key is always [io.airbyte.cdk.load.message.StreamKey]). Specifically, state
* will always be managed per-key.
* - [T] is the input data type
* - [U] is the output data type
*
* The first time data is seen for a given key, [start] is called (with the partition number). The
* state returned by [start] will be passed per input to [accept].
*
* If [accept] returns a non-null output, that output will be forwarded to the next stage (if
* applicable) and/or trigger bookkeeping (iff the output type implements
* [io.airbyte.cdk.load.message.WithBatchState]).
*
* If [accept] returns a non-null state, that state will be passed to the next call to [accept]. If
* [accept] returns a null state, the state will be discarded and a new one will be created on the
* next input by a new call to [start].
*
* When the input stream is exhausted, [finish] will be called with any remaining state iff at least
* one input was seen for that key. This means that [finish] will not be called on empty keys or on
* keys where the last call to [accept] yielded a null (finished) state.
*/
interface BatchAccumulator<S, K : WithStream, T, U> {
fun start(key: K, part: Int): S
fun accept(record: T, state: S): Pair<S, U?>
fun finish(state: S): U
suspend fun start(key: K, part: Int): S
suspend fun accept(input: T, state: S): Pair<S?, U?>
suspend fun finish(state: S): U
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import io.airbyte.cdk.load.message.PartitionedQueue
import io.airbyte.cdk.load.message.PipelineEvent
import io.airbyte.cdk.load.message.QueueWriter
import io.airbyte.cdk.load.message.StreamKey
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.task.internal.LoadPipelineStepTask
import io.airbyte.cdk.load.write.DirectLoader
import io.airbyte.cdk.load.write.DirectLoaderFactory
Expand All @@ -24,8 +23,7 @@ import jakarta.inject.Singleton
class DirectLoadPipelineStep<S : DirectLoader>(
val accumulator: DirectLoadRecordAccumulator<S, StreamKey>,
@Named("recordQueue")
val inputQueue:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
val inputQueue: PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
@Named("batchStateUpdateQueue") val batchQueue: QueueWriter<BatchUpdate>,
@Value("\${airbyte.destination.core.record-batch-size-override:null}")
val batchSizeOverride: Long? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ data class DirectLoadAccResult(override val state: Batch.State) : WithBatchState
class DirectLoadRecordAccumulator<S : DirectLoader, K : WithStream>(
val directLoaderFactory: DirectLoaderFactory<S>
) : BatchAccumulator<S, K, DestinationRecordAirbyteValue, DirectLoadAccResult> {
override fun start(key: K, part: Int): S {
override suspend fun start(key: K, part: Int): S {
return directLoaderFactory.create(key.stream, part)
}

override fun accept(
record: DestinationRecordAirbyteValue,
override suspend fun accept(
input: DestinationRecordAirbyteValue,
state: S
): Pair<S, DirectLoadAccResult?> {
state.accept(record).let {
): Pair<S?, DirectLoadAccResult?> {
state.accept(input).let {
return when (it) {
is Incomplete -> Pair(state, null)
is Complete -> Pair(state, DirectLoadAccResult(Batch.State.COMPLETE))
is Complete -> Pair(null, DirectLoadAccResult(Batch.State.COMPLETE))
}
}
}

override fun finish(state: S): DirectLoadAccResult {
override suspend fun finish(state: S): DirectLoadAccResult {
state.finish()
return DirectLoadAccResult(Batch.State.COMPLETE)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ interface InputPartitioner {
fun getPartition(record: DestinationRecordAirbyteValue, numParts: Int): Int
}

/**
* The default input partitioner, which partitions by the stream name. TODO: Should be round-robin?
*/
@Singleton
@Secondary
class ByStreamInputPartitioner : InputPartitioner {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.pipeline

import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import kotlin.math.abs
import kotlin.random.Random

/**
* Declare a singleton of this type to have input distributed evenly across the input partitions.
* (The default is to [ByStreamInputPartitioner].)
*/
open class RoundRobinInputPartitioner(private val rotateEveryNRecords: Int = 10_000) :
InputPartitioner {
private var nextPartition =
Random(System.currentTimeMillis()).nextInt(Int.MAX_VALUE / rotateEveryNRecords) *
rotateEveryNRecords

override fun getPartition(record: DestinationRecordAirbyteValue, numParts: Int): Int {
val part = nextPartition++ / rotateEveryNRecords
return if (part == Int.MIN_VALUE) { // avoid overflow
0
} else {
abs(part) % numParts
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
// New interface shim
@Named("recordQueue")
private val recordQueueForPipeline:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
@Named("batchStateUpdateQueue") private val batchUpdateQueue: ChannelMessageQueue<BatchUpdate>,
private val loadPipeline: LoadPipeline?,
private val partitioner: InputPartitioner,
Expand Down Expand Up @@ -197,23 +197,6 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
}

override suspend fun run() {
// Start the input consumer ASAP
log.info { "Starting input consumer task" }
val inputConsumerTask =
inputConsumerTaskFactory.make(
catalog = catalog,
inputFlow = inputFlow,
recordQueueSupplier = recordQueueSupplier,
checkpointQueue = checkpointQueue,
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
recordQueueForPipeline = recordQueueForPipeline,
loadPipeline = loadPipeline,
partitioner = partitioner,
openStreamQueue = openStreamQueue,
)
launch(inputConsumerTask)

// Launch the client interface setup task
log.info { "Starting startup task" }
val setupTask = setupTaskFactory.make(this)
Expand All @@ -225,12 +208,29 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
}

if (loadPipeline != null) {
log.info { "Setting up load pipeline" }
loadPipeline.start { launch(it) }
log.info { "Setup load pipeline" }
loadPipeline.start { task -> launch(task, withExceptionHandling = true) }
log.info { "Launching update batch task" }
val updateBatchTask = updateBatchTaskFactory.make(this)
launch(updateBatchTask)
} else {
// Start the input consumer ASAP
log.info { "Starting input consumer task" }
val inputConsumerTask =
inputConsumerTaskFactory.make(
catalog = catalog,
inputFlow = inputFlow,
recordQueueSupplier = recordQueueSupplier,
checkpointQueue = checkpointQueue,
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
recordQueueForPipeline = recordQueueForPipeline,
loadPipeline = loadPipeline,
partitioner = partitioner,
openStreamQueue = openStreamQueue,
)
launch(inputConsumerTask)

// TODO: pluggable file transfer
if (!fileTransferEnabled) {
// Start a spill-to-disk task for each record stream
Expand Down Expand Up @@ -289,6 +289,26 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
catalog.streams.forEach { openStreamQueue.publish(it) }
log.info { "Closing open stream queue" }
openStreamQueue.close()
} else {
// When the pipeline is enabled, input consuming for
// each stream will wait on stream start to complete,
// but not on setup. This is the simplest way to make
// it do that.
log.info { "Setup complete, starting input consumer task" }
val inputConsumerTask =
inputConsumerTaskFactory.make(
catalog = catalog,
inputFlow = inputFlow,
recordQueueSupplier = recordQueueSupplier,
checkpointQueue = checkpointQueue,
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
recordQueueForPipeline = recordQueueForPipeline,
loadPipeline = loadPipeline,
partitioner = partitioner,
openStreamQueue = openStreamQueue,
)
launch(inputConsumerTask)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class DefaultInputConsumerTask(
// Required by new interface
@Named("recordQueue")
private val recordQueueForPipeline:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
private val loadPipeline: LoadPipeline? = null,
private val partitioner: InputPartitioner,
private val openStreamQueue: QueueWriter<DestinationStream>
Expand Down Expand Up @@ -165,20 +165,20 @@ class DefaultInputConsumerTask(
mapOf(manager.getCurrentCheckpointId() to 1),
StreamKey(stream),
record
)
) { reserved.release() }
val partition = partitioner.getPartition(record, recordQueueForPipeline.partitions)
recordQueueForPipeline.publish(reserved.replace(pipelineMessage), partition)
recordQueueForPipeline.publish(pipelineMessage, partition)
}
is DestinationRecordStreamComplete -> {
manager.markEndOfStream(true)
log.info { "Read COMPLETE for stream $stream" }
recordQueueForPipeline.broadcast(reserved.replace(PipelineEndOfStream(stream)))
recordQueueForPipeline.broadcast(PipelineEndOfStream(stream))
reserved.release()
}
is DestinationRecordStreamIncomplete -> {
manager.markEndOfStream(false)
log.info { "Read INCOMPLETE for stream $stream" }
recordQueueForPipeline.broadcast(reserved.replace(PipelineEndOfStream(stream)))
recordQueueForPipeline.broadcast(PipelineEndOfStream(stream))
reserved.release()
}
is DestinationFile -> {
Expand Down Expand Up @@ -310,7 +310,7 @@ interface InputConsumerTaskFactory {

// Required by new interface
recordQueueForPipeline:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
loadPipeline: LoadPipeline?,
partitioner: InputPartitioner,
openStreamQueue: QueueWriter<DestinationStream>,
Expand All @@ -333,7 +333,7 @@ class DefaultInputConsumerTaskFactory(

// Required by new interface
recordQueueForPipeline:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
loadPipeline: LoadPipeline?,
partitioner: InputPartitioner,
openStreamQueue: QueueWriter<DestinationStream>,
Expand Down
Loading
Loading