5
5
import logging
6
6
import traceback
7
7
from collections import defaultdict
8
- from typing import Any , Dict , List , Optional , Set , Tuple , Union
8
+ from typing import Any , Dict , List , Optional , Set , Tuple , TypeVar , Union
9
9
10
10
import pydantic .dataclasses
11
11
import sqlglot
@@ -873,6 +873,49 @@ def _translate_internal_column_lineage(
873
873
)
874
874
875
875
876
+ _StrOrNone = TypeVar ("_StrOrNone" , str , Optional [str ])
877
+
878
+
879
+ def _normalize_db_or_schema (
880
+ db_or_schema : _StrOrNone ,
881
+ dialect : sqlglot .Dialect ,
882
+ ) -> _StrOrNone :
883
+ if db_or_schema is None :
884
+ return None
885
+
886
+ # In snowflake, table identifiers must be uppercased to match sqlglot's behavior.
887
+ if is_dialect_instance (dialect , "snowflake" ):
888
+ return db_or_schema .upper ()
889
+
890
+ # In mssql, table identifiers must be lowercased.
891
+ elif is_dialect_instance (dialect , "mssql" ):
892
+ return db_or_schema .lower ()
893
+
894
+ return db_or_schema
895
+
896
+
897
+ def _simplify_select_into (statement : sqlglot .exp .Expression ) -> sqlglot .exp .Expression :
898
+ """
899
+ Check if the expression is a SELECT INTO statement. If so, converts it into a CTAS.
900
+ Other expressions are returned as-is.
901
+ """
902
+
903
+ if not (isinstance (statement , sqlglot .exp .Select ) and statement .args .get ("into" )):
904
+ return statement
905
+
906
+ # Convert from SELECT <cols> INTO <out> <expr>
907
+ # to CREATE TABLE <out> AS SELECT <cols> <expr>
908
+ into_expr : sqlglot .exp .Into = statement .args ["into" ].pop ()
909
+ into_table = into_expr .this
910
+
911
+ create = sqlglot .exp .Create (
912
+ this = into_table ,
913
+ kind = "TABLE" ,
914
+ expression = statement ,
915
+ )
916
+ return create
917
+
918
+
876
919
def _sqlglot_lineage_inner (
877
920
sql : sqlglot .exp .ExpOrStr ,
878
921
schema_resolver : SchemaResolverInterface ,
@@ -885,12 +928,9 @@ def _sqlglot_lineage_inner(
885
928
else :
886
929
dialect = get_dialect (default_dialect )
887
930
888
- if is_dialect_instance (dialect , "snowflake" ):
889
- # in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
890
- if default_db :
891
- default_db = default_db .upper ()
892
- if default_schema :
893
- default_schema = default_schema .upper ()
931
+ default_db = _normalize_db_or_schema (default_db , dialect )
932
+ default_schema = _normalize_db_or_schema (default_schema , dialect )
933
+
894
934
if is_dialect_instance (dialect , "redshift" ) and not default_schema :
895
935
# On Redshift, there's no "USE SCHEMA <schema>" command. The default schema
896
936
# is public, and "current schema" is the one at the front of the search path.
@@ -918,6 +958,8 @@ def _sqlglot_lineage_inner(
918
958
# original_statement.sql(pretty=True, dialect=dialect),
919
959
# )
920
960
961
+ statement = _simplify_select_into (statement )
962
+
921
963
# Make sure the tables are resolved with the default db / schema.
922
964
# This only works for Unionable statements. For other types of statements,
923
965
# we have to do it manually afterwards, but that's slightly lower accuracy
0 commit comments