Skip to content

Commit 080f1ba

Browse files
committed
Switch to upper case identifiers
1 parent 956e15f commit 080f1ba

File tree

4 files changed

+42
-48
lines changed

4 files changed

+42
-48
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
uses: actions/checkout@v2
1919
with:
2020
repository: 'timgraham/django'
21-
ref: 'snowflake-3.2.x'
21+
ref: 'remove-query-quotes'
2222
path: 'django_repo'
2323
- name: Install system packages for Django's Python test dependencies
2424
run: |

django_snowflake/features.py

-26
Original file line numberDiff line numberDiff line change
@@ -204,32 +204,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
204204
'backends.tests.SequenceResetTest.test_generic_relation',
205205
'backends.base.test_operations.SqlFlushTests.test_execute_sql_flush_statements',
206206
},
207-
'This test does not quote a field name in raw SQL as Snowflake requires.': {
208-
'aggregation_regress.tests.AggregationTests.test_annotation',
209-
'aggregation_regress.tests.AggregationTests.test_more_more',
210-
'aggregation_regress.tests.AggregationTests.test_more_more_more',
211-
'annotations.tests.NonAggregateAnnotationTestCase.test_raw_sql_with_inherited_field',
212-
'lookup.tests.LookupTests.test_values',
213-
'lookup.tests.LookupTests.test_values_list',
214-
'expressions.tests.BasicExpressionsTests.test_filtering_on_rawsql_that_is_boolean',
215-
'expressions.tests.BasicExpressionsTests.test_order_by_multiline_sql',
216-
'model_fields.test_booleanfield.BooleanFieldTests.test_return_type',
217-
'queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_list_and_order_by_extra_select', # noqa
218-
'queries.tests.EscapingTests.test_ticket_7302',
219-
'queries.tests.Queries5Tests.test_ordering',
220-
'queries.tests.ValuesQuerysetTests.test_extra_multiple_select_params_values_order_by',
221-
'queries.tests.ValuesQuerysetTests.test_extra_select_params_values_order_in_extra',
222-
'queries.tests.ValuesQuerysetTests.test_extra_values',
223-
'queries.tests.ValuesQuerysetTests.test_extra_values_list',
224-
'queries.tests.ValuesQuerysetTests.test_extra_values_order_multiple',
225-
'queries.tests.ValuesQuerysetTests.test_extra_values_order_twice',
226-
'queries.tests.ValuesQuerysetTests.test_flat_extra_values_list',
227-
'queries.tests.ValuesQuerysetTests.test_named_values_list_with_fields',
228-
'queries.tests.ValuesQuerysetTests.test_named_values_list_without_fields',
229-
'queries.tests.Queries1Tests.test_order_by_rawsql',
230-
'queries.tests.Queries1Tests.test_ticket7098',
231-
'queries.tests.Queries1Tests.test_tickets_7087_12242',
232-
},
233207
"Snowflake prohibits string truncation when using Cast.": {
234208
'db_functions.comparison.test_cast.CastTests.test_cast_to_char_field_with_max_length',
235209
},

django_snowflake/introspection.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,44 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
5353
}
5454

