@@ -18,7 +18,6 @@ import io.airbyte.cdk.load.pipeline.BatchUpdate
18
18
import io.airbyte.cdk.load.pipeline.OutputPartitioner
19
19
import io.airbyte.cdk.load.pipeline.PipelineFlushStrategy
20
20
import io.airbyte.cdk.load.state.CheckpointId
21
- import io.airbyte.cdk.load.state.Reserved
22
21
import io.airbyte.cdk.load.task.OnEndOfSync
23
22
import io.airbyte.cdk.load.task.Task
24
23
import io.airbyte.cdk.load.task.TerminalCondition
@@ -34,7 +33,7 @@ data class RangeState<S>(
34
33
/* * A long-running task that actually implements a load pipeline step. */
35
34
class LoadPipelineStepTask <S : AutoCloseable , K1 : WithStream , T , K2 : WithStream , U : Any >(
36
35
private val batchAccumulator : BatchAccumulator <S , K1 , T , U >,
37
- private val inputFlow : Flow <Reserved < PipelineEvent <K1 , T > >>,
36
+ private val inputFlow : Flow <PipelineEvent <K1 , T >>,
38
37
private val batchUpdateQueue : QueueWriter <BatchUpdate >,
39
38
private val outputPartitioner : OutputPartitioner <K1 , T , K2 , U >? ,
40
39
private val outputQueue : PartitionedQueue <PipelineEvent <K2 , U >>? ,
@@ -44,11 +43,11 @@ class LoadPipelineStepTask<S : AutoCloseable, K1 : WithStream, T, K2 : WithStrea
44
43
override val terminalCondition: TerminalCondition = OnEndOfSync
45
44
46
45
override suspend fun execute () {
47
- inputFlow.fold(mutableMapOf<K1 , RangeState <S >>()) { stateStore, reservation ->
46
+ inputFlow.fold(mutableMapOf<K1 , RangeState <S >>()) { stateStore, input ->
48
47
try {
49
- when (val input = reservation.value ) {
48
+ when (input) {
50
49
is PipelineMessage -> {
51
- // Fetch and update the local state associated with the current batch .
50
+ // Get or create the accumulator state associated w/ the input key .
52
51
val state =
53
52
stateStore
54
53
.getOrPut(input.key) {
@@ -57,43 +56,73 @@ class LoadPipelineStepTask<S : AutoCloseable, K1 : WithStream, T, K2 : WithStrea
57
56
)
58
57
}
59
58
.let { it.copy(inputCount = it.inputCount + 1 ) }
60
- val (newState, output) =
59
+
60
+ // Accumulate the input and get the new state and output.
61
+ val (newStateMaybe, outputMaybe) =
61
62
batchAccumulator.accept(
62
63
input.value,
63
64
state.state,
64
65
)
65
- reservation.release() // TODO: Accumulate and release when persisted
66
+ /* * TODO: Make this impossible at the return type level */
67
+ if (newStateMaybe == null && outputMaybe == null ) {
68
+ throw IllegalStateException (
69
+ " BatchAccumulator must return a new state or an output"
70
+ )
71
+ }
72
+
73
+ // Update bookkeeping metadata
74
+ input
75
+ .postProcessingCallback() // TODO: Accumulate and release when persisted
66
76
input.checkpointCounts.forEach {
67
77
state.checkpointCounts.merge(it.key, it.value) { old, new -> old + new }
68
78
}
69
79
70
- // If the accumulator did not produce a result, check if we should flush.
71
- // If so, use the result of a finish call as the output.
72
- val finalOutput =
73
- output
74
- ? : if (flushStrategy?.shouldFlush(state.inputCount) == true ) {
75
- batchAccumulator.finish(newState)
80
+ // Finalize the state and output
81
+ val (finalState, finalOutput) =
82
+ if (outputMaybe == null ) {
83
+ // Possibly force an output (and if so, discard the state)
84
+ if (flushStrategy?.shouldFlush(state.inputCount) == true ) {
85
+ val finalOutput = batchAccumulator.finish(newStateMaybe!! )
86
+ Pair (null , finalOutput)
76
87
} else {
77
- null
88
+ Pair (newStateMaybe, null )
78
89
}
90
+ } else {
91
+ // Otherwise, just use what we were given
92
+ Pair (newStateMaybe, outputMaybe)
93
+ }
79
94
80
- if (finalOutput != null ) {
81
- // Publish the emitted output and evict the state.
82
- handleOutput(input.key, state.checkpointCounts, finalOutput)
83
- stateStore.remove(input.key)
95
+ // Publish the output if there is one & reset the input count
96
+ val inputCount =
97
+ if (finalOutput != null ) {
98
+ // Publish the emitted output and evict the state.
99
+ handleOutput(input.key, state.checkpointCounts, finalOutput)
100
+ state.checkpointCounts.clear()
101
+ 0
102
+ } else {
103
+ state.inputCount
104
+ }
105
+
106
+ // Update the state if `accept` returned a new state, otherwise evict.
107
+ if (finalState != null ) {
108
+ // If accept returned a new state, update the state store.
109
+ stateStore[input.key] =
110
+ state.copy(state = finalState, inputCount = inputCount)
84
111
} else {
85
- // If there's no output yet, just update the local state.
86
- stateStore[input.key] = RangeState (newState, state.checkpointCounts)
112
+ stateStore.remove(input.key)
87
113
}
114
+
88
115
stateStore
89
116
}
90
117
is PipelineEndOfStream -> {
91
118
// Give any key associated with the stream a chance to finish
92
119
val keysToRemove = stateStore.keys.filter { it.stream == input.stream }
93
120
keysToRemove.forEach { key ->
94
121
stateStore.remove(key)?.let { stored ->
95
- val output = batchAccumulator.finish(stored.state)
96
- handleOutput(key, stored.checkpointCounts, output)
122
+ if (stored.inputCount > 0 ) {
123
+ val output = batchAccumulator.finish(stored.state)
124
+ handleOutput(key, stored.checkpointCounts, output)
125
+ }
97
126
}
98
127
}
99
128
@@ -122,7 +151,7 @@ class LoadPipelineStepTask<S : AutoCloseable, K1 : WithStream, T, K2 : WithStrea
122
151
// Only publish the output if there's a next step.
123
152
outputQueue?.let {
124
153
val outputKey = outputPartitioner!! .getOutputKey(inputKey, output)
125
- val message = PipelineMessage (checkpointCounts, outputKey, output)
154
+ val message = PipelineMessage (checkpointCounts.toMap() , outputKey, output)
126
155
val outputPart = outputPartitioner.getPart(outputKey, it.partitions)
127
156
it.publish(message, outputPart)
128
157
}
@@ -132,7 +161,7 @@ class LoadPipelineStepTask<S : AutoCloseable, K1 : WithStream, T, K2 : WithStrea
132
161
val update =
133
162
BatchStateUpdate (
134
163
stream = inputKey.stream,
135
- checkpointCounts = checkpointCounts,
164
+ checkpointCounts = checkpointCounts.toMap() ,
136
165
state = output.state
137
166
)
138
167
batchUpdateQueue.publish(update)
0 commit comments