Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: File Loader #55732

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.ChannelMessageQueue
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.MultiProducerChannel
import io.airbyte.cdk.load.message.PartitionedQueue
import io.airbyte.cdk.load.message.PipelineEvent
import io.airbyte.cdk.load.message.StreamKey
import io.airbyte.cdk.load.pipeline.BatchUpdate
import io.airbyte.cdk.load.state.ReservationManager
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.task.implementor.FileAggregateMessage
import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage
import io.airbyte.cdk.load.write.LoadStrategy
Expand Down Expand Up @@ -120,7 +120,22 @@ class SyncBeanFactory {
@Named("recordQueue")
fun recordQueue(
loadStrategy: LoadStrategy? = null,
): PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>> {
): PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>> {
return PartitionedQueue(
Array(loadStrategy?.inputPartitions ?: 1) {
ChannelMessageQueue(Channel(Channel.UNLIMITED))
}
)
}

/**
* Same as recordQueue, but for files.
*/
@Singleton
@Named("fileQueue")
fun fileQueue(
loadStrategy: LoadStrategy? = null,
): PartitionedQueue<PipelineEvent<StreamKey, DestinationFile>> {
return PartitionedQueue(
Array(loadStrategy?.inputPartitions ?: 1) {
ChannelMessageQueue(Channel(Channel.UNLIMITED))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import io.airbyte.protocol.models.v0.AirbyteTraceMessage
import io.micronaut.context.annotation.Value
import jakarta.inject.Singleton
import java.math.BigInteger
import java.nio.MappedByteBuffer
import java.time.OffsetDateTime

/**
Expand Down Expand Up @@ -167,7 +168,7 @@ data class DestinationFile(
override val stream: DestinationStream.Descriptor,
val emittedAtMs: Long,
val serialized: String,
val fileMessage: AirbyteRecordMessageFile
val fileMessage: AirbyteRecordMessageFile,
) : DestinationFileDomainMessage {
/** Convenience constructor, primarily intended for use in tests. */
class AirbyteRecordMessageFile {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ class PartitionedQueue<T>(private val queues: Array<MessageQueue<T>>) : Closeabl
val partitions = queues.size

fun consume(partition: Int): Flow<T> {
if (partition < 0 || partition >= queues.size) {
throw IllegalArgumentException("Invalid partition: $partition")
}
return queues[partition].consume()
// if (partition < 0 || partition >= queues.size) {
// throw IllegalArgumentException("Invalid partition: $partition")
// }
return queues[partition % partitions].consume()
}

suspend fun publish(value: T, partition: Int) {
if (partition < 0 || partition >= queues.size) {
throw IllegalArgumentException("Invalid partition: $partition")
}
queues[partition].publish(value)
// if (partition < 0 || partition >= queues.size) {
// throw IllegalArgumentException("Invalid partition: $partition")
// }
queues[partition % partitions].publish(value)
}

suspend fun broadcast(value: T) = queues.forEach { it.publish(value) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ import io.airbyte.cdk.load.state.CheckpointId
/** Used internally by the CDK to pass messages between steps in the loader pipeline. */
sealed interface PipelineEvent<K : WithStream, T>

/**
* A message that contains a keyed payload. The key is used to manage the state of the payload's
* corresponding [io.airbyte.cdk.load.pipeline.BatchAccumulator]. [checkpointCounts] is used by the
* CDK to perform state message bookkeeping. [postProcessingCallback] is for releasing resources
* associated with the message.
*/
class PipelineMessage<K : WithStream, T>(
val checkpointCounts: Map<CheckpointId, Long>,
val key: K,
val value: T
val value: T,
val postProcessingCallback: suspend () -> Unit = {},
) : PipelineEvent<K, T>

/**
* We send the end message on the stream and not the key, because there's no way to partition an
* empty message.
*/
/** Broadcast at end-of-stream to all partitions to signal that the stream has ended. */
class PipelineEndOfStream<K : WithStream, T>(val stream: DestinationStream.Descriptor) :
PipelineEvent<K, T>
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,38 @@ package io.airbyte.cdk.load.pipeline
import io.airbyte.cdk.load.message.WithStream

/**
* [BatchAccumulator] is used internally by the CDK to implement RecordLoaders. Connector devs
* should never need to implement this interface.
* [BatchAccumulator] is used internally by the CDK to implement
* [io.airbyte.cdk.load.write.LoadStrategy]s. Connector devs should never need to implement this
* interface.
*
* It is the glue that connects a specific step in a specific pipeline to the generic pipeline on
* the back end. (For example, in a three-stage pipeline like bulk load, step 1 is to create a part,
* step 2 is to upload it, and step 3 is to load it from object storage into a table.)
*
* - [S] is a state type that will be threaded through accumulator calls.
* - [K] is a key type associated the input data. (NOTE: Currently, there is no support for
* key-mapping, so the key is always [io.airbyte.cdk.load.message.StreamKey]). Specifically, state
* will always be managed per-key.
* - [T] is the input data type
* - [U] is the output data type
*
* The first time data is seen for a given key, [start] is called (with the partition number). The
* state returned by [start] will be passed per input to [accept].
*
* If [accept] returns a non-null output, that output will be forwarded to the next stage (if
* applicable) and/or trigger bookkeeping (iff the output type implements
* [io.airbyte.cdk.load.message.WithBatchState]).
*
* If [accept] returns a non-null state, that state will be passed to the next call to [accept]. If
* [accept] returns a null state, the state will be discarded and a new one will be created on the
* next input by a new call to [start].
*
* When the input stream is exhausted, [finish] will be called with any remaining state iff at least
* one input was seen for that key. This means that [finish] will not be called on empty keys or on
* keys where the last call to [accept] yielded a null (finished) state.
*/
interface BatchAccumulator<S, K : WithStream, T, U> {
fun start(key: K, part: Int): S
fun accept(record: T, state: S): Pair<S, U?>
fun finish(state: S): U
suspend fun start(key: K, part: Int): S
suspend fun accept(input: T, state: S): Pair<S?, U?>
suspend fun finish(state: S): U
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import io.airbyte.cdk.load.message.PartitionedQueue
import io.airbyte.cdk.load.message.PipelineEvent
import io.airbyte.cdk.load.message.QueueWriter
import io.airbyte.cdk.load.message.StreamKey
import io.airbyte.cdk.load.state.Reserved
import io.airbyte.cdk.load.task.internal.LoadPipelineStepTask
import io.airbyte.cdk.load.write.DirectLoader
import io.airbyte.cdk.load.write.DirectLoaderFactory
Expand All @@ -24,8 +23,7 @@ import jakarta.inject.Singleton
class DirectLoadPipelineStep<S : DirectLoader>(
val accumulator: DirectLoadRecordAccumulator<S, StreamKey>,
@Named("recordQueue")
val inputQueue:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
val inputQueue: PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
@Named("batchStateUpdateQueue") val batchQueue: QueueWriter<BatchUpdate>,
@Value("\${airbyte.destination.core.record-batch-size-override:null}")
val batchSizeOverride: Long? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ data class DirectLoadAccResult(override val state: Batch.State) : WithBatchState
class DirectLoadRecordAccumulator<S : DirectLoader, K : WithStream>(
val directLoaderFactory: DirectLoaderFactory<S>
) : BatchAccumulator<S, K, DestinationRecordAirbyteValue, DirectLoadAccResult> {
override fun start(key: K, part: Int): S {
override suspend fun start(key: K, part: Int): S {
return directLoaderFactory.create(key.stream, part)
}

override fun accept(
record: DestinationRecordAirbyteValue,
override suspend fun accept(
input: DestinationRecordAirbyteValue,
state: S
): Pair<S, DirectLoadAccResult?> {
state.accept(record).let {
): Pair<S?, DirectLoadAccResult?> {
state.accept(input).let {
return when (it) {
is Incomplete -> Pair(state, null)
is Complete -> Pair(state, DirectLoadAccResult(Batch.State.COMPLETE))
is Complete -> Pair(null, DirectLoadAccResult(Batch.State.COMPLETE))
}
}
}

override fun finish(state: S): DirectLoadAccResult {
override suspend fun finish(state: S): DirectLoadAccResult {
state.finish()
return DirectLoadAccResult(Batch.State.COMPLETE)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.airbyte.cdk.load.pipeline

import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import kotlin.math.abs
Expand All @@ -17,6 +18,9 @@ interface InputPartitioner {
fun getPartition(record: DestinationRecordAirbyteValue, numParts: Int): Int
}

/**
* The default input partitioner, which partitions by the stream name. TODO: Should be round-robin?
*/
@Singleton
@Secondary
class ByStreamInputPartitioner : InputPartitioner {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import io.airbyte.cdk.load.task.internal.LoadPipelineStepTask

interface LoadPipelineStep {
val numWorkers: Int
fun taskForPartition(partition: Int): LoadPipelineStepTask<*, *, *, *, *>
fun taskForPartition(partition: Int): Task
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.pipeline

import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
import kotlin.math.abs
import kotlin.random.Random

/**
* Declare a singleton of this type to have input distributed evenly across the input partitions.
* (The default is to [ByStreamInputPartitioner].)
*/
open class RoundRobinInputPartitioner(private val rotateEveryNRecords: Int = 10_000) :
InputPartitioner {
private var nextPartition =
Random(System.currentTimeMillis()).nextInt(Int.MAX_VALUE / rotateEveryNRecords) *
rotateEveryNRecords

override fun getPartition(record: DestinationRecordAirbyteValue, numParts: Int): Int {
val part = nextPartition++ / rotateEveryNRecords
return if (part == Int.MIN_VALUE) { // avoid overflow
0
} else {
abs(part) % numParts
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ class DefaultStreamManager(
if (readCount == 0L) {
return true
}

val completedCount = checkpointCounts.values.sumOf { it.recordsCompleted.get() }
return completedCount == readCount
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.ChannelMessageQueue
import io.airbyte.cdk.load.message.CheckpointMessageWrapped
import io.airbyte.cdk.load.message.DestinationFile
import io.airbyte.cdk.load.message.DestinationRecordAirbyteValue
import io.airbyte.cdk.load.message.DestinationStreamEvent
import io.airbyte.cdk.load.message.MessageQueue
Expand Down Expand Up @@ -145,7 +146,9 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
// New interface shim
@Named("recordQueue")
private val recordQueueForPipeline:
PartitionedQueue<Reserved<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>>,
PartitionedQueue<PipelineEvent<StreamKey, DestinationRecordAirbyteValue>>,
@Named("fileQueue")
private val fileQueueForPipeline: PartitionedQueue<PipelineEvent<StreamKey, DestinationFile>>,
@Named("batchStateUpdateQueue") private val batchUpdateQueue: ChannelMessageQueue<BatchUpdate>,
private val loadPipeline: LoadPipeline?,
private val partitioner: InputPartitioner,
Expand Down Expand Up @@ -201,13 +204,15 @@ class DefaultDestinationTaskLauncher<K : WithStream>(
log.info { "Starting input consumer task" }
val inputConsumerTask =
inputConsumerTaskFactory.make(
config = config,
catalog = catalog,
inputFlow = inputFlow,
recordQueueSupplier = recordQueueSupplier,
checkpointQueue = checkpointQueue,
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
recordQueueForPipeline = recordQueueForPipeline,
fileQueueForPipeline = fileQueueForPipeline,
loadPipeline = loadPipeline,
partitioner = partitioner,
openStreamQueue = openStreamQueue,
Expand Down
Loading
Loading