5555
def get_constraints(self, cursor, table_name):
56+
table_name = self.connection.ops.quote_name(table_name)
5657
constraints = {}
5758
# Foreign keys
58-
cursor.execute(f'SHOW IMPORTED KEYS IN TABLE "{table_name}"')
59+
cursor.execute(f'SHOW IMPORTED KEYS IN TABLE {table_name}')
5960
for row in cursor.fetchall():
6061
constraints[row[12]] = {
61-
'columns': [row[8]],
62+
'columns': [self.identifier_converter(row[8])],
6263
'primary_key': False,
6364
'unique': False,
64-
'foreign_key': (row[3], row[4]),
65+
'foreign_key': (self.identifier_converter(row[3]), self.identifier_converter(row[4])),
6566
'check': False,
6667
'index': False,
6768
}
6869
# Primary keys
69-
cursor.execute(f'SHOW PRIMARY KEYS IN TABLE "{table_name}"')
70+
cursor.execute(f'SHOW PRIMARY KEYS IN TABLE {table_name}')
7071
for row in cursor.fetchall():
71-
constraints[row[6]] = {
72-
'columns': [row[4]],
72+
# Add quotes around
73+
constraints['"' + row[6] + '"'] = {
74+
'columns': [self.identifier_converter(row[4])],
7375
'primary_key': True,
7476
'unique': False,
7577
'foreign_key': None,
7678
'check': False,
7779
'index': False,
7880
}
7981
# Unique constraints
80-
cursor.execute(f'SHOW UNIQUE KEYS IN TABLE "{table_name}"')
82+
cursor.execute(f'SHOW UNIQUE KEYS IN TABLE {table_name}')
8183
# The columns of multi-column unique indexes are ordered by row[5].
8284
# Map {constraint_name: [(row[5], column_name), ...] so the columns can
8385
# be sorted for each constraint.
8486
unique_column_orders = {}
8587
for row in cursor.fetchall():
86-
column_name = row[4]
87-
constraint_name = row[6]
88+
column_name = self.identifier_converter(row[4])
89+
# TODO: hack
90+
constraint_name = (
91+
'"' + row[6] + '"' if row[6].startswith("SYS_CONSTRAINT_")
92+
else self.identifier_converter(row[6])
93+
)
8894
if constraint_name in constraints:
8995
# If the constraint name is already present, this is a
9096
# multi-column unique constraint.
@@ -106,16 +112,24 @@ def get_constraints(self, cursor, table_name):
106112
return constraints
107113

108114
def get_primary_key_column(self, cursor, table_name):
109-
pks = [field.name for field in self.get_table_description(cursor, table_name) if field.pk]
115+
pks = [
116+
self.identifier_converter(field.name)
117+
for field in self.get_table_description(cursor, table_name)
118+
if field.pk
119+
]
110120
return pks[0] if pks else None
111121

112122
def get_relations(self, cursor, table_name):
113123
"""
114124
Return a dictionary of {field_name: (field_name_other_table, other_table)}
115125
representing all foreign keys in the given table.
116126
"""
117-
cursor.execute(f'SHOW IMPORTED KEYS IN TABLE "{table_name}"')
118-
return {row[8]: (row[4], row[3]) for row in cursor.fetchall()}
127+
table_name = self.connection.ops.quote_name(table_name)
128+
cursor.execute(f'SHOW IMPORTED KEYS IN TABLE {table_name}')
129+
return {
130+
self.identifier_converter(row[8]): (self.identifier_converter(row[4]), self.identifier_converter(row[3]))
131+
for row in cursor.fetchall()
132+
}
119133

120134
def get_field_type(self, data_type, description):
121135
field_type = super().get_field_type(data_type, description)
@@ -145,22 +159,28 @@ def get_table_description(self, cursor, table_name):
145159
table_info = cursor.fetchall()
146160
return [
147161
FieldInfo(
148-
# name, type_code, display_size, internal_size,
149-
name, get_data_type(data_type), None, get_field_size(data_type),
150-
# precision, scale, null_ok, default,
151-
*get_precision_and_scale(data_type), null == 'Y', default,
152-
# collation, pk,
153-
get_collation(data_type), pk == 'Y',
162+
# name, type_code, display_size,
163+
self.identifier_converter(name), get_data_type(data_type), None,
164+
# internal_size, precision, scale,
165+
get_field_size(data_type), *get_precision_and_scale(data_type),
166+
# null_ok, default, collation, pk,
167+
null == 'Y', default, get_collation(data_type), pk == 'Y',
154168
)
155169
for (
156170
name, data_type, kind, null, default, pk, unique_key, check,
157171
expression, comment, policy_name,
158172
) in table_info
159173
]
160174

175+
def identifier_converter(self, name):
176+
"""Identifier comparison is case insensitive on Snowflake."""
177+
# if name != name.upper():
178+
# return f'"{name}"'
179+
return name.lower()
180+
161181
def get_table_list(self, cursor):
162182
cursor.execute('SHOW TERSE TABLES')
163-
tables = [TableInfo(row[1], 't') for row in cursor.fetchall()]
183+
tables = [TableInfo(self.identifier_converter(row[1]), 't') for row in cursor.fetchall()]
164184
cursor.execute('SHOW TERSE VIEWS')
165-
views = [TableInfo(row[1], 'v') for row in cursor.fetchall()]
185+
views = [TableInfo(self.identifier_converter(row[1]), 'v') for row in cursor.fetchall()]
166186
return tables + views

django_snowflake/operations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def no_limit_value(self):
156156
def quote_name(self, name):
157157
if name.startswith('"') and name.endswith('"'):
158158
return name # Quoting once is enough.
159-
return '"%s"' % name.replace('.', '"."')
159+
return '"%s"' % name.upper().replace('.', '"."')
160160

161161
def regex_lookup(self, lookup_type):
162162
match_option = 'c' if lookup_type == 'regex' else 'i'

0 commit comments

Comments
 (0)