diff --git a/README.rst b/README.rst index 927848dc..2a195455 100644 --- a/README.rst +++ b/README.rst @@ -465,7 +465,6 @@ Other limitations ~~~~~~~~~~~~~~~~~ - WITH RECURSIVE statement is not supported. -- Named schemas are not supported. - Temporary tables are not supported. - Numeric type dimensions (scale and precision) are constant. See the `docs <https://cloud.google.com/spanner/docs/data-types#numeric_types>`__. diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 2c4238a4..e5559c65 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler): compound_keywords = _compound_keywords + def __init__(self, *args, **kwargs): + self.tablealiases = {} + super().__init__(*args, **kwargs) + def get_from_hint_text(self, _, text): """Return a hint text. @@ -378,8 +382,11 @@ def limit_clause(self, select, **kw): return text def returning_clause(self, stmt, returning_cols, **kw): + # Set the spanner_is_returning flag which is passed to visit_column. columns = [ - self._label_select_column(None, c, True, False, {}) + self._label_select_column( + None, c, True, False, {"spanner_is_returning": True} + ) for c in expression._select_iterables(returning_cols) ] @@ -391,6 +398,98 @@ def visit_sequence(self, seq, **kw): seq ) + def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs): + """Produces the table name. + + Schema names are not allowed in Spanner SELECT statements. We + need to avoid generating SQL like + + SELECT schema.tbl.id + FROM schema.tbl + + To do so, we alias the table in order to produce SQL like: + + SELECT tbl_1.id, tbl_1.col + FROM schema.tbl AS tbl_1 + + And do similar for UPDATE and DELETE statements. + + This closely mirrors the mssql dialect which also avoids + schema-qualified columns in SELECTs, although the behaviour is + currently behind a deprecated 'legacy_schema_aliasing' flag. + """ + if spanner_aliased is table or self.isinsert: + return super().visit_table(table, **kwargs) + + # Add an alias for schema-qualified tables. + # Tables in the default schema are not aliased and follow the + # standard SQLAlchemy code path. + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, spanner_aliased=table, **kwargs) + else: + return super().visit_table(table, **kwargs) + + def visit_alias(self, alias, **kw): + """Produces alias statements.""" + # translate for schema-qualified table aliases + kw["spanner_aliased"] = alias.element + return super().visit_alias(alias, **kw) + + def visit_column( + self, column, add_to_result_map=None, spanner_is_returning=False, **kw + ): + """Produces column expressions. + + In tandem with visit_table, replaces schema-qualified column + names with column names qualified against an alias. + """ + if column.table is not None and not self.isinsert or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = elements._corresponding_column_or_error(t, column) + if add_to_result_map is not None: + add_to_result_map( + column.name, + column.name, + (column, column.name, column.key), + column.type, + ) + + return super().visit_column(converted, **kw) + if spanner_is_returning: + # Set include_table=False because although table names are + # allowed in RETURNING clauses, schema names are not. We + # can't use the same aliasing trick above that we use with + # other statements, because INSERT statements don't result + # in visit_table calls and INSERT table names can't be + # aliased. Statements like: + # + # INSERT INTO table (id, name) + # SELECT id, name FROM another_table + # THEN RETURN another_table.id + # + # aren't legal, so the columns remain unambiguous when not + # qualified by table name. + kw["include_table"] = False + + return super().visit_column(column, add_to_result_map=add_to_result_map, **kw) + + def _schema_aliased_table(self, table): + """Creates an alias for the table if it is schema-qualified. + + If the table is schema-qualified, returns an alias for the + table and caches the alias for future references to the + table. If the table is not schema-qualified, returns None. + """ + if getattr(table, "schema", None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + class SpannerDDLCompiler(DDLCompiler): """Spanner DDL statements compiler.""" diff --git a/test/mockserver_tests/test_auto_increment.py b/test/mockserver_tests/test_auto_increment.py index 6bc5e2c0..7fa245e8 100644 --- a/test/mockserver_tests/test_auto_increment.py +++ b/test/mockserver_tests/test_auto_increment.py @@ -125,9 +125,7 @@ def test_create_table_with_specific_sequence_kind(self): def test_insert_row(self): from test.mockserver_tests.auto_increment_model import Singer - self.add_insert_result( - "INSERT INTO singers (name) VALUES (@a0) THEN RETURN singers.id" - ) + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") engine = create_engine( "spanner:///projects/p/instances/i/databases/d", connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, diff --git a/test/mockserver_tests/test_basics.py b/test/mockserver_tests/test_basics.py index 6db248d6..e1445829 100644 --- a/test/mockserver_tests/test_basics.py +++ b/test/mockserver_tests/test_basics.py @@ -262,3 +262,53 @@ class Singer(Base): singer.name = "New Name" session.add(singer) session.commit() + + def test_select_table_in_named_schema(self): + class Base(DeclarativeBase): + pass + + class Singer(Base): + __tablename__ = "singers" + __table_args__ = {"schema": "my_schema"} + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + name: Mapped[str] = mapped_column(String) + + query = ( + "SELECT" + " singers_1.id AS my_schema_singers_id," + " singers_1.name AS my_schema_singers_name\n" + "FROM my_schema.singers AS singers_1\n" + "WHERE singers_1.id = @a0\n" + " LIMIT @a1" + ) + add_singer_query_result(query) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + insert = "INSERT INTO my_schema.singers (name) VALUES (@a0) THEN RETURN id" + add_single_result(insert, "id", TypeCode.INT64, [("1",)]) + with Session(engine) as session: + singer = Singer(name="New Name") + session.add(singer) + session.commit() + + update = ( + "UPDATE my_schema.singers AS singers_1 " + "SET name=@a0 " + "WHERE singers_1.id = @a1" + ) + add_update_count(update, 1) + with Session(engine) as session: + singer = session.query(Singer).filter(Singer.id == 1).first() + singer.name = "New Name" + session.add(singer) + session.commit() + + delete = "DELETE FROM my_schema.singers AS singers_1 WHERE singers_1.id = @a0" + add_update_count(delete, 1) + with Session(engine) as session: + singer = session.query(Singer).filter(Singer.id == 1).first() + session.delete(singer) + session.commit() diff --git a/test/mockserver_tests/test_bit_reversed_sequence.py b/test/mockserver_tests/test_bit_reversed_sequence.py index a18bc08e..9e7a81a8 100644 --- a/test/mockserver_tests/test_bit_reversed_sequence.py +++ b/test/mockserver_tests/test_bit_reversed_sequence.py @@ -110,7 +110,7 @@ def test_insert_row(self): add_result( "INSERT INTO singers (id, name) " "VALUES ( GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id), @a0) " - "THEN RETURN singers.id", + "THEN RETURN id", result, ) engine = create_engine( diff --git a/test/system/test_basics.py b/test/system/test_basics.py index e5411988..3001052d 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -25,6 +25,8 @@ Boolean, BIGINT, select, + update, + delete, ) from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column from sqlalchemy.types import REAL @@ -58,6 +60,16 @@ def define_tables(cls, metadata): Column("name", String(20)), ) + with cls.bind.begin() as conn: + conn.execute(text("CREATE SCHEMA IF NOT EXISTS schema")) + Table( + "users", + metadata, + Column("ID", Integer, primary_key=True), + Column("name", String(20)), + schema="schema", + ) + def test_hello_world(self, connection): greeting = connection.execute(text("select 'Hello World'")) eq_("Hello World", greeting.fetchone()[0]) @@ -139,6 +151,12 @@ class User(Base): ID: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(20)) + class SchemaUser(Base): + __tablename__ = "users" + __table_args__ = {"schema": "schema"} + ID: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(20)) + engine = connection.engine with Session(engine) as session: number = Number( @@ -156,3 +174,35 @@ class User(Base): users = session.scalars(statement).all() eq_(1, len(users)) is_true(users[0].ID > 0) + + with Session(engine) as session: + user = SchemaUser(name="SchemaTest") + session.add(user) + session.commit() + + users = session.scalars( + select(SchemaUser).where(SchemaUser.name == "SchemaTest") + ).all() + eq_(1, len(users)) + is_true(users[0].ID > 0) + + session.execute( + update(SchemaUser) + .where(SchemaUser.name == "SchemaTest") + .values(name="NewName") + ) + session.commit() + + users = session.scalars( + select(SchemaUser).where(SchemaUser.name == "NewName") + ).all() + eq_(1, len(users)) + is_true(users[0].ID > 0) + + session.execute(delete(SchemaUser).where(SchemaUser.name == "NewName")) + session.commit() + + users = session.scalars( + select(SchemaUser).where(SchemaUser.name == "NewName") + ).all() + eq_(0, len(users))