diff --git a/cpp/src/arrow/acero/aggregate_internal.cc b/cpp/src/arrow/acero/aggregate_internal.cc index 0c1bc3db365a6..68cbc8549c516 100644 --- a/cpp/src/arrow/acero/aggregate_internal.cc +++ b/cpp/src/arrow/acero/aggregate_internal.cc @@ -177,14 +177,11 @@ void AggregatesToString(std::stringstream* ss, const Schema& input_schema, *ss << ']'; } -Status ExtractSegmenterValues(std::vector* values_ptr, - const ExecBatch& input_batch, +Status ExtractSegmenterValues(std::vector& values, const ExecBatch& input_batch, const std::vector& field_ids) { + DCHECK_EQ(values.size(), field_ids.size()); DCHECK_GT(input_batch.length, 0); - std::vector& values = *values_ptr; int64_t row = input_batch.length - 1; - values.clear(); - values.resize(field_ids.size()); for (size_t i = 0; i < field_ids.size(); i++) { const Datum& value = input_batch.values[field_ids[i]]; if (value.is_scalar()) { diff --git a/cpp/src/arrow/acero/aggregate_internal.h b/cpp/src/arrow/acero/aggregate_internal.h index 7cdc424cbb76b..94622f9149059 100644 --- a/cpp/src/arrow/acero/aggregate_internal.h +++ b/cpp/src/arrow/acero/aggregate_internal.h @@ -143,11 +143,10 @@ Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch, } /// @brief Extract values of segment keys from a segment batch -/// @param[out] values_ptr Vector to store the extracted segment key values +/// @param[out] values Vector to store the extracted segment key values /// @param[in] input_batch Segment batch. Must have the a constant value for segment key /// @param[in] field_ids Segment key field ids -Status ExtractSegmenterValues(std::vector* values_ptr, - const ExecBatch& input_batch, +Status ExtractSegmenterValues(std::vector& values, const ExecBatch& input_batch, const std::vector& field_ids); Result> ExtractValues(const ExecBatch& input_batch, @@ -171,6 +170,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { TracedNode(this), segmenter_(std::move(segmenter)), segment_field_ids_(std::move(segment_field_ids)), + segmenter_values_(segment_field_ids_.size()), target_fieldsets_(std::move(target_fieldsets)), aggs_(std::move(aggs)), kernels_(std::move(kernels)), @@ -249,6 +249,7 @@ class GroupByNode : public ExecNode, public TracedNode { : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)), TracedNode(this), segmenter_(std::move(segmenter)), + segmenter_values_(segment_key_field_ids.size()), key_field_ids_(std::move(key_field_ids)), segment_key_field_ids_(std::move(segment_key_field_ids)), agg_src_types_(std::move(agg_src_types)), diff --git a/cpp/src/arrow/acero/aggregate_node_test.cc b/cpp/src/arrow/acero/aggregate_node_test.cc index f980496d527d1..73476a681e186 100644 --- a/cpp/src/arrow/acero/aggregate_node_test.cc +++ b/cpp/src/arrow/acero/aggregate_node_test.cc @@ -213,6 +213,26 @@ TEST(GroupByNode, NoSkipNulls) { AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch}, out_batches.batches); } +TEST(GroupByNode, BasicParallel) { + const int64_t num_batches = 8; + + std::vector batches(num_batches, ExecBatchFromJSON({int32()}, "[[42]]")); + + Declaration plan = Declaration::Sequence( + {{"exec_batch_source", + ExecBatchSourceNodeOptions(schema({field("key", int32())}), batches)}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"hash_count_all", "count(*)"}}, + /*keys=*/{"key"}}}}); + + ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches, + DeclarationToExecBatches(plan)); + + ExecBatch expected_batch = ExecBatchFromJSON( + {int32(), int64()}, "[[42, " + std::to_string(num_batches) + "]]"); + AssertExecBatchesEqualIgnoringOrder(out_batches.schema, {expected_batch}, + out_batches.batches); +} + TEST(ScalarAggregateNode, AnyAll) { // GH-43768: boolean_any and boolean_all with constant input should work well // when min_count != 0. @@ -265,5 +285,24 @@ TEST(ScalarAggregateNode, AnyAll) { } } +TEST(ScalarAggregateNode, BasicParallel) { + const int64_t num_batches = 8; + + std::vector batches(num_batches, ExecBatchFromJSON({int32()}, "[[42]]")); + + Declaration plan = Declaration::Sequence( + {{"exec_batch_source", + ExecBatchSourceNodeOptions(schema({field("", int32())}), batches)}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{{"count_all", "count(*)"}}}}}); + + ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches, + DeclarationToExecBatches(plan)); + + ExecBatch expected_batch = + ExecBatchFromJSON({int64()}, "[[" + std::to_string(num_batches) + "]]"); + AssertExecBatchesEqualIgnoringOrder(out_batches.schema, {expected_batch}, + out_batches.batches); +} + } // namespace acero } // namespace arrow diff --git a/cpp/src/arrow/acero/groupby_aggregate_node.cc b/cpp/src/arrow/acero/groupby_aggregate_node.cc index 2beef360b45d4..48af6d90e09d9 100644 --- a/cpp/src/arrow/acero/groupby_aggregate_node.cc +++ b/cpp/src/arrow/acero/groupby_aggregate_node.cc @@ -380,7 +380,7 @@ Status GroupByNode::InputReceived(ExecNode* input, ExecBatch batch) { auto batch = ExecSpan(exec_batch); RETURN_NOT_OK(Consume(batch)); RETURN_NOT_OK( - ExtractSegmenterValues(&segmenter_values_, exec_batch, segment_key_field_ids_)); + ExtractSegmenterValues(segmenter_values_, exec_batch, segment_key_field_ids_)); if (!segment.is_open) RETURN_NOT_OK(OutputResult(/*is_last=*/false)); return Status::OK(); }; diff --git a/cpp/src/arrow/acero/scalar_aggregate_node.cc b/cpp/src/arrow/acero/scalar_aggregate_node.cc index b34f7511cc12b..eb1cf04022aaf 100644 --- a/cpp/src/arrow/acero/scalar_aggregate_node.cc +++ b/cpp/src/arrow/acero/scalar_aggregate_node.cc @@ -241,7 +241,7 @@ Status ScalarAggregateNode::InputReceived(ExecNode* input, ExecBatch batch) { auto exec_batch = full_batch.Slice(segment.offset, segment.length); RETURN_NOT_OK(DoConsume(ExecSpan(exec_batch), thread_index)); RETURN_NOT_OK( - ExtractSegmenterValues(&segmenter_values_, exec_batch, segment_field_ids_)); + ExtractSegmenterValues(segmenter_values_, exec_batch, segment_field_ids_)); // If the segment closes the current segment group, we can output segment group // aggregation.