From 7f6df23e100d36e085b6bef7667a9d3c7c550aab Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sat, 29 Jul 2023 14:43:10 +0000 Subject: [PATCH 1/6] feat: support dml returning --- google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 9fe09140..167a8685 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -40,6 +40,7 @@ ) from sqlalchemy.sql.default_comparator import operator_lookup from sqlalchemy.sql.operators import json_getitem_op +from sqlalchemy.sql import expression from google.cloud.spanner_v1.data_types import JsonObject from google.cloud import spanner_dbapi @@ -343,6 +344,14 @@ def limit_clause(self, select, **kw): text += " OFFSET " + self.process(select._offset_clause, **kw) return text + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_returning_column(stmt, c) + for c in expression._select_iterables(returning_cols) + ] + + return "THEN RETURN " + ", ".join(columns) + class SpannerDDLCompiler(DDLCompiler): """Spanner DDL statements compiler.""" From 58d0d35411f90a7c11d4612c43a72b1111539dae Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sun, 30 Jul 2023 08:01:15 +0000 Subject: [PATCH 2/6] feat: add tests for dml returning --- test/test_suite_13.py | 44 +++++++++++++++++++++++++++++++++++++++++++ test/test_suite_14.py | 44 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/test/test_suite_13.py b/test/test_suite_13.py index 0561de5d..afd558d7 100644 --- a/test/test_suite_13.py +++ b/test/test_suite_13.py @@ -2059,3 +2059,47 @@ def test_create_engine_wo_database(self): engine = create_engine(get_db_url().split("/database")[0]) with engine.connect() as connection: assert connection.connection.database is None + + +class ReturningTest(fixtures.TestBase): + def setUp(self): + self._engine = create_engine(get_db_url(), future=True) + metadata = MetaData() + + self._table = Table( + "returning_test", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(16), nullable=False), + ) + + metadata.create_all(self._engine) + + def test_returning_for_insert_and_update(self): + random_id = random.randint(1, 1000) + with self._engine.begin() as connection: + stmt = ( + self._table.insert() + .values(id=random_id, data="some % value") + .returning(self._table.c.id) + ) + row = connection.execute(stmt).fetchall() + eq_( + row, + [(random_id,)], + ) + connection.commit() + + with self._engine.begin() as connection: + update_text = "some + value" + stmt = ( + self._table.update() + .values(data=update_text) + .where(self._table.c.id == random_id) + .returning(self._table.c.data) + ) + row = connection.execute(stmt).fetchall() + eq_( + row, + [(update_text,)], + ) diff --git a/test/test_suite_14.py b/test/test_suite_14.py index 3ff069b2..ee57096d 100644 --- a/test/test_suite_14.py +++ b/test/test_suite_14.py @@ -2392,3 +2392,47 @@ def test_create_engine_wo_database(self): engine = create_engine(get_db_url().split("/database")[0]) with engine.connect() as connection: assert connection.connection.database is None + + +class ReturningTest(fixtures.TestBase): + def setUp(self): + self._engine = create_engine(get_db_url(), future=True) + metadata = MetaData() + + self._table = Table( + "returning_test", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(16), nullable=False), + ) + + metadata.create_all(self._engine) + + def test_returning_for_insert_and_update(self): + random_id = random.randint(1, 1000) + with self._engine.begin() as connection: + stmt = ( + self._table.insert() + .values(id=random_id, data="some % value") + .returning(self._table.c.id) + ) + row = connection.execute(stmt).fetchall() + eq_( + row, + [(random_id,)], + ) + connection.commit() + + with self._engine.begin() as connection: + update_text = "some + value" + stmt = ( + self._table.update() + .values(data=update_text) + .where(self._table.c.id == random_id) + .returning(self._table.c.data) + ) + row = connection.execute(stmt).fetchall() + eq_( + row, + [(update_text,)], + ) From 5d48c659adcad2316affd7b75be4681d961b27a2 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sun, 30 Jul 2023 08:09:33 +0000 Subject: [PATCH 3/6] feat: update tests --- test/test_suite_13.py | 3 +-- test/test_suite_14.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_suite_13.py b/test/test_suite_13.py index afd558d7..5b0eec12 100644 --- a/test/test_suite_13.py +++ b/test/test_suite_13.py @@ -2063,7 +2063,7 @@ def test_create_engine_wo_database(self): class ReturningTest(fixtures.TestBase): def setUp(self): - self._engine = create_engine(get_db_url(), future=True) + self._engine = create_engine(get_db_url()) metadata = MetaData() self._table = Table( @@ -2088,7 +2088,6 @@ def test_returning_for_insert_and_update(self): row, [(random_id,)], ) - connection.commit() with self._engine.begin() as connection: update_text = "some + value" diff --git a/test/test_suite_14.py b/test/test_suite_14.py index ee57096d..b7c7e2ac 100644 --- a/test/test_suite_14.py +++ b/test/test_suite_14.py @@ -2396,7 +2396,7 @@ def test_create_engine_wo_database(self): class ReturningTest(fixtures.TestBase): def setUp(self): - self._engine = create_engine(get_db_url(), future=True) + self._engine = create_engine(get_db_url()) metadata = MetaData() self._table = Table( @@ -2421,7 +2421,6 @@ def test_returning_for_insert_and_update(self): row, [(random_id,)], ) - connection.commit() with self._engine.begin() as connection: update_text = "some + value" From 091f308005175d705d8fe31028c462a0436ab5ad Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sun, 30 Jul 2023 09:33:23 +0000 Subject: [PATCH 4/6] feat: support dml returning --- google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 167a8685..5e789062 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -346,7 +346,7 @@ def limit_clause(self, select, **kw): def returning_clause(self, stmt, returning_cols): columns = [ - self._label_returning_column(stmt, c) + self._label_select_column(None, c, True, False, {}) for c in expression._select_iterables(returning_cols) ] From fa78127bb8f6406e76a043b97b2f7c262bc3fa6b Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sun, 30 Jul 2023 09:43:01 +0000 Subject: [PATCH 5/6] feat: add test in sqlalchemy 2.0 --- test/test_suite_20.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/test_suite_20.py b/test/test_suite_20.py index b4bf26fa..4664619e 100644 --- a/test/test_suite_20.py +++ b/test/test_suite_20.py @@ -3041,3 +3041,30 @@ def test_create_engine_wo_database(self): engine = create_engine(get_db_url().split("/database")[0]) with engine.connect() as connection: assert connection.connection.database is None + + +class ReturningTest(fixtures.TestBase): + def setUp(self): + self._engine = create_engine(get_db_url()) + metadata = MetaData() + + self._table = Table( + "returning_test", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(16), nullable=False), + ) + + metadata.create_all(self._engine) + time.sleep(1) + + def test_returning_for_insert(self): + with self._engine.connect() as connection: + random_id = random.randint(1, 100000000) + stmt = ( + self._table.insert() + .values(id=random_id, data="some % value") + .returning(self._table.c.id) + ) + row = list(connection.execute(stmt)) + eq_(row[0], random_id) From 197e00b3284231ad155e18e9f33536329c19d387 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sun, 30 Jul 2023 09:57:08 +0000 Subject: [PATCH 6/6] feat: add returning support for sqlalchemy 2.0 --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 2 +- test/test_suite_20.py | 28 +++++++++++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 5e789062..1d2cc5c8 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -344,7 +344,7 @@ def limit_clause(self, select, **kw): text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def returning_clause(self, stmt, returning_cols): + def returning_clause(self, stmt, returning_cols, **kw): columns = [ self._label_select_column(None, c, True, False, {}) for c in expression._select_iterables(returning_cols) diff --git a/test/test_suite_20.py b/test/test_suite_20.py index 4664619e..f1f689e2 100644 --- a/test/test_suite_20.py +++ b/test/test_suite_20.py @@ -3056,15 +3056,31 @@ def setUp(self): ) metadata.create_all(self._engine) - time.sleep(1) - def test_returning_for_insert(self): - with self._engine.connect() as connection: - random_id = random.randint(1, 100000000) + def test_returning_for_insert_and_update(self): + random_id = random.randint(1, 1000) + with self._engine.begin() as connection: stmt = ( self._table.insert() .values(id=random_id, data="some % value") .returning(self._table.c.id) ) - row = list(connection.execute(stmt)) - eq_(row[0], random_id) + row = connection.execute(stmt).fetchall() + eq_( + row, + [(random_id,)], + ) + + with self._engine.begin() as connection: + update_text = "some + value" + stmt = ( + self._table.update() + .values(data=update_text) + .where(self._table.c.id == random_id) + .returning(self._table.c.data) + ) + row = connection.execute(stmt).fetchall() + eq_( + row, + [(update_text,)], + )