Skip to content

Commit 41fa390

Browse files
committed
Add parsing for non-SQL language functions
1 parent 43b9cf9 commit 41fa390

File tree

16 files changed

+348
-47
lines changed

16 files changed

+348
-47
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

+11-1
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,12 @@ pathSpecification
861861
;
862862

863863
functionSpecification
864-
: FUNCTION functionDeclaration returnsClause routineCharacteristic* controlStatement
864+
: FUNCTION functionDeclaration returnsClause routineCharacteristic*
865+
(controlStatement | AS functionDefinition)
866+
;
867+
868+
functionDefinition
869+
: DOLLAR_STRING
865870
;
866871

867872
functionDeclaration
@@ -883,6 +888,7 @@ routineCharacteristic
883888
| CALLED ON NULL INPUT #calledOnNullInputCharacteristic
884889
| SECURITY (DEFINER | INVOKER) #securityCharacteristic
885890
| COMMENT string #commentCharacteristic
891+
| (WITH properties) #propertiesCharacteristic
886892
;
887893

888894
controlStatement
@@ -1332,6 +1338,10 @@ UNICODE_STRING
13321338
: 'U&\'' ( ~'\'' | '\'\'' )* '\''
13331339
;
13341340

1341+
DOLLAR_STRING
1342+
: '$$' .*? '$$'
1343+
;
1344+
13351345
// Note: we allow any character inside the binary literal and validate
13361346
// its a correct literal when the AST is being constructed. This
13371347
// allows us to provide more meaningful error messages to the user

core/trino-main/src/main/java/io/trino/metadata/LanguageFunctionManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ private synchronized void analyzeAndPlan(AccessControl accessControl)
437437
analyzing = true;
438438

439439
SqlRoutineAnalysis analysis = analyze(functionContext(accessControl));
440-
routine = planner.planSqlFunction(session, functionSpecification, analysis);
440+
routine = planner.planSqlFunction(session, analysis);
441441

442442
Hasher hasher = Hashing.sha256().newHasher();
443443
SqlRoutineHash.hash(routine, hasher, blockEncodingSerde);

core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalysis.java

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import com.google.common.collect.ImmutableMap;
1717
import io.trino.spi.type.Type;
1818
import io.trino.sql.analyzer.Analysis;
19+
import io.trino.sql.tree.ControlStatement;
1920

2021
import java.util.Map;
2122
import java.util.Optional;
@@ -29,6 +30,7 @@ public record SqlRoutineAnalysis(
2930
boolean calledOnNull,
3031
boolean deterministic,
3132
Optional<String> comment,
33+
ControlStatement statement,
3234
Analysis analysis)
3335
{
3436
public SqlRoutineAnalysis
@@ -37,6 +39,7 @@ public record SqlRoutineAnalysis(
3739
arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null"));
3840
requireNonNull(returnType, "returnType is null");
3941
requireNonNull(comment, "comment is null");
42+
requireNonNull(statement, "statement is null");
4043
requireNonNull(analysis, "analysis is null");
4144
}
4245
}

core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineAnalyzer.java

+43-14
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
import io.trino.sql.tree.Node;
5656
import io.trino.sql.tree.NullInputCharacteristic;
5757
import io.trino.sql.tree.ParameterDeclaration;
58+
import io.trino.sql.tree.PropertiesCharacteristic;
59+
import io.trino.sql.tree.Property;
5860
import io.trino.sql.tree.RepeatStatement;
5961
import io.trino.sql.tree.ReturnStatement;
6062
import io.trino.sql.tree.ReturnsClause;
@@ -71,11 +73,11 @@
7173
import java.util.Optional;
7274
import java.util.Set;
7375

