55
55
import io .trino .sql .tree .Node ;
56
56
import io .trino .sql .tree .NullInputCharacteristic ;
57
57
import io .trino .sql .tree .ParameterDeclaration ;
58
+ import io .trino .sql .tree .PropertiesCharacteristic ;
59
+ import io .trino .sql .tree .Property ;
58
60
import io .trino .sql .tree .RepeatStatement ;
59
61
import io .trino .sql .tree .ReturnStatement ;
60
62
import io .trino .sql .tree .ReturnsClause ;
71
73
import java .util .Optional ;
72
74
import java .util .Set ;
73
75
74
- import static com .google .common .base .Preconditions .checkArgument ;
75
76
import static com .google .common .collect .ImmutableList .toImmutableList ;
76
77
import static com .google .common .collect .Iterables .getLast ;
77
78
import static io .trino .spi .StandardErrorCode .ALREADY_EXISTS ;
78
79
import static io .trino .spi .StandardErrorCode .INVALID_ARGUMENTS ;
80
+ import static io .trino .spi .StandardErrorCode .INVALID_FUNCTION_PROPERTY ;
79
81
import static io .trino .spi .StandardErrorCode .MISSING_RETURN ;
80
82
import static io .trino .spi .StandardErrorCode .NOT_FOUND ;
81
83
import static io .trino .spi .StandardErrorCode .NOT_SUPPORTED ;
@@ -103,7 +105,6 @@ public SqlRoutineAnalyzer(PlannerContext plannerContext, WarningCollector warnin
103
105
public static FunctionMetadata extractFunctionMetadata (FunctionId functionId , FunctionSpecification function )
104
106
{
105
107
validateLanguage (function );
106
- validateReturn (function );
107
108
108
109
String functionName = getFunctionName (function );
109
110
Signature .Builder signatureBuilder = Signature .builder ()
@@ -139,7 +140,7 @@ public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl,
139
140
{
140
141
String functionName = getFunctionName (function );
141
142
142
- validateLanguage (function );
143
+ ControlStatement statement = validateLanguage (function );
143
144
144
145
boolean calledOnNull = isCalledOnNull (function );
145
146
Optional <String > comment = getComment (function );
@@ -150,10 +151,10 @@ public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl,
150
151
151
152
Map <String , Type > arguments = getArguments (function );
152
153
153
- validateReturn (function );
154
+ validateReturn (statement );
154
155
155
156
StatementVisitor visitor = new StatementVisitor (session , accessControl , returnType );
156
- visitor .process (function . getStatement () , new Context (arguments , Set .of ()));
157
+ visitor .process (statement , new Context (arguments , Set .of ()));
157
158
158
159
Analysis analysis = visitor .getAnalysis ();
159
160
@@ -174,6 +175,7 @@ public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl,
174
175
calledOnNull ,
175
176
actuallyDeterministic ,
176
177
comment ,
178
+ statement ,
177
179
visitor .getAnalysis ());
178
180
}
179
181
@@ -239,13 +241,24 @@ private static Optional<Identifier> getLanguage(FunctionSpecification function)
239
241
.findAny ();
240
242
}
241
243
242
- private static void validateLanguage (FunctionSpecification function )
244
+ private static ControlStatement validateLanguage (FunctionSpecification function )
243
245
{
244
246
getLanguage (function ).ifPresent (language -> {
245
247
if (!language .getValue ().equalsIgnoreCase ("sql" )) {
248
+ function .getStatement ().ifPresent (statement -> {
249
+ throw semanticException (NOT_SUPPORTED , statement , "Only functions using language 'SQL' may be defined using SQL" );
250
+ });
246
251
throw semanticException (NOT_SUPPORTED , language , "Unsupported function language: %s" , language .getCanonicalValue ());
247
252
}
248
253
});
254
+
255
+ List <Property > properties = getProperties (function );
256
+ if (!properties .isEmpty ()) {
257
+ throw semanticException (INVALID_FUNCTION_PROPERTY , properties .getFirst (), "Function language 'SQL' does not support properties" );
258
+ }
259
+
260
+ return function .getStatement ().orElseThrow (() ->
261
+ semanticException (SYNTAX_ERROR , function .getDefinition ().orElseThrow (), "Functions using language 'SQL' must be defined using SQL" ));
249
262
}
250
263
251
264
private static boolean isDeterministic (FunctionSpecification function )
@@ -321,17 +334,33 @@ private static Optional<String> getComment(FunctionSpecification function)
321
334
.findAny ();
322
335
}
323
336
324
- private static void validateReturn (FunctionSpecification function )
337
+ private static List < Property > getProperties (FunctionSpecification function )
325
338
{
326
- ControlStatement statement = function .getStatement ();
327
- if (statement instanceof ReturnStatement ) {
328
- return ;
339
+ List <PropertiesCharacteristic > properties = function .getRoutineCharacteristics ().stream ()
340
+ .filter (PropertiesCharacteristic .class ::isInstance )
341
+ .map (PropertiesCharacteristic .class ::cast )
342
+ .toList ();
343
+
344
+ if (properties .size () > 1 ) {
345
+ throw semanticException (SYNTAX_ERROR , properties .get (1 ), "Multiple properties clauses specified" );
329
346
}
330
347
331
- checkArgument (statement instanceof CompoundStatement , "invalid function statement: %s" , statement );
332
- CompoundStatement body = (CompoundStatement ) statement ;
333
- if (!(getLast (body .getStatements (), null ) instanceof ReturnStatement )) {
334
- throw semanticException (MISSING_RETURN , body , "Function must end in a RETURN statement" );
348
+ return properties .stream ()
349
+ .map (PropertiesCharacteristic ::getProperties )
350
+ .flatMap (List ::stream )
351
+ .collect (toImmutableList ());
352
+ }
353
+
354
+ private static void validateReturn (ControlStatement statement )
355
+ {
356
+ switch (statement ) {
357
+ case ReturnStatement _ -> {}
358
+ case CompoundStatement body -> {
359
+ if (!(getLast (body .getStatements (), null ) instanceof ReturnStatement )) {
360
+ throw semanticException (MISSING_RETURN , body , "Function must end in a RETURN statement" );
361
+ }
362
+ }
363
+ default -> throw new IllegalArgumentException ("Invalid function statement: " + statement );
335
364
}
336
365
}
337
366
0 commit comments