8000 feat: support Partitioned DML by olavloite · Pull Request #541 · googleapis/python-spanner-sqlalchemy · GitHub
[go: up one dir, main page]

Skip to content

feat: support Partitioned DML #541

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions samples/partitioned_dml_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import create_engine, text

from sample_helper import run_sample

# Shows how to use Partitioned DML using SQLAlchemy and Spanner.
def partitioned_dml_sample():
engine = create_engine(
"spanner:///projects/sample-project/"
"instances/sample-instance/"
"databases/sample-database",
echo=True,
)
# Get a connection in auto-commit mode.
# Partitioned DML can only be executed in auto-commit mode, as each
# Partitioned DML transaction can only consist of one statement.
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
# Set the DML mode to PARTITIONED_NON_ATOMIC.
connection.connection.set_autocommit_dml_mode(
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)
# Use a bulk update statement to back-fill a column.
lower_bound_rowcount = connection.execute(
text("update venues set active=true where active is null")
).rowcount
# Partitioned DML returns the lower-bound update count.
print("Updated at least ", lower_bound_rowcount, " venue records")


if __name__ == "__main__":
run_sample(partitioned_dml_sample)
12 changes: 12 additions & 0 deletions test/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import Engine, create_engine
from sqlalchemy.testing.plugin.plugin_base import fixtures
import google.cloud.spanner_v1.types.type as spanner_type
Expand All @@ -35,6 +36,17 @@ def add_result(sql: str, result: ResultSet):
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


def add_update_count(
sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
add_result(sql, result)


def add_select1_result():
result = result_set.ResultSet(
dict(
Expand Down
25 changes: 19 additions & 6 deletions test/mockserver_tests/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata
from google.cloud.spanner_v1 import (
TransactionOptions,
ResultSetMetadata,
ExecuteSqlRequest,
)
from google.protobuf import empty_pb2
import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc
import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc
Expand Down Expand Up @@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet:
return result

def get_result_as_partial_result_sets(
self, sql: str
self, sql: str, started_transaction: transaction.Transaction
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
if len(result.rows) == 0:
partial = result_set.PartialResultSet()
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partials.append(partial)
else:
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
if started_transaction:
partials[0].metadata.transaction = started_transaction
return partials


Expand Down Expand Up @@ -120,9 +126,16 @@ def ExecuteSql(self, request, context):
self._requests.append(request)
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context):
self._requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
started_transaction = None
if not request.transaction.begin == TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
partials = self.mock_spanner.get_result_as_partial_result_sets(
request.sql, started_transaction
)
for result in partials:
yield result

Expand Down
31 changes: 30 additions & 1 deletion test/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# limitations under the License.

from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from sqlalchemy import (
create_engine,
select,
MetaData,
Table,
Column,
Integer,
String,
text,
)
from sqlalchemy.testing import eq_, is_instance_of
from google.cloud.spanner_v1 import (
FixedSizePool,
Expand All @@ -26,6 +36,7 @@
MockServerTestBase,
add_select1_result,
add_result,
add_update_count,
)


Expand Down Expand Up @@ -127,3 +138,21 @@ def test_create_multiple_tables(self):
"\n) PRIMARY KEY (id)",
requests[0].statements[i],
)

def test_partitioned_dml(self):
sql = "UPDATE singers SET checked=true WHERE active = true"
add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": PingingPool(size=10)},
)
# TODO: Support autocommit_dml_mode as a connection variable in execution
# options.
with engine.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as connection:
connection.connection.set_autocommit_dml_mode(
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)
results = connection.execute(text(sql)).rowcount
eq_(100, results)
Loading
0