Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45821: [C++][Compute] Grouper improvements #45822

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/key_map_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace compute {
//
// A detailed explanation of this data structure (including concepts such as blocks,
// slots, stamps) and operations provided by this class is given in the document:
// arrow/compute/exec/doc/key_map.md.
// arrow/acero/doc/key_map.md.
//
class ARROW_EXPORT SwissTable {
friend class SwissTableMerge;
Expand Down
208 changes: 158 additions & 50 deletions cpp/src/arrow/compute/row/grouper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "arrow/compute/row/grouper.h"

#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -318,7 +319,7 @@ Result<std::unique_ptr<RowSegmenter>> RowSegmenter::Make(

namespace {

Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t& consume_offset,
Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t consume_offset,
int64_t* consume_length) {
if (consume_offset < 0) {
return Status::Invalid("invalid grouper consume offset: ", consume_offset);
Expand All @@ -329,6 +330,8 @@ Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t& consume_offset
return Status::OK();
}

enum class GrouperMode { kPopulate, kConsume, kLookup };

struct GrouperImpl : public Grouper {
static Result<std::unique_ptr<GrouperImpl>> Make(
const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
Expand Down Expand Up @@ -388,11 +391,60 @@ struct GrouperImpl : public Grouper {
return Status::OK();
}

Status Populate(const ExecSpan& batch, int64_t offset, int64_t length) override {
return ConsumeImpl(batch, offset, length, GrouperMode::kPopulate).status();
}

Result<Datum> Consume(const ExecSpan& batch, int64_t offset, int64_t length) override {
return ConsumeImpl(batch, offset, length, GrouperMode::kConsume);
}

Result<Datum> Lookup(const ExecSpan& batch, int64_t offset, int64_t length) override {
return ConsumeImpl(batch, offset, length, GrouperMode::kLookup);
}

template <typename VisitGroupFunc, typename VisitUnknownGroupFunc>
void VisitKeys(int64_t length, const int32_t* key_offsets, const uint8_t* key_data,
bool insert_new_keys, VisitGroupFunc&& visit_group,
VisitUnknownGroupFunc&& visit_unknown_group) {
for (int64_t i = 0; i < length; ++i) {
const int32_t key_length = key_offsets[i + 1] - key_offsets[i];
const uint8_t* key_ptr = key_data + key_offsets[i];
std::string key(reinterpret_cast<const char*>(key_ptr), key_length);

uint32_t group_id;
if (insert_new_keys) {
const auto [it, inserted] = map_.emplace(std::move(key), num_groups_);
if (inserted) {
// New key: update offsets and key_bytes
++num_groups_;
if (key_length > 0) {
const auto next_key_offset = static_cast<int32_t>(key_bytes_.size());
key_bytes_.resize(next_key_offset + key_length);
offsets_.push_back(next_key_offset + key_length);
memcpy(key_bytes_.data() + next_key_offset, key_ptr, key_length);
}
}
group_id = it->second;
} else {
const auto it = map_.find(std::move(key));
if (it == map_.end()) {
// Key not found
visit_unknown_group();
continue;
}
group_id = it->second;
}
visit_group(group_id);
}
}

Result<Datum> ConsumeImpl(const ExecSpan& batch, int64_t offset, int64_t length,
GrouperMode mode) {
ARROW_RETURN_NOT_OK(CheckAndCapLengthForConsume(batch.length, offset, &length));
if (offset != 0 || length != batch.length) {
auto batch_slice = batch.ToExecBatch().Slice(offset, length);
return Consume(ExecSpan(batch_slice), 0, -1);
return ConsumeImpl(ExecSpan(batch_slice), 0, -1, mode);
}
std::vector<int32_t> offsets_batch(batch.length + 1);
for (int i = 0; i < batch.num_values(); ++i) {
Expand All @@ -417,35 +469,50 @@ struct GrouperImpl : public Grouper {
RETURN_NOT_OK(encoders_[i]->Encode(batch[i], batch.length, key_buf_ptrs.data()));
}

if (mode == GrouperMode::kPopulate) {
VisitKeys(
batch.length, offsets_batch.data(), key_bytes_batch.data(),
/*insert_new_keys=*/true,
/*visit_group=*/[](...) {},
/*visit_unknown_group=*/[] {});
return Datum();
}

TypedBufferBuilder<uint32_t> group_ids_batch(ctx_->memory_pool());
RETURN_NOT_OK(group_ids_batch.Resize(batch.length));
std::shared_ptr<Buffer> null_bitmap;

for (int64_t i = 0; i < batch.length; ++i) {
int32_t key_length = offsets_batch[i + 1] - offsets_batch[i];
std::string key(
reinterpret_cast<const char*>(key_bytes_batch.data() + offsets_batch[i]),
key_length);

auto it_success = map_.emplace(key, num_groups_);
auto group_id = it_success.first->second;

if (it_success.second) {
// new key; update offsets and key_bytes
++num_groups_;
// Skip if there are no keys
if (key_length > 0) {
auto next_key_offset = static_cast<int32_t>(key_bytes_.size());
key_bytes_.resize(next_key_offset + key_length);
offsets_.push_back(next_key_offset + key_length);
memcpy(key_bytes_.data() + next_key_offset, key.c_str(), key_length);
}
}
if (mode == GrouperMode::kConsume) {
auto visit_group = [&](uint32_t group_id) {
group_ids_batch.UnsafeAppend(group_id);
};
auto visit_unknown_group = [] {};

group_ids_batch.UnsafeAppend(group_id);
}
VisitKeys(batch.length, offsets_batch.data(), key_bytes_batch.data(),
/*insert_new_keys=*/true, visit_group, visit_unknown_group);
} else {
DCHECK_EQ(mode, GrouperMode::kLookup);

// Create a null bitmap to indicate which keys were found.
TypedBufferBuilder<bool> null_bitmap_builder(ctx_->memory_pool());
RETURN_NOT_OK(null_bitmap_builder.Resize(batch.length));

auto visit_group = [&](uint32_t group_id) {
group_ids_batch.UnsafeAppend(group_id);
null_bitmap_builder.UnsafeAppend(true);
};
auto visit_unknown_group = [&] {
group_ids_batch.UnsafeAppend(0); // any defined value really
null_bitmap_builder.UnsafeAppend(false);
};

VisitKeys(batch.length, offsets_batch.data(), key_bytes_batch.data(),
/*insert_new_keys=*/false, visit_group, visit_unknown_group);

ARROW_ASSIGN_OR_RAISE(null_bitmap, null_bitmap_builder.Finish());
}
ARROW_ASSIGN_OR_RAISE(auto group_ids, group_ids_batch.Finish());
return Datum(UInt32Array(batch.length, std::move(group_ids)));
return Datum(UInt32Array(batch.length, std::move(group_ids), std::move(null_bitmap)));
}

uint32_t num_groups() const override { return num_groups_; }
Expand All @@ -470,6 +537,7 @@ struct GrouperImpl : public Grouper {
}

ExecContext* ctx_;
// TODO We could use std::string_view since the keys are copied in key_bytes_.
std::unordered_map<std::string, uint32_t> map_;
std::vector<int32_t> offsets_ = {0};
std::vector<uint8_t> key_bytes_;
Expand Down Expand Up @@ -577,11 +645,24 @@ struct GrouperFastImpl : public Grouper {
return Status::OK();
}

Status Populate(const ExecSpan& batch, int64_t offset, int64_t length) override {
return ConsumeImpl(batch, offset, length, GrouperMode::kPopulate).status();
}

Result<Datum> Consume(const ExecSpan& batch, int64_t offset, int64_t length) override {
return ConsumeImpl(batch, offset, length, GrouperMode::kConsume);
}

Result<Datum> Lookup(const ExecSpan& batch, int64_t offset, int64_t length) override {
return ConsumeImpl(batch, offset, length, GrouperMode::kLookup);
}

Result<Datum> ConsumeImpl(const ExecSpan& batch, int64_t offset, int64_t length,
GrouperMode mode) {
ARROW_RETURN_NOT_OK(CheckAndCapLengthForConsume(batch.length, offset, &length));
if (offset != 0 || length != batch.length) {
auto batch_slice = batch.ToExecBatch().Slice(offset, length);
return Consume(ExecSpan(batch_slice), 0, -1);
return ConsumeImpl(ExecSpan(batch_slice), 0, -1, mode);
}
// ARROW-14027: broadcast scalar arguments for now
for (int i = 0; i < batch.num_values(); i++) {
Expand All @@ -595,13 +676,13 @@ struct GrouperFastImpl : public Grouper {
ctx_->memory_pool()));
}
}
return ConsumeImpl(ExecSpan(expanded));
return ConsumeImpl(ExecSpan(expanded), mode);
}
}
return ConsumeImpl(batch);
return ConsumeImpl(batch, mode);
}

Result<Datum> ConsumeImpl(const ExecSpan& batch) {
Result<Datum> ConsumeImpl(const ExecSpan& batch, GrouperMode mode) {
int64_t num_rows = batch.length;
int num_columns = batch.num_values();
// Process dictionaries
Expand All @@ -621,10 +702,6 @@ struct GrouperFastImpl : public Grouper {
}
}

std::shared_ptr<arrow::Buffer> group_ids;
ARROW_ASSIGN_OR_RAISE(
group_ids, AllocateBuffer(sizeof(uint32_t) * num_rows, ctx_->memory_pool()));

for (int icol = 0; icol < num_columns; ++icol) {
const uint8_t* non_nulls = NULLPTR;
const uint8_t* fixedlen = NULLPTR;
Expand All @@ -649,11 +726,32 @@ struct GrouperFastImpl : public Grouper {
cols_[icol] = col_base.Slice(offset, num_rows);
}

std::shared_ptr<arrow::Buffer> group_ids, null_bitmap;
// If we need to return the group ids, then allocate a buffer of group ids
// for all rows, otherwise each minibatch will reuse the same buffer.
const int64_t groups_ids_size =
(mode == GrouperMode::kPopulate) ? minibatch_size_max_ : num_rows;
ARROW_ASSIGN_OR_RAISE(group_ids, AllocateBuffer(sizeof(uint32_t) * groups_ids_size,
ctx_->memory_pool()));
if (mode == GrouperMode::kLookup) {
ARROW_ASSIGN_OR_RAISE(null_bitmap,
AllocateBitmap(groups_ids_size, ctx_->memory_pool()));
}

// Split into smaller mini-batches
//
for (uint32_t start_row = 0; start_row < num_rows;) {
uint32_t batch_size_next = std::min(static_cast<uint32_t>(minibatch_size_),
static_cast<uint32_t>(num_rows) - start_row);
uint32_t* batch_group_ids = group_ids->mutable_data_as<uint32_t>() +
((mode == GrouperMode::kPopulate) ? 0 : start_row);
if (mode == GrouperMode::kLookup) {
// Zero-initialize each mini-batch just before it is partially populated
// in map_.find() below.
// This is potentially more cache-efficient than zeroing the entire buffer
// at once before this loop.
memset(batch_group_ids, 0, batch_size_next * sizeof(uint32_t));
}

// Encode
rows_minibatch_.Clean();
Expand All @@ -672,28 +770,38 @@ struct GrouperFastImpl : public Grouper {
match_bitvector.mutable_data(), local_slots.mutable_data());
map_.find(batch_size_next, minibatch_hashes_.data(),
match_bitvector.mutable_data(), local_slots.mutable_data(),
reinterpret_cast<uint32_t*>(group_ids->mutable_data()) + start_row,
&temp_stack_, map_equal_impl_, nullptr);
batch_group_ids, &temp_stack_, map_equal_impl_, nullptr);
}
if (mode == GrouperMode::kLookup) {
// Fill validity bitmap from match_bitvector
::arrow::internal::CopyBitmap(match_bitvector.mutable_data(), /*offset=*/0,
/*length=*/batch_size_next,
null_bitmap->mutable_data(),
/*dest_offset=*/start_row);
} else {
// Insert new keys
auto ids = util::TempVectorHolder<uint16_t>(&temp_stack_, batch_size_next);
int num_ids;
util::bit_util::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next,
match_bitvector.mutable_data(), &num_ids,
ids.mutable_data());

RETURN_NOT_OK(map_.map_new_keys(
num_ids, ids.mutable_data(), minibatch_hashes_.data(), batch_group_ids,
&temp_stack_, map_equal_impl_, map_append_impl_, nullptr));
}
auto ids = util::TempVectorHolder<uint16_t>(&temp_stack_, batch_size_next);
int num_ids;
util::bit_util::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next,
match_bitvector.mutable_data(), &num_ids,
ids.mutable_data());

RETURN_NOT_OK(map_.map_new_keys(
num_ids, ids.mutable_data(), minibatch_hashes_.data(),
reinterpret_cast<uint32_t*>(group_ids->mutable_data()) + start_row,
&temp_stack_, map_equal_impl_, map_append_impl_, nullptr));

start_row += batch_size_next;

if (minibatch_size_ * 2 <= minibatch_size_max_) {
minibatch_size_ *= 2;
}
// XXX why not use minibatch_size_max_ from the start?
minibatch_size_ = std::min(minibatch_size_max_, 2 * minibatch_size_);
Comment on lines -691 to +796
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zanmato1984 Would you know the answer to this XXX?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't. It doesn't seem to be necessary for either performance or memory profile.

}

return Datum(UInt32Array(batch.length, std::move(group_ids)));
if (mode == GrouperMode::kPopulate) {
return Datum{};
} else {
return Datum(
UInt32Array(batch.length, std::move(group_ids), std::move(null_bitmap)));
}
}

uint32_t num_groups() const override { return static_cast<uint32_t>(rows_.length()); }
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/row/grouper.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ class ARROW_EXPORT Grouper {
virtual Result<Datum> Consume(const ExecSpan& batch, int64_t offset = 0,
int64_t length = -1) = 0;

/// Like Consume, but groups not already encountered emit null instead of
/// generating a new group id.
virtual Result<Datum> Lookup(const ExecSpan& batch, int64_t offset = 0,
int64_t length = -1) = 0;

/// Like Consume, but only populates the Grouper without returning the group ids.
virtual Status Populate(const ExecSpan& batch, int64_t offset = 0,
int64_t length = -1) = 0;

/// Get current unique keys. May be called multiple times.
virtual Result<ExecBatch> GetUniques() = 0;

Expand Down
Loading
Loading