4
4
from collections import defaultdict , namedtuple
5
5
from logging import getLogger
6
6
from pathlib import Path
7
+ from typing import Any , Optional
7
8
8
9
import sqlalchemy as sa
9
10
from packaging .version import Version
@@ -438,9 +439,9 @@ def __new__(cls, name, schema=None, connection=None):
438
439
439
440
def __str__ (self ):
440
441
if self .schema is None :
441
- return self .name
442
+ return RelationKey . _unquote ( self .name )
442
443
else :
443
- return self .schema + "." + self .name
444
+ return RelationKey . _unquote ( self .schema ) + "." + RelationKey . _unquote ( self .name )
444
445
445
446
@staticmethod
446
447
def _unquote (part ):
@@ -654,6 +655,9 @@ def visit_HLLSKETCH(self, type_, **kw):
654
655
class RedshiftIdentifierPreparer (PGIdentifierPreparer ):
655
656
reserved_words = RESERVED_WORDS
656
657
658
+ def quote_schema (self , schema : Any , force : Optional [bool ] = ...) -> str :
659
+ return schema
660
+
657
661
658
662
class RedshiftDialectMixin (DefaultDialect ):
659
663
"""
@@ -670,6 +674,7 @@ class RedshiftDialectMixin(DefaultDialect):
670
674
statement_compiler = RedshiftCompiler
671
675
ddl_compiler = RedshiftDDLCompiler
672
676
preparer = RedshiftIdentifierPreparer
677
+ identifier_preparer = RedshiftIdentifierPreparer
673
678
type_compiler = RedshiftTypeCompiler
674
679
construct_arguments = [
675
680
(sa .schema .Index , {
@@ -793,12 +798,12 @@ def get_check_constraints(self, connection, table_name, schema=None, **kw):
793
798
def get_table_oid (self , connection , table_name , schema = None , ** kw ):
794
799
"""Fetch the oid for schema.table_name.
795
800
Return null if not found (external table does not have table oid)"""
796
- schema_field = '" {schema}" .' .format (schema = schema ) if schema else ""
801
+ schema_field = '{schema}.' .format (schema = schema ) if schema else ""
797
802
798
803
result = connection .execute (
799
804
sa .text (
800
805
"""
801
- select '{schema_field}" {table_name}" '::regclass::oid;
806
+ select '{schema_field}{table_name}'::regclass::oid;
802
807
""" .format (
803
808
schema_field = schema_field ,
804
809
table_name = table_name
@@ -857,8 +862,8 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
857
862
fkey_d = {
858
863
'name' : conname ,
859
864
'constrained_columns' : constrained_columns ,
860
- 'referred_schema' : referred_schema ,
861
- 'referred_table' : referred_table ,
865
+ 'referred_schema' : self . unquote ( referred_schema ) ,
866
+ 'referred_table' : self . unquote ( referred_table ) ,
862
867
'referred_columns' : referred_columns ,
863
868
}
864
869
fkeys .append (fkey_d )
@@ -908,6 +913,16 @@ def get_indexes(self, connection, table_name, schema, **kw):
908
913
"""
909
914
return []
910
915
916
+ @staticmethod
917
+ def unquote (text ):
918
+ if text is None :
919
+ return None
920
+
921
+ if text .startswith ('"' ) and text .endswith ('"' ):
922
+ return text [1 :- 1 ]
923
+
924
+ return text
925
+
911
926
@reflection .cache
912
927
def get_unique_constraints (self , connection , table_name ,
913
928
schema = None , ** kw ):
@@ -977,7 +992,7 @@ def _get_table_or_view_names(self, relkind, connection, schema=None, **kw):
977
992
relation_names = []
978
993
for key , relation in all_relations .items ():
979
994
if key .schema == schema and relation .relkind == relkind :
980
- relation_names .append (key .name )
995
+ relation_names .append (self . unquote ( key .name ) )
981
996
return relation_names
982
997
983
998
def _get_column_info (self , * args , ** kwargs ):
@@ -1110,6 +1125,7 @@ def _get_all_relation_info(self, connection, **kw):
1110
1125
@reflection .cache
1111
1126
def _get_schema_column_info (self , connection , ** kw ):
1112
1127
schema = kw .get ('schema' , None )
1128
+ schema = self .unquote (schema )
1113
1129
schema_clause = (
1114
1130
"AND schema = '{schema}'" .format (schema = schema ) if schema else ""
1115
1131
)
0 commit comments