Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache GroupByHash raw values where appropriate #25294

Merged
merged 2 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,7 @@ public class FlatGroupByHash
// Max (page value count / cumulative dictionary size) to trigger the low cardinality case
private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = 0.25;

public enum HashMode {
// Hash values are pre-computed as input, and emitted as output
PRECOMPUTED,
// Hash values are computed by the FlatGroupByHash instance and stored along with the entry
CACHED,
// Hash values are re-computed for each access on demand
ON_DEMAND;

public boolean isHashPrecomputed()
{
return this == PRECOMPUTED;
}

public boolean isHashCached()
{
return switch (this) {
case PRECOMPUTED, CACHED -> true;
case ON_DEMAND -> false;
};
}
}

private final HashMode hashMode;
private final GroupByHashMode hashMode;
private final FlatHash flatHash;
private final int groupByChannelCount;

Expand All @@ -86,7 +64,7 @@ public boolean isHashCached()

public FlatGroupByHash(
List<Type> hashTypes,
HashMode hashMode,
GroupByHashMode hashMode,
int expectedSize,
boolean processDictionary,
FlatHashStrategyCompiler hashStrategyCompiler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private static int calculateMaxFill(int capacity)
private int nextGroupId;
private int maxFill;

public FlatHash(FlatHashStrategy flatHashStrategy, FlatGroupByHash.HashMode hashMode, int expectedSize, UpdateMemory checkMemoryReservation)
public FlatHash(FlatHashStrategy flatHashStrategy, GroupByHashMode hashMode, int expectedSize, UpdateMemory checkMemoryReservation)
{
this.flatHashStrategy = requireNonNull(flatHashStrategy, "flatHashStrategy is null");
this.checkMemoryReservation = requireNonNull(checkMemoryReservation, "checkMemoryReservation is null");
Expand Down
61 changes: 46 additions & 15 deletions core/trino-main/src/main/java/io/trino/operator/GroupByHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import io.trino.annotation.NotThreadSafe;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;

import java.util.List;
Expand All @@ -32,36 +35,64 @@ static GroupByHash createGroupByHash(
Session session,
List<Type> types,
boolean hasPrecomputedHash,
boolean cacheHashValues,
boolean spillable,
int expectedSize,
FlatHashStrategyCompiler hashStrategyCompiler,
UpdateMemory updateMemory)
{
boolean dictionaryAggregationEnabled = isDictionaryAggregationEnabled(session);
return createGroupByHash(types, hasPrecomputedHash, cacheHashValues, expectedSize, dictionaryAggregationEnabled, hashStrategyCompiler, updateMemory);
return createGroupByHash(
types,
selectGroupByHashMode(hasPrecomputedHash, spillable, types),
expectedSize,
dictionaryAggregationEnabled,
hashStrategyCompiler,
updateMemory);
}

static GroupByHashMode selectGroupByHashMode(boolean hasPrecomputedHash, boolean spillable, List<Type> types)
{
if (hasPrecomputedHash) {
return GroupByHashMode.PRECOMPUTED;
}
// Spillable aggregations should always cache hash values since spilling requires sorting by the hash value
if (spillable) {
return GroupByHashMode.CACHED;
}
// When 3 or more columns are present, always cache the hash value
if (types.size() >= 3) {
return GroupByHashMode.CACHED;
}

int variableWidthTypes = 0;
for (Type type : types) {
// The presence of any container types should trigger hash value caching since computing the hash and
// checking valueIdentical is so much more expensive for these values
if (type instanceof MapType || type instanceof ArrayType || type instanceof RowType) {
return GroupByHashMode.CACHED;
}
// Cache hash values when more than 2 or more variable width types are present
if (type.isFlatVariableWidth()) {
variableWidthTypes++;
if (variableWidthTypes >= 2) {
return GroupByHashMode.CACHED;
}
}
}
// All remaining scenarios will use on-demand hashing
return GroupByHashMode.ON_DEMAND;
}

static GroupByHash createGroupByHash(
List<Type> types,
boolean hasPrecomputedHash,
boolean cacheHashValues,
GroupByHashMode hashMode,
int expectedSize,
boolean dictionaryAggregationEnabled,
FlatHashStrategyCompiler hashStrategyCompiler,
UpdateMemory updateMemory)
{
if (types.size() == 1 && types.get(0).equals(BIGINT)) {
return new BigintGroupByHash(hasPrecomputedHash, expectedSize, updateMemory);
}
FlatGroupByHash.HashMode hashMode;
if (hasPrecomputedHash) {
hashMode = FlatGroupByHash.HashMode.PRECOMPUTED;
}
else if (cacheHashValues) {
hashMode = FlatGroupByHash.HashMode.CACHED;
}
else {
hashMode = FlatGroupByHash.HashMode.ON_DEMAND;
return new BigintGroupByHash(hashMode.isHashPrecomputed(), expectedSize, updateMemory);
}
return new FlatGroupByHash(
types,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator;

public enum GroupByHashMode {
// Hash values are pre-computed as input, and emitted as output
PRECOMPUTED,
// Hash values are computed by the FlatGroupByHash instance and stored along with the entry. This consumes more
// memory, but makes re-hashing cheaper by avoiding the need to re-compute each hash code and also makes the
// valueIdentical check cheaper by avoiding a deep equality check when hashes don't match
CACHED,
// Hash values are re-computed for each access on demand. This avoids storing the hash value directly in the
// table which saves memory, but can be more expensive during rehash.
ON_DEMAND;

public boolean isHashPrecomputed()
{
return this == PRECOMPUTED;
}

public boolean isHashCached()
{
return switch (this) {
case PRECOMPUTED, CACHED -> true;
case ON_DEMAND -> false;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static com.google.common.base.Verify.verify;
import static io.trino.operator.GroupByHash.selectGroupByHashMode;
import static io.trino.operator.UpdateMemory.NOOP;
import static java.util.Objects.requireNonNull;

Expand All @@ -30,7 +31,7 @@ public class GroupByHashPageIndexer

public GroupByHashPageIndexer(List<Type> hashTypes, FlatHashStrategyCompiler hashStrategyCompiler)
{
this(GroupByHash.createGroupByHash(hashTypes, false, false, 20, false, hashStrategyCompiler, NOOP));
this(GroupByHash.createGroupByHash(hashTypes, selectGroupByHashMode(false, false, hashTypes), 20, false, hashStrategyCompiler, NOOP));
}

public GroupByHashPageIndexer(GroupByHash hash)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
groupByTypes,
groupByChannels,
hashChannel,
false, // spillable
operatorContext,
maxPartialMemory,
flatHashStrategyCompiler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public InMemoryHashAggregationBuilder(
List<Type> groupByTypes,
List<Integer> groupByChannels,
Optional<Integer> hashChannel,
boolean spillable,
OperatorContext operatorContext,
Optional<DataSize> maxPartialMemory,
FlatHashStrategyCompiler hashStrategyCompiler,
Expand All @@ -82,6 +83,7 @@ public InMemoryHashAggregationBuilder(
groupByTypes,
groupByChannels,
hashChannel,
spillable,
operatorContext,
maxPartialMemory,
Optional.empty(),
Expand All @@ -97,6 +99,7 @@ public InMemoryHashAggregationBuilder(
List<Type> groupByTypes,
List<Integer> groupByChannels,
Optional<Integer> hashChannel,
boolean spillable,
OperatorContext operatorContext,
Optional<DataSize> maxPartialMemory,
Optional<Integer> unspillIntermediateChannelOffset,
Expand Down Expand Up @@ -124,7 +127,7 @@ public InMemoryHashAggregationBuilder(
operatorContext.getSession(),
groupByTypes,
hashChannel.isPresent(),
false,
spillable,
expectedGroups,
hashStrategyCompiler,
updateMemory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ private void rebuildHashAggregationBuilder()
groupByTypes,
groupByPartialChannels,
hashChannel,
false, // spillable
operatorContext,
Optional.of(DataSize.succinctBytes(0)),
Optional.of(overwriteIntermediateChannelOffset),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ private void rebuildHashAggregationBuilder()
groupByTypes,
groupByChannels,
hashChannel,
true, // spillable
operatorContext,
Optional.of(DataSize.succinctBytes(0)),
hashStrategyCompiler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public static class MultiChannelBenchmarkData
private int groupCount = GROUP_COUNT;

@Param({"PRECOMPUTED", "CACHED", "ON_DEMAND"})
private FlatGroupByHash.HashMode hashMode = FlatGroupByHash.HashMode.ON_DEMAND;
private GroupByHashMode hashMode = GroupByHashMode.ON_DEMAND;

@Param({"VARCHAR", "BIGINT"})
private String dataType = "VARCHAR";
Expand Down Expand Up @@ -245,7 +245,7 @@ public boolean isHashPrecomputed()
return hashMode.isHashPrecomputed();
}

public FlatGroupByHash.HashMode getFlatGroupByHashMode()
public GroupByHashMode getFlatGroupByHashMode()
{
return hashMode;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ public Object groupBy(BenchmarkContext data)
{
GroupByHash groupByHash = GroupByHash.createGroupByHash(
data.getTypes(),
false,
false,
data.getHashMode(),
EXPECTED_GROUP_COUNT,
false,
hashStrategyCompiler,
Expand All @@ -105,7 +104,9 @@ public Object groupBy(BenchmarkContext data)
pageBuilder.reset();
}
}
pages.add(pageBuilder.build());
if (!pageBuilder.isEmpty()) {
pages.add(pageBuilder.build());
}
return ImmutableList.of(pages, results); // all the things that might get erased by the compiler
}

Expand Down Expand Up @@ -231,6 +232,8 @@ public static class BenchmarkContext
@Param({"0", ".1", ".5", ".9"})
private double nullChance;

private GroupByHashMode hashMode;

private final int positions;
private List<Page> pages;
private List<Type> types;
Expand All @@ -255,6 +258,7 @@ public void setup()
.map(channel -> channel.columnType.type)
.collect(toImmutableList());
pages = createPages(query);
hashMode = GroupByHash.selectGroupByHashMode(false, false, types);
}

private List<Page> createPages(AggregationDefinition definition)
Expand Down Expand Up @@ -295,6 +299,11 @@ public WorkType getWorkType()
{
return workType;
}

public GroupByHashMode getHashMode()
{
return hashMode;
}
}

public enum WorkType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import static io.trino.block.BlockAssertions.createLongsBlock;
import static io.trino.block.BlockAssertions.createStringSequenceBlock;
import static io.trino.operator.GroupByHash.createGroupByHash;
import static io.trino.operator.GroupByHash.selectGroupByHashMode;
import static io.trino.operator.UpdateMemory.NOOP;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
Expand Down Expand Up @@ -70,7 +71,7 @@ public GroupByHash createGroupByHash(int expectedSize, UpdateMemory updateMemory
case BIGINT -> new BigintGroupByHash(true, expectedSize, updateMemory);
case FLAT -> new FlatGroupByHash(
ImmutableList.of(BigintType.BIGINT),
FlatGroupByHash.HashMode.PRECOMPUTED,
GroupByHashMode.PRECOMPUTED,
expectedSize,
true,
new FlatHashStrategyCompiler(new TypeOperators()),
Expand Down Expand Up @@ -301,7 +302,7 @@ public void testUpdateMemoryVarchar()

// Create GroupByHash with tiny size
AtomicInteger rehashCount = new AtomicInteger();
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), true, false, 1, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> {
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(true, false, ImmutableList.of(type)), 1, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> {
rehashCount.incrementAndGet();
return true;
});
Expand All @@ -322,7 +323,7 @@ public void testUpdateMemoryBigint()

// Create GroupByHash with tiny size
AtomicInteger rehashCount = new AtomicInteger();
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), true, false, 1, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> {
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(true, false, ImmutableList.of(type)), 1, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> {
rehashCount.incrementAndGet();
return true;
});
Expand Down Expand Up @@ -357,7 +358,7 @@ private static void testMemoryReservationYield(Type type, Block valuesBlock, int
int yields = 0;

// test addPage
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), true, false, 1, false, new FlatHashStrategyCompiler(new TypeOperators()), updateMemory);
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(true, false, ImmutableList.of(type)), 1, false, new FlatHashStrategyCompiler(new TypeOperators()), updateMemory);
boolean finish = false;
Work<?> addPageWork = groupByHash.addPage(page);
while (!finish) {
Expand All @@ -383,7 +384,7 @@ private static void testMemoryReservationYield(Type type, Block valuesBlock, int
currentQuota.set(0);
allowedQuota.set(6);
yields = 0;
groupByHash = createGroupByHash(ImmutableList.of(type), true, false, 1, false, new FlatHashStrategyCompiler(new TypeOperators()), updateMemory);
groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(true, false, ImmutableList.of(type)), 1, false, new FlatHashStrategyCompiler(new TypeOperators()), updateMemory);

finish = false;
Work<int[]> getGroupIdsWork = groupByHash.getGroupIds(page);
Expand Down Expand Up @@ -630,7 +631,7 @@ public void testProperWorkTypesSelected()

private static void assertGroupByHashWork(Page page, List<Type> types, Class<?> clazz)
{
GroupByHash groupByHash = createGroupByHash(types, false, false, 100, true, new FlatHashStrategyCompiler(new TypeOperators()), NOOP);
GroupByHash groupByHash = createGroupByHash(types, selectGroupByHashMode(false, false, types), 100, true, new FlatHashStrategyCompiler(new TypeOperators()), NOOP);
Work<int[]> work = groupByHash.getGroupIds(page);
// Compare by name since classes are private
assertThat(work).isInstanceOf(clazz);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ private GroupByHash createGroupByHash(Type partitionType, UpdateMemory updateMem
{
return GroupByHash.createGroupByHash(
ImmutableList.of(partitionType),
false,
false,
GroupByHash.selectGroupByHashMode(false, false, ImmutableList.of(partitionType)),
1,
false,
new FlatHashStrategyCompiler(typeOperators),
Expand Down
Loading