74-
import static com.google.common.base.Preconditions.checkArgument;
7576
import static com.google.common.collect.ImmutableList.toImmutableList;
7677
import static com.google.common.collect.Iterables.getLast;
7778
import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS;
7879
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS;
80+
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_PROPERTY;
7981
import static io.trino.spi.StandardErrorCode.MISSING_RETURN;
8082
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
8183
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
@@ -103,7 +105,6 @@ public SqlRoutineAnalyzer(PlannerContext plannerContext, WarningCollector warnin
103105
public static FunctionMetadata extractFunctionMetadata(FunctionId functionId, FunctionSpecification function)
104106
{
105107
validateLanguage(function);
106-
validateReturn(function);
107108

108109
String functionName = getFunctionName(function);
109110
Signature.Builder signatureBuilder = Signature.builder()
@@ -139,7 +140,7 @@ public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl,
139140
{
140141
String functionName = getFunctionName(function);
141142

142-
validateLanguage(function);
143+
ControlStatement statement = validateLanguage(function);
143144

144145
boolean calledOnNull = isCalledOnNull(function);
145146
Optional<String> comment = getComment(function);
@@ -150,10 +151,10 @@ public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl,
150151

151152
Map<String, Type> arguments = getArguments(function);
152153

153-
validateReturn(function);
154+
validateReturn(statement);
154155

155156
StatementVisitor visitor = new StatementVisitor(session, accessControl, returnType);
156-
visitor.process(function.getStatement(), new Context(arguments, Set.of()));
157+
visitor.process(statement, new Context(arguments, Set.of()));
157158

158159
Analysis analysis = visitor.getAnalysis();
159160

@@ -174,6 +175,7 @@ public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl,
174175
calledOnNull,
175176
actuallyDeterministic,
176177
comment,
178+
statement,
177179
visitor.getAnalysis());
178180
}
179181

@@ -239,13 +241,24 @@ private static Optional<Identifier> getLanguage(FunctionSpecification function)
239241
.findAny();
240242
}
241243

242-
private static void validateLanguage(FunctionSpecification function)
244+
private static ControlStatement validateLanguage(FunctionSpecification function)
243245
{
244246
getLanguage(function).ifPresent(language -> {
245247
if (!language.getValue().equalsIgnoreCase("sql")) {
248+
function.getStatement().ifPresent(statement -> {
249+
throw semanticException(NOT_SUPPORTED, statement, "Only functions using language 'SQL' may be defined using SQL");
250+
});
246251
throw semanticException(NOT_SUPPORTED, language, "Unsupported function language: %s", language.getCanonicalValue());
247252
}
248253
});
254+
255+
List<Property> properties = getProperties(function);
256+
if (!properties.isEmpty()) {
257+
throw semanticException(INVALID_FUNCTION_PROPERTY, properties.getFirst(), "Function language 'SQL' does not support properties");
258+
}
259+
260+
return function.getStatement().orElseThrow(() ->
261+
semanticException(SYNTAX_ERROR, function.getDefinition().orElseThrow(), "Functions using language 'SQL' must be defined using SQL"));
249262
}
250263

251264
private static boolean isDeterministic(FunctionSpecification function)
@@ -321,17 +334,33 @@ private static Optional<String> getComment(FunctionSpecification function)
321334
.findAny();
322335
}
323336

