Skip to content

Commit 6cbfff3

Browse files
committed
Allow field name declaration in row literal
Add support for `row(a 1, b 2)` instead of the much more complex `cast(row(1, 2) as row(a integer, b integer))`.
1 parent 0b565ee commit 6cbfff3

File tree

71 files changed

+491
-274
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+491
-274
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

+8-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ primaryExpression
574574
| QUESTION_MARK #parameter
575575
| POSITION '(' valueExpression IN valueExpression ')' #position
576576
| '(' expression (',' expression)+ ')' #rowConstructor
577-
| ROW '(' expression (',' expression)* ')' #rowConstructor
577+
| ROW '(' fieldConstructor (',' fieldConstructor)* ')' #rowConstructor
578578
| name=LISTAGG '(' setQuantifier? expression (',' string)?
579579
(ON OVERFLOW listAggOverflowBehavior)? ')'
580580
(WITHIN GROUP '(' ORDER BY sortItem (',' sortItem)* ')')
@@ -646,6 +646,13 @@ primaryExpression
646646
')' #jsonArray
647647
;
648648

649+
// Match a single expression before identifier expression so that the field `ROW(1, 2)` is not parsed as
650+
// "ROW" identifier followed by the expression `(1, 2)`.
651+
fieldConstructor
652+
: expression
653+
| identifier expression
654+
;
655+
649656
jsonPathInvocation
650657
: jsonValueExpression ',' path=string
651658
(AS pathName=identifier)?

