Skip to content

Commit c984de6

Browse files
committed
Re-fix
1 parent 2c3cce0 commit c984de6

5 files changed

+47
-10
lines changed

cpp/src/arrow/acero/aggregate_internal.cc

+2-5
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,11 @@ void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
177177
*ss << ']';
178178
}
179179

180-
Status ExtractSegmenterValues(std::vector<Datum>* values_ptr,
181-
const ExecBatch& input_batch,
180+
Status ExtractSegmenterValues(std::vector<Datum>& values, const ExecBatch& input_batch,
182181
const std::vector<int>& field_ids) {
182+
DCHECK_EQ(values.size(), field_ids.size());
183183
DCHECK_GT(input_batch.length, 0);
184-
std::vector<Datum>& values = *values_ptr;
185184
int64_t row = input_batch.length - 1;
186-
values.clear();
187-
values.resize(field_ids.size());
188185
for (size_t i = 0; i < field_ids.size(); i++) {
189186
const Datum& value = input_batch.values[field_ids[i]];
190187
if (value.is_scalar()) {

cpp/src/arrow/acero/aggregate_internal.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,10 @@ Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch,
143143
}
144144

145145
/// @brief Extract values of segment keys from a segment batch
146-
/// @param[out] values_ptr Vector to store the extracted segment key values
146+
/// @param[out] values Vector to store the extracted segment key values
147147
/// @param[in] input_batch Segment batch. Must have the a constant value for segment key
148148
/// @param[in] field_ids Segment key field ids
149-
Status ExtractSegmenterValues(std::vector<Datum>* values_ptr,
150-
const ExecBatch& input_batch,
149+
Status ExtractSegmenterValues(std::vector<Datum>& values, const ExecBatch& input_batch,
151150
const std::vector<int>& field_ids);
152151

153152
Result<std::vector<Datum>> ExtractValues(const ExecBatch& input_batch,
@@ -171,6 +170,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
171170
TracedNode(this),
172171
segmenter_(std::move(segmenter)),
173172
segment_field_ids_(std::move(segment_field_ids)),
173+
segmenter_values_(segment_field_ids.size()),
174174
target_fieldsets_(std::move(target_fieldsets)),
175175
aggs_(std::move(aggs)),
176176
kernels_(std::move(kernels)),
@@ -249,6 +249,7 @@ class GroupByNode : public ExecNode, public TracedNode {
249249
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)),
250250
TracedNode(this),
251251
segmenter_(std::move(segmenter)),
252+
segmenter_values_(segment_key_field_ids.size()),
252253
key_field_ids_(std::move(key_field_ids)),
253254
segment_key_field_ids_(std::move(segment_key_field_ids)),
254255
agg_src_types_(std::move(agg_src_types)),

cpp/src/arrow/acero/aggregate_node_test.cc

+39
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,26 @@ TEST(GroupByNode, NoSkipNulls) {
213213
AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch}, out_batches.batches);
214214
}
215215

216+
TEST(GroupByNode, BasicParallel) {
217+
const int64_t num_batches = 8;
218+
219+
std::vector<ExecBatch> batches(num_batches, ExecBatchFromJSON({int32()}, "[[42]]"));
220+
221+
Declaration plan = Declaration::Sequence(
222+
{{"exec_batch_source",
223+
ExecBatchSourceNodeOptions(schema({field("key", int32())}), batches)},
224+
{"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_count_all", "count(*)"}},
225+
/*keys=*/{"key"}}}});
226+
227+
ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches,
228+
DeclarationToExecBatches(plan));
229+
230+
ExecBatch expected_batch = ExecBatchFromJSON(
231+
{int32(), int64()}, "[[42, " + std::to_string(num_batches) + "]]");
232+
AssertExecBatchesEqualIgnoringOrder(out_batches.schema, {expected_batch},
233+
out_batches.batches);
234+
}
235+
216236
TEST(ScalarAggregateNode, AnyAll) {
217237
// GH-43768: boolean_any and boolean_all with constant input should work well
218238
// when min_count != 0.
@@ -265,5 +285,24 @@ TEST(ScalarAggregateNode, AnyAll) {
265285
}
266286
}
267287

288+
TEST(ScalarAggregateNode, BasicParallel) {
289+
const int64_t num_batches = 8;
290+
291+
std::vector<ExecBatch> batches(num_batches, ExecBatchFromJSON({int32()}, "[[42]]"));
292+
293+
Declaration plan = Declaration::Sequence(
294+
{{"exec_batch_source",
295+
ExecBatchSourceNodeOptions(schema({field("", int32())}), batches)},
296+
{"aggregate", AggregateNodeOptions{/*aggregates=*/{{"count_all", "count(*)"}}}}});
297+
298+
ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches,
299+
DeclarationToExecBatches(plan));
300+
301+
ExecBatch expected_batch =
302+
ExecBatchFromJSON({int64()}, "[[" + std::to_string(num_batches) + "]]");
303+
AssertExecBatchesEqualIgnoringOrder(out_batches.schema, {expected_batch},
304+
out_batches.batches);
305+
}
306+
268307
} // namespace acero
269308
} // namespace arrow

cpp/src/arrow/acero/groupby_aggregate_node.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ Status GroupByNode::InputReceived(ExecNode* input, ExecBatch batch) {
380380
auto batch = ExecSpan(exec_batch);
381381
RETURN_NOT_OK(Consume(batch));
382382
RETURN_NOT_OK(
383-
ExtractSegmenterValues(&segmenter_values_, exec_batch, segment_key_field_ids_));
383+
ExtractSegmenterValues(segmenter_values_, exec_batch, segment_key_field_ids_));
384384
if (!segment.is_open) RETURN_NOT_OK(OutputResult(/*is_last=*/false));
385385
return Status::OK();
386386
};

cpp/src/arrow/acero/scalar_aggregate_node.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ Status ScalarAggregateNode::InputReceived(ExecNode* input, ExecBatch batch) {
241241
auto exec_batch = full_batch.Slice(segment.offset, segment.length);
242242
RETURN_NOT_OK(DoConsume(ExecSpan(exec_batch), thread_index));
243243
RETURN_NOT_OK(
244-
ExtractSegmenterValues(&segmenter_values_, exec_batch, segment_field_ids_));
244+
ExtractSegmenterValues(segmenter_values_, exec_batch, segment_field_ids_));
245245

246246
// If the segment closes the current segment group, we can output segment group
247247
// aggregation.

0 commit comments

Comments
 (0)