324-
private static void validateReturn(FunctionSpecification function)
337+
private static List<Property> getProperties(FunctionSpecification function)
325338
{
326-
ControlStatement statement = function.getStatement();
327-
if (statement instanceof ReturnStatement) {
328-
return;
339+
List<PropertiesCharacteristic> properties = function.getRoutineCharacteristics().stream()
340+
.filter(PropertiesCharacteristic.class::isInstance)
341+
.map(PropertiesCharacteristic.class::cast)
342+
.toList();
343+
344+
if (properties.size() > 1) {
345+
throw semanticException(SYNTAX_ERROR, properties.get(1), "Multiple properties clauses specified");
329346
}
330347

331-
checkArgument(statement instanceof CompoundStatement, "invalid function statement: %s", statement);
332-
CompoundStatement body = (CompoundStatement) statement;
333-
if (!(getLast(body.getStatements(), null) instanceof ReturnStatement)) {
334-
throw semanticException(MISSING_RETURN, body, "Function must end in a RETURN statement");
348+
return properties.stream()
349+
.map(PropertiesCharacteristic::getProperties)
350+
.flatMap(List::stream)
351+
.collect(toImmutableList());
352+
}
353+
354+
private static void validateReturn(ControlStatement statement)
355+
{
356+
switch (statement) {
357+
case ReturnStatement _ -> {}
358+
case CompoundStatement body -> {
359+
if (!(getLast(body.getStatements(), null) instanceof ReturnStatement)) {
360+
throw semanticException(MISSING_RETURN, body, "Function must end in a RETURN statement");
361+
}
362+
}
363+
default -> throw new IllegalArgumentException("Invalid function statement: " + statement);
335364
}
336365
}
337366

core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutinePlanner.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
import io.trino.sql.tree.ControlStatement;
5757
import io.trino.sql.tree.ElseIfClause;
5858
import io.trino.sql.tree.Expression;
59-
import io.trino.sql.tree.FunctionSpecification;
6059
import io.trino.sql.tree.Identifier;
6160
import io.trino.sql.tree.IfStatement;
6261
import io.trino.sql.tree.IterateStatement;
@@ -99,24 +98,23 @@ public SqlRoutinePlanner(PlannerContext plannerContext)
9998
this.optimizer = newOptimizer(plannerContext);
10099
}
101100

102-
public IrRoutine planSqlFunction(Session session, FunctionSpecification function, SqlRoutineAnalysis routineAnalysis)
101+
public IrRoutine planSqlFunction(Session session, SqlRoutineAnalysis analysis)
103102
{
104103
List<IrVariable> allVariables = new ArrayList<>();
105104
Map<String, IrVariable> scopeVariables = new LinkedHashMap<>();
106105

107106
ImmutableList.Builder<IrVariable> parameters = ImmutableList.builder();
108-
routineAnalysis.arguments().forEach((name, type) -> {
107+
analysis.arguments().forEach((name, type) -> {
109108
IrVariable variable = new IrVariable(allVariables.size(), type, constantNull(type));
110109
allVariables.add(variable);
111110
scopeVariables.put(name, variable);
112111
parameters.add(variable);
113112
});
114113

115-
Analysis analysis = routineAnalysis.analysis();
116-
StatementVisitor visitor = new StatementVisitor(session, allVariables, analysis);
117-
IrStatement body = visitor.process(function.getStatement(), new Context(scopeVariables, Map.of()));
114+
StatementVisitor visitor = new StatementVisitor(session, allVariables, analysis.analysis());
115+
IrStatement body = visitor.process(analysis.statement(), new Context(scopeVariables, Map.of()));
118116

119-
return new IrRoutine(routineAnalysis.returnType(), parameters.build(), body);
117+
return new IrRoutine(analysis.returnType(), parameters.build(), body);
120118
}
121119

122120
private class StatementVisitor

core/trino-main/src/test/java/io/trino/sql/routine/TestSqlFunctions.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ private static ScalarFunctionImplementation compileFunction(@Language("SQL") Str
620620
SqlRoutineAnalysis analysis = analyzer.analyze(session, new AllowAllAccessControl(), function);
621621

622622
SqlRoutinePlanner planner = new SqlRoutinePlanner(PLANNER_CONTEXT);
623-
IrRoutine routine = planner.planSqlFunction(session, function, analysis);
623+
IrRoutine routine = planner.planSqlFunction(session, analysis);
624624

625625
if (serialize) {
626626
// Simulate worker communication

core/trino-main/src/test/java/io/trino/sql/routine/TestSqlRoutineAnalyzer.java

+17
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS;
2828
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS;
29+
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_PROPERTY;
2930
import static io.trino.spi.StandardErrorCode.MISSING_RETURN;
3031
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
3132
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
@@ -76,6 +77,10 @@ void testCharacteristics()
7677
assertFails("FUNCTION test() RETURNS int NOT DETERMINISTIC DETERMINISTIC RETURN 123")
7778
.hasErrorCode(SYNTAX_ERROR)
7879
.hasMessage("line 1:47: Multiple deterministic clauses specified");
80+
81+
assertFails("FUNCTION test() RETURNS int WITH (x = 1) WITH (y = 2) RETURN 123")
82+
.hasErrorCode(SYNTAX_ERROR)
83+
.hasMessage("line 1:42: Multiple properties clauses specified");
7984
}
8085

8186
@Test
@@ -120,8 +125,20 @@ void testLanguage()
120125
.returns(true, from(SqlRoutineAnalysis::deterministic));
121126

122127
assertFails("FUNCTION test() RETURNS bigint LANGUAGE JAVASCRIPT RETURN abs(-42)")
128+
.hasErrorCode(NOT_SUPPORTED)
129+
.hasMessage("line 1:52: Only functions using language 'SQL' may be defined using SQL");
130+
131+
assertFails("FUNCTION test() RETURNS bigint LANGUAGE JAVASCRIPT AS 'xxx'")
123132
.hasErrorCode(NOT_SUPPORTED)
124133
.hasMessage("line 1:41: Unsupported function language: JAVASCRIPT");
134+
135+
assertFails("FUNCTION test() RETURNS bigint AS 'xxx'")
136+
.hasErrorCode(SYNTAX_ERROR)
137+
.hasMessage("line 1:35: Functions using language 'SQL' must be defined using SQL");
138+
139+
assertFails("FUNCTION test() RETURNS bigint WITH (abc = 'test') RETURN abs(-42)")
140+
.hasErrorCode(INVALID_FUNCTION_PROPERTY)
141+
.hasMessage("line 1:38: Function language 'SQL' does not support properties");
125142
}
126143

127144
@Test

core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
import io.trino.sql.tree.PlanSiblings;
110110
import io.trino.sql.tree.Prepare;
111111
import io.trino.sql.tree.PrincipalSpecification;
112+
import io.trino.sql.tree.PropertiesCharacteristic;
112113
import io.trino.sql.tree.Property;
113114
import io.trino.sql.tree.QualifiedName;
114115
import io.trino.sql.tree.Query;
@@ -160,6 +161,7 @@
160161
import io.trino.sql.tree.ShowTables;
161162
import io.trino.sql.tree.SingleColumn;
162163
import io.trino.sql.tree.StartTransaction;
164+
import io.trino.sql.tree.StringLiteral;
163165
import io.trino.sql.tree.Table;
164166
import io.trino.sql.tree.TableExecute;
165167
import io.trino.sql.tree.TableFunctionArgument;
@@ -2296,7 +2298,11 @@ protected Void visitFunctionSpecification(FunctionSpecification node, Integer in
22962298
process(characteristic, indent);
22972299
builder.append("\n");
22982300
}
2299-
process(node.getStatement(), indent);
2301+
node.getStatement().ifPresent(statement -> process(statement, indent));
2302+
node.getDefinition().map(StringLiteral::getValue).ifPresent(definition -> {
2303+
append(indent, "AS ");
2304+
builder.append("$$\n").append(definition).append("$$");
2305+
});
23002306
return null;
23012307
}
23022308

@@ -2352,6 +2358,20 @@ protected Void visitCommentCharacteristic(CommentCharacteristic node, Integer in
23522358
return null;
23532359
}
23542360

2361+
@Override
2362+
protected Void visitPropertiesCharacteristic(PropertiesCharacteristic node, Integer indent)
2363+
{
2364+
append(indent, "WITH (\n");
2365+
Iterator<Property> iterator = node.getProperties().iterator();
2366+
while (iterator.hasNext()) {
2367+
Property property = iterator.next();
2368+
append(indent + 1, formatProperty(property));
2369+
builder.append(iterator.hasNext() ? ",\n" : "\n");
2370+
}
2371+
append(indent, ")");
2372+
return null;
2373+
}
2374+
23552375
@Override
23562376
protected Void visitReturnClause(ReturnsClause node, Integer indent)
23572377
{

core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java

+36-5
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@
203203
import io.trino.sql.tree.Prepare;
204204
import io.trino.sql.tree.PrincipalSpecification;
205205
import io.trino.sql.tree.ProcessingMode;
206+
import io.trino.sql.tree.PropertiesCharacteristic;
206207
import io.trino.sql.tree.Property;
207208
import io.trino.sql.tree.QualifiedName;
208209
import io.trino.sql.tree.QuantifiedComparisonExpression;
@@ -3710,17 +3711,39 @@ public Node visitJsonTableDefaultPlan(SqlBaseParser.JsonTableDefaultPlanContext
37103711
@Override
37113712
public Node visitFunctionSpecification(SqlBaseParser.FunctionSpecificationContext context)
37123713
{
3713-
ControlStatement statement = (ControlStatement) visit(context.controlStatement());
3714-
if (!(statement instanceof ReturnStatement || statement instanceof CompoundStatement)) {
3715-
throw parseError("Function body must start with RETURN or BEGIN", context.controlStatement());
3716-
}
3714+
Optional<ControlStatement> statement = visitIfPresent(context.controlStatement(), ControlStatement.class);
3715+
statement.ifPresent(body -> {
3716+
if (!(body instanceof ReturnStatement || body instanceof CompoundStatement)) {
3717+
throw parseError("Function body must start with RETURN or BEGIN", context.controlStatement());
3718+
}
3719+
});
3720+
37173721
return new FunctionSpecification(
37183722
getLocation(context),
37193723
getQualifiedName(context.functionDeclaration().qualifiedName()),
37203724
visit(context.functionDeclaration().parameterDeclaration(), ParameterDeclaration.class),
37213725
(ReturnsClause) visit(context.returnsClause()),
37223726
visit(context.routineCharacteristic(), RoutineCharacteristic.class),
3723-
statement);
3727+
statement,
3728+
visitIfPresent(context.functionDefinition(), StringLiteral.class));
3729+
}
3730+
3731+
@Override
3732+
public Node visitFunctionDefinition(SqlBaseParser.FunctionDefinitionContext context)
3733+
{
3734+
String value = context.getText();
3735+
value = value.substring(2, value.length() - 2);
3736+
if (value.isEmpty() || ((value.charAt(0) != '\r') && (value.charAt(0) != '\n'))) {
3737+
throw parseError("Function definition must start with a newline after opening quotes", context);
3738+
}
3739+
// strip leading \r or \n or \r\n
3740+
if (value.charAt(0) == '\r') {
3741+
value = value.substring(1);
3742+
}
3743+
if (!value.isEmpty() && value.charAt(0) == '\n') {
3744+
value = value.substring(1);
3745+
}
3746+
return new StringLiteral(getLocation(context), value);
37243747
}
37253748

37263749
@Override
@@ -3776,6 +3799,14 @@ public Node visitCommentCharacteristic(SqlBaseParser.CommentCharacteristicContex
37763799
return new CommentCharacteristic(getLocation(context), visitString(context.string()).getValue());
37773800
}
37783801

3802+
@Override
3803+
public Node visitPropertiesCharacteristic(SqlBaseParser.PropertiesCharacteristicContext context)
3804+
{
3805+
return new PropertiesCharacteristic(
3806+
getLocation(context),
3807+
visit(context.properties().propertyAssignments().property(), Property.class));
3808+
}
3809+
37793810
@Override
37803811
public Node visitReturnStatement(SqlBaseParser.ReturnStatementContext context)
37813812
{

0 commit comments

Comments
 (0)