core/trino-main/src/main/java/io/trino/cost/ValuesStatsRule.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ private List<Object> getSymbolValues(ValuesNode valuesNode, int symbolId, Type r
8888
checkState(valuesNode.getRows().isPresent(), "rows is empty");
8989
return valuesNode.getRows().get().stream()
9090
.map(row -> switch (row) {
91-
case Row value -> ((Constant) value.items().get(symbolId)).value();
91+
case Row value -> ((Constant) value.fields().get(symbolId).value()).value();
9292
case Constant(Type type, SqlRow value) -> readNativeValue(symbolType, value.getRawFieldBlock(symbolId), value.getRawIndex());
9393
default -> throw new IllegalArgumentException("Expected Row or Constant: " + row);
9494
})

core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,16 @@ protected Boolean visitTryExpression(TryExpression node, Void context)
701701
@Override
702702
protected Boolean visitRow(Row node, Void context)
703703
{
704-
return node.getItems().stream()
704+
return node.getFields().stream()
705705
.allMatch(item -> process(item, context));
706706
}
707707

708+
@Override
709+
protected Boolean visitRowField(Row.Field node, Void context)
710+
{
711+
return process(node.getExpression(), context);
712+
}
713+
708714
@Override
709715
protected Boolean visitParameter(Parameter node, Void context)
710716
{

core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -679,11 +679,11 @@ public Type process(Node node, @Nullable Context context)
679679
@Override
680680
protected Type visitRow(Row node, Context context)
681681
{
682-
List<Type> types = node.getItems().stream()
683-
.map(child -> process(child, context))
682+
List<RowType.Field> fields = node.getFields().stream()
683+
.map(field -> new RowType.Field(field.getName().map(Identifier::getCanonicalValue), process(field.getExpression(), context)))
684684
.collect(toImmutableList());
685685

686-
Type type = RowType.anonymous(types);
686+
Type type = RowType.from(fields);
687687
return setExpressionType(node, type);
688688
}
689689

core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3929,7 +3929,7 @@ protected Scope visitValues(Values node, Optional<Scope> scope)
39293929
// TODO coerce the whole Row and add an Optimizer rule that converts CAST(ROW(...) AS ...) into ROW(CAST(...), CAST(...), ...).
39303930
// The rule would also handle Row-type expressions that were specified as CAST(ROW). It should support multiple casts over a ROW.
39313931
for (int i = 0; i < actualType.getTypeParameters().size(); i++) {
3932-
Expression item = ((Row) row).getItems().get(i);
3932+
Expression item = ((Row) row).getFields().get(i).getExpression();
39333933
Type actualItemType = actualType.getTypeParameters().get(i);
39343934
Type expectedItemType = commonSuperType.getTypeParameters().get(i);
39353935
if (!actualItemType.equals(expectedItemType)) {

core/trino-main/src/main/java/io/trino/sql/ir/DefaultTraversalVisitor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ protected Void visitLogical(Logical node, C context)
158158
@Override
159159
protected Void visitRow(Row node, C context)
160160
{
161-
for (Expression expression : node.items()) {
162-
process(expression, context);
161+
for (Row.Field field : node.fields()) {
162+
process(field.value(), context);
163163
}
164164
return null;
165165
}

core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,16 @@ protected String visitArray(Array node, Void context)
6767
@Override
6868
protected String visitRow(Row node, Void context)
6969
{
70-
return node.items().stream()
71-
.map(child -> process(child, context))
70+
return node.fields().stream()
71+
.map(this::formatRowField)
7272
.collect(joining(", ", "ROW (", ")"));
7373
}
7474

75+
private String formatRowField(Row.Field field)
76+
{
77+
return field.name().map(name -> name + " ").orElse("") + process(field.value(), null);
78+
}
79+
7580
@Override
7681
protected String visitExpression(Expression node, Void context)
7782
{

core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,13 @@ protected Expression visitRow(Row node, Context<C> context)
107107
}
108108
}
109109

110-
List<Expression> items = rewrite(node.items(), context);
110+
ImmutableList.Builder<Row.Field> builder = ImmutableList.builder();
111+
for (Row.Field field : node.fields()) {
112+
builder.add(new Row.Field(field.name(), rewrite(field.value(), context.get())));
113+
}
114+
List<Row.Field> items = builder.build();
111115

112-
if (!sameElements(node.items(), items)) {
116+
if (!sameElements(node.fields(), items)) {
113117
return new Row(items);
114118
}
115119

core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public static boolean mayFail(PlannerContext plannerContext, Expression expressi
102102
case Logical e -> e.terms().stream().anyMatch(argument -> mayFail(plannerContext, argument));
103103
case NullIf e -> mayFail(plannerContext, e.first()) || mayFail(plannerContext, e.second());
104104
case Reference e -> false;
105-
case Row e -> e.items().stream().anyMatch(argument -> mayFail(plannerContext, argument));
105+
case Row e -> e.fields().stream().anyMatch(field -> mayFail(plannerContext, field.value()));
106106
case Switch e -> mayFail(plannerContext, e.operand()) || e.whenClauses().stream().anyMatch(clause -> mayFail(plannerContext, clause.getOperand()) || mayFail(plannerContext, clause.getResult())) ||
107107
mayFail(plannerContext, e.defaultValue());
108108
};

core/trino-main/src/main/java/io/trino/sql/ir/Row.java

+43-7
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,32 @@
1919
import io.trino.spi.type.Type;
2020

2121
import java.util.List;
22+
import java.util.Optional;
2223
import java.util.stream.Collectors;
2324

2425
import static java.util.Objects.requireNonNull;
2526

2627
@JsonSerialize
27-
public record Row(List<Expression> items)
28+
public record Row(List<Field> fields)
2829
implements Expression
2930
{
3031
public Row
3132
{
32-
requireNonNull(items, "items is null");
33-
items = ImmutableList.copyOf(items);
33+
requireNonNull(fields, "fields is null");
34+
fields = ImmutableList.copyOf(fields);
35+
}
36+
37+
public static Row anonymousRow(List<Expression> values)
38+
{
39+
return new Row(values.stream()
40+
.map(Field::anonymousField)
41+
.collect(Collectors.toList()));
3442
}
3543

3644
@Override
3745
public Type type()
3846
{
39-
return RowType.anonymous(items.stream().map(Expression::type).collect(Collectors.toList()));
47+
return RowType.from(fields.stream().map(Field::asRowTypeField).collect(Collectors.toList()));
4048
}
4149

4250
@Override
@@ -48,16 +56,44 @@ public <R, C> R accept(IrVisitor<R, C> visitor, C context)
4856
@Override
4957
public List<? extends Expression> children()
5058
{
51-
return items;
59+
return fields.stream()
60+
.map(Field::value)
61+
.collect(Collectors.toList());
5262
}
5363

5464
@Override
5565
public String toString()
5666
{
5767
return "(" +
58-
items.stream()
59-
.map(Expression::toString)
68+
fields.stream()
69+
.map(Field::toString)
6070
.collect(Collectors.joining(", ")) +
6171
")";
6272
}
73+
74+
@JsonSerialize
75+
public record Field(Optional<String> name, Expression value)
76+
{
77+
public Field
78+
{
79+
requireNonNull(name, "name is null");
80+
requireNonNull(value, "value is null");
81+
}
82+
83+
public static Field anonymousField(Expression value)
84+
{
85+
return new Field(Optional.empty(), value);
86+
}
87+
88+
public RowType.Field asRowTypeField()
89+
{
90+
return new RowType.Field(name, value.type());
91+
}
92+
93+
@Override
94+
public String toString()
95+
{
96+
return name.map(n -> n + " " + value).orElseGet(value::toString);
97+
}
98+
}
6399
}

core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionEvaluator.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ private Object evaluateInternal(Switch expression, Session session, Map<String,
173173
private Object evaluateInternal(Row expression, Session session, Map<String, Object> bindings)
174174
{
175175
return buildRowValue((RowType) expression.type(), builders -> {
176-
for (int i = 0; i < expression.items().size(); ++i) {
176+
for (int i = 0; i < expression.fields().size(); ++i) {
177+
Expression fieldValue = expression.fields().get(i).value();
177178
writeNativeValue(
178-
expression.items().get(i).type(), builders.get(i),
179-
evaluate(expression.items().get(i), session, bindings));
179+
fieldValue.type(), builders.get(i),
180+
evaluate(fieldValue, session, bindings));
180181
}
181182
});
182183
}

core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionOptimizer.java

+11-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,17 @@ private Optional<Expression> processChildren(Expression expression, Session sess
192192
case Logical logical -> process(logical.terms(), session, bindings).map(arguments -> new Logical(logical.operator(), arguments));
193193
case Call call -> process(call.arguments(), session, bindings).map(arguments -> new Call(call.function(), arguments));
194194
case Array array -> process(array.elements(), session, bindings).map(elements -> new Array(array.elementType(), elements));
195-
case Row row -> process(row.items(), session, bindings).map(fields -> new Row(fields));
195+
case Row row -> {
196+
boolean changed = false;
197+
ImmutableList.Builder<Row.Field> fields = ImmutableList.builder();
198+
for (Row.Field field : row.fields()) {
199+
Optional<Expression> optimized = process(field.value(), session, bindings);
200+
changed = changed || optimized.isPresent();
201+
fields.add(new Row.Field(field.name(), optimized.orElse(field.value())));
202+
}
203+
204+
yield changed ? Optional.of(new Row(fields.build())) : Optional.empty();
205+
}
196206
case Between between -> {
197207
Optional<Expression> value = process(between.value(), session, bindings);
198208
Optional<Expression> min = process(between.min(), session, bindings);

core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateFieldReference.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class EvaluateFieldReference
4343
public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings)
4444
{
4545
return switch (expression) {
46-
case FieldReference(Row row, int field) -> Optional.of(row.items().get(field));
46+
case FieldReference(Row row, int field) -> Optional.of(row.fields().get(field).value());
4747
case FieldReference(Constant(RowType type, SqlRow row), int field) -> {
4848
Type fieldType = type.getFields().get(field).getType();
4949
yield Optional.of(new Constant(

core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateRow.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public class EvaluateRow
3737
@Override
3838
public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings)
3939
{
40-
if (!(expression instanceof Row(List<Expression> fields)) || !fields.stream().allMatch(Constant.class::isInstance)) {
40+
if (!(expression instanceof Row(List<Row.Field> fields)) || !fields.stream().map(Row.Field::value).allMatch(Constant.class::isInstance)) {
4141
return Optional.empty();
4242
}
4343

@@ -46,7 +46,8 @@ public Optional<Expression> apply(Expression expression, Session session, Map<Sy
4646
rowType,
4747
buildRowValue(rowType, builders -> {
4848
for (int i = 0; i < fields.size(); ++i) {
49-
writeNativeValue(fields.get(i).type(), builders.get(i), ((Constant) fields.get(i)).value());
49+
Expression fieldValue = fields.get(i).value();
50+
writeNativeValue(fieldValue.type(), builders.get(i), ((Constant) fieldValue).value());
5051
}
5152
})));
5253
}

core/trino-main/src/main/java/io/trino/sql/planner/EffectivePredicateExtractor.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ public Expression visitValues(ValuesNode node, Void context)
366366
for (Expression expression : node.getRows().get()) {
367367
if (expression instanceof Row row) {
368368
for (int i = 0; i < node.getOutputSymbols().size(); i++) {
369-
Expression value = row.items().get(i);
369+
Expression value = row.fields().get(i).value();
370370
if (!DeterminismEvaluator.isDeterministic(value)) {
371371
nonDeterministic[i] = true;
372372
}

core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ public PlanNode planStatement(Analysis analysis, Statement statement)
346346
if ((statement instanceof CreateTableAsSelect && analysis.getCreate().orElseThrow().isCreateTableAsSelectNoOp()) ||
347347
statement instanceof RefreshMaterializedView && analysis.isSkipMaterializedViewRefresh()) {
348348
Symbol symbol = symbolAllocator.newSymbol("rows", BIGINT);
349-
PlanNode source = new ValuesNode(idAllocator.getNextId(), ImmutableList.of(symbol), ImmutableList.of(new Row(ImmutableList.of(new Constant(BIGINT, 0L)))));
349+
PlanNode source = new ValuesNode(idAllocator.getNextId(), ImmutableList.of(symbol), ImmutableList.of(Row.anonymousRow(ImmutableList.of(new Constant(BIGINT, 0L)))));
350350
return new OutputNode(idAllocator.getNextId(), source, ImmutableList.of("rows"), ImmutableList.of(symbol));
351351
}
352352
return createOutputPlan(planStatementWithoutOutput(analysis, statement), analysis);

core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ public PlanNode plan(Update node)
675675
rowBuilder.add(new Constant(INTEGER, 0L));
676676

677677
// Finally, the merge row is complete
678-
Expression mergeRow = new Row(rowBuilder.build());
678+
Expression mergeRow = Row.anonymousRow(rowBuilder.build());
679679

680680
List<io.trino.sql.tree.Expression> constraints = analysis.getCheckConstraints(table);
681681
if (!constraints.isEmpty()) {
@@ -850,7 +850,7 @@ public MergeWriterNode plan(Merge merge)
850850
coerceIfNecessary(analysis, casePredicate.get(), subPlan.rewrite(casePredicate.get())));
851851
}
852852

853-
whenClauses.add(new WhenClause(condition, new Row(rowBuilder.build())));
853+
whenClauses.add(new WhenClause(condition, Row.anonymousRow(rowBuilder.build())));
854854

855855
List<io.trino.sql.tree.Expression> constraints = analysis.getCheckConstraints(mergeAnalysis.getTargetTable());
856856
if (!constraints.isEmpty()) {

core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -1765,15 +1765,17 @@ protected RelationPlan visitValues(Values node, Void context)
17651765
ImmutableList.Builder<Expression> rows = ImmutableList.builder();
17661766
for (io.trino.sql.tree.Expression row : node.getRows()) {
17671767
if (row instanceof io.trino.sql.tree.Row) {
1768-
rows.add(new Row(((io.trino.sql.tree.Row) row).getItems().stream()
1769-
.map(item -> coerceIfNecessary(analysis, item, translationMap.rewrite(item)))
1768+
rows.add(new Row(((io.trino.sql.tree.Row) row).getFields().stream()
1769+
.map(field -> new Row.Field(
1770+
field.getName().map(Identifier::getCanonicalValue),
1771+
coerceIfNecessary(analysis, field.getExpression(), translationMap.rewrite(field.getExpression()))))
17701772
.collect(toImmutableList())));
17711773
}
17721774
else if (analysis.getType(row) instanceof RowType) {
17731775
rows.add(coerceIfNecessary(analysis, row, translationMap.rewrite(row)));
17741776
}
17751777
else {
1776-
rows.add(new Row(ImmutableList.of(coerceIfNecessary(analysis, row, translationMap.rewrite(row)))));
1778+
rows.add(Row.anonymousRow(ImmutableList.of(coerceIfNecessary(analysis, row, translationMap.rewrite(row)))));
17771779
}
17781780
}
17791781

core/trino-main/src/main/java/io/trino/sql/planner/SubqueryPlanner.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ private PlanBuilder planScalarSubquery(PlanBuilder subPlan, Cluster<SubqueryExpr
245245
}
246246
}
247247

248-
Expression expression = new Cast(new Row(fields.build()), type);
248+
Expression expression = new Cast(Row.anonymousRow(fields.build()), type);
249249

250250
root = new ProjectNode(idAllocator.getNextId(), root, Assignments.of(column, expression));
251251
}
@@ -442,7 +442,7 @@ private PlanAndMappings planValue(PlanBuilder subPlan, io.trino.sql.tree.Express
442442

443443
Assignments assignments = Assignments.builder()
444444
.putIdentities(subPlan.getRoot().getOutputSymbols())
445-
.put(wrapped, new Row(ImmutableList.of(column.toSymbolReference())))
445+
.put(wrapped, Row.anonymousRow(ImmutableList.of(column.toSymbolReference())))
446446
.build();
447447

448448
subPlan = subPlan.withNewRoot(new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments));
@@ -481,7 +481,7 @@ private PlanAndMappings planSubquery(io.trino.sql.tree.Expression subquery, Opti
481481
new ProjectNode(
482482
idAllocator.getNextId(),
483483
relationPlan.getRoot(),
484-
Assignments.of(column, new Cast(new Row(fields.build()), type))));
484+
Assignments.of(column, new Cast(Row.anonymousRow(fields.build()), type))));
485485

486486
return coerceIfNecessary(subqueryPlan, column, subquery, coercion);
487487
}

0 commit comments

Comments
 (0)