From 92cc0b344370046720a73ca80b77652088412040 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Tue, 6 Oct 2020 18:11:35 -0400 Subject: [PATCH 01/33] chore: updated docstrings --- google/cloud/spanner_dbapi/cursor.py | 77 ++++++++++++++++------------ 1 file changed, 44 insertions(+), 33 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 73764b4c26..7ed42a313a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -4,7 +4,7 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -"""Database cursor API.""" +"""Database cursor for Google Cloud Spanner DB-API.""" from google.api_core.exceptions import ( AlreadyExists, @@ -54,11 +54,10 @@ class Cursor: - """ - Database cursor to manage the context of a fetch operation. + """Database cursor to manage the context of a fetch operation. - :type connection: :class:`spanner_dbapi.connection.Connection` - :param connection: Parent connection object for this Cursor. + :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: A DB-API connection to Google Cloud Spanner. """ def __init__(self, connection): @@ -72,14 +71,13 @@ def __init__(self, connection): self.arraysize = 1 def execute(self, sql, args=None): - """ - Abstracts and implements execute SQL statements on Cloud Spanner. - Args: - sql: A SQL statement - *args: variadic argument list - **kwargs: key worded arguments - Returns: - None + """Prepares and executes a Spanner database operation. + + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. """ self._raise_if_closed() @@ -212,6 +210,16 @@ def __exit__(self, etype, value, traceback): @property def description(self): + """Read-only attribute containing a sequence of the following items: + + - ``name`` + - ``type_code`` + - ``display_size`` + - ``internal_size`` + - ``precision`` + - ``scale`` + - ``null_ok`` + """ if not (self._res and self._res.metadata): return None @@ -232,15 +240,16 @@ def description(self): @property def rowcount(self): + """The number of rows produced by the last `.execute()`.""" return self._row_count @property def is_closed(self): """The cursor close indicator. - :rtype: :class:`bool` - :returns: True if this cursor or it's parent connection is closed, False - otherwise. + :rtype: bool + :returns: True if the cursor or the parent connection is closed, + otherwise False. """ return self._is_closed or self._connection.is_closed @@ -257,22 +266,19 @@ def _raise_if_closed(self): raise InterfaceError("cursor is already closed") def close(self): - """Close this cursor. - - The cursor will be unusable from this point forward. - """ + """Closes this Cursor, making it unusable from this point forward.""" self._is_closed = True def executemany(self, operation, seq_of_params): - """ - Execute the given SQL with every parameters set + """Execute the given SQL with every parameters set from the given sequence of parameters. - :type operation: :class:`str` + :type operation: str :param operation: SQL code to execute. - :type seq_of_params: :class:`list` - :param seq_of_params: Sequence of params to run the query with. + :type seq_of_params: list + :param seq_of_params: Sequence of additional parameters to run + the query with. """ self._raise_if_closed() @@ -290,6 +296,8 @@ def __iter__(self): return self._itr def fetchone(self): + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available.""" self._raise_if_closed() try: @@ -298,21 +306,22 @@ def fetchone(self): return None def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ self._raise_if_closed() return list(self.__iter__()) def fetchmany(self, size=None): - """ - Fetch the next set of rows of a query result, returning a sequence of sequences. - An empty sequence is returned when no more rows are available. - - Args: - size: optional integer to determine the maximum number of results to fetch. + """Fetch the next set of rows of a query result, returning a sequence + of sequences. An empty sequence is returned when no more rows are available. + :type size: int + :param size: (Optional) The maximum number of results to fetch. - Raises: - Error if the previous call to .execute*() did not produce any result set + :raises InterfaceError: + if the previous call to .execute*() did not produce any result set or if no call was issued yet. """ self._raise_if_closed() @@ -334,9 +343,11 @@ def lastrowid(self): return None def setinputsizes(sizes): + """A no-op, raising an error if the cursor or connection is closed.""" raise ProgrammingError("Unimplemented") def setoutputsize(size, column=None): + """A no-op, raising an error if the cursor or connection is closed.""" raise ProgrammingError("Unimplemented") def _run_prior_DDL_statements(self): From da4eb61f8be69ef210d8d237b4c255b6d3a469c9 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Tue, 6 Oct 2020 18:32:11 -0400 Subject: [PATCH 02/33] chore: methods rearranged --- google/cloud/spanner_dbapi/cursor.py | 212 +++++++++++++-------------- 1 file changed, 106 insertions(+), 106 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 7ed42a313a..5503d2b4fa 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -70,6 +70,45 @@ def __init__(self, connection): # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 + @property + def description(self): + """Read-only attribute containing a sequence of the following items: + + - ``name`` + - ``type_code`` + - ``display_size`` + - ``internal_size`` + - ``precision`` + - ``scale`` + - ``null_ok`` + """ + if not (self._res and self._res.metadata): + return None + + row_type = self._res.metadata.row_type + columns = [] + for field in row_type.fields: + columns.append( + ColumnInfo( + name=field.name, + type_code=field.type.code, + # Size of the SQL type of the column. + display_size=code_to_display_size.get(field.type.code), + # Client perceived size of the column. + internal_size=field.ByteSize(), + ) + ) + return tuple(columns) + + @property + def rowcount(self): + """The number of rows produced by the last `.execute()`.""" + return self._row_count + + def close(self): + """Closes this Cursor, making it unusable from this point forward.""" + self._is_closed = True + def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. @@ -110,6 +149,73 @@ def execute(self, sql, args=None): except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) + def executemany(self, operation, seq_of_params): + """Execute the given SQL with every parameters set + from the given sequence of parameters. + + :type operation: str + :param operation: SQL code to execute. + + :type seq_of_params: list + :param seq_of_params: Sequence of additional parameters to run + the query with. + """ + self._raise_if_closed() + + for params in seq_of_params: + self.execute(operation, params) + + def fetchone(self): + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available.""" + self._raise_if_closed() + + try: + return next(self) + except StopIteration: + return None + + def fetchmany(self, size=None): + """Fetch the next set of rows of a query result, returning a sequence + of sequences. An empty sequence is returned when no more rows are available. + + :type size: int + :param size: (Optional) The maximum number of results to fetch. + + :raises InterfaceError: + if the previous call to .execute*() did not produce any result set + or if no call was issued yet. + """ + self._raise_if_closed() + + if size is None: + size = self.arraysize + + items = [] + for i in range(size): + try: + items.append(tuple(self.__next__())) + except StopIteration: + break + + return items + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() + + return list(self.__iter__()) + + def setinputsizes(self, sizes): + """A no-op, raising an error if the cursor or connection is closed.""" + pass + + def setoutputsize(self, size, column=None): + """A no-op, raising an error if the cursor or connection is closed.""" + pass + def __handle_update(self, sql, params): self._connection.in_transaction(self.__do_execute_update, sql, params) @@ -208,41 +314,6 @@ def __enter__(self): def __exit__(self, etype, value, traceback): self.close() - @property - def description(self): - """Read-only attribute containing a sequence of the following items: - - - ``name`` - - ``type_code`` - - ``display_size`` - - ``internal_size`` - - ``precision`` - - ``scale`` - - ``null_ok`` - """ - if not (self._res and self._res.metadata): - return None - - row_type = self._res.metadata.row_type - columns = [] - for field in row_type.fields: - columns.append( - ColumnInfo( - name=field.name, - type_code=field.type.code, - # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), - # Client perceived size of the column. - internal_size=field.ByteSize(), - ) - ) - return tuple(columns) - - @property - def rowcount(self): - """The number of rows produced by the last `.execute()`.""" - return self._row_count - @property def is_closed(self): """The cursor close indicator. @@ -265,26 +336,6 @@ def _raise_if_closed(self): if self.is_closed: raise InterfaceError("cursor is already closed") - def close(self): - """Closes this Cursor, making it unusable from this point forward.""" - self._is_closed = True - - def executemany(self, operation, seq_of_params): - """Execute the given SQL with every parameters set - from the given sequence of parameters. - - :type operation: str - :param operation: SQL code to execute. - - :type seq_of_params: list - :param seq_of_params: Sequence of additional parameters to run - the query with. - """ - self._raise_if_closed() - - for params in seq_of_params: - self.execute(operation, params) - def __next__(self): if self._itr is None: raise ProgrammingError("no results to return") @@ -295,61 +346,10 @@ def __iter__(self): raise ProgrammingError("no results to return") return self._itr - def fetchone(self): - """Fetch the next row of a query result set, returning a single - sequence, or None when no more data is available.""" - self._raise_if_closed() - - try: - return next(self) - except StopIteration: - return None - - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as - a sequence of sequences. - """ - self._raise_if_closed() - - return list(self.__iter__()) - - def fetchmany(self, size=None): - """Fetch the next set of rows of a query result, returning a sequence - of sequences. An empty sequence is returned when no more rows are available. - - :type size: int - :param size: (Optional) The maximum number of results to fetch. - - :raises InterfaceError: - if the previous call to .execute*() did not produce any result set - or if no call was issued yet. - """ - self._raise_if_closed() - - if size is None: - size = self.arraysize - - items = [] - for i in range(size): - try: - items.append(tuple(self.__next__())) - except StopIteration: - break - - return items - @property def lastrowid(self): return None - def setinputsizes(sizes): - """A no-op, raising an error if the cursor or connection is closed.""" - raise ProgrammingError("Unimplemented") - - def setoutputsize(size, column=None): - """A no-op, raising an error if the cursor or connection is closed.""" - raise ProgrammingError("Unimplemented") - def _run_prior_DDL_statements(self): return self._connection.run_prior_DDL_statements() From d27ff7125bb8b75dfe06750f9359e04e1e1eedb5 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Tue, 6 Oct 2020 18:52:45 -0400 Subject: [PATCH 03/33] chore: --- google/cloud/spanner_dbapi/cursor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 5503d2b4fa..c856872f50 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -53,7 +53,7 @@ } -class Cursor: +class Cursor(object): """Database cursor to manage the context of a fetch operation. :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` @@ -334,7 +334,7 @@ def _raise_if_closed(self): :raises: :class:`InterfaceError` if this cursor is closed. """ if self.is_closed: - raise InterfaceError("cursor is already closed") + raise InterfaceError("Cursor and/or connection is already closed.") def __next__(self): if self._itr is None: From eaa024f40f84badba7fdec11382188e9a777db23 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Tue, 6 Oct 2020 21:23:22 -0400 Subject: [PATCH 04/33] chore: --- google/cloud/spanner_dbapi/connection.py | 25 ++++++++++++++---------- google/cloud/spanner_dbapi/cursor.py | 8 ++++---- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 869586e363..959671b806 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -14,14 +14,20 @@ ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) -class Connection: - def __init__(self, db_handle): - self._dbhandle = db_handle +class Connection(object): + """DB-API Connection to a Google Cloud Spanner database. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database to which the connection is linked. + """ + def __init__(self, database): + self._database = database self._ddl_statements = [] self.is_closed = False def cursor(self): + """Factory to create a DB-API Cursor.""" self._raise_if_closed() return Cursor(self) @@ -48,15 +54,15 @@ def __handle_update_ddl(self, ddl_statements): """ self._raise_if_closed() # Synchronously wait on the operation's completion. - return self._dbhandle.update_ddl(ddl_statements).result() + return self._database.update_ddl(ddl_statements).result() def read_snapshot(self): self._raise_if_closed() - return self._dbhandle.snapshot() + return self._database.snapshot() def in_transaction(self, fn, *args, **kwargs): self._raise_if_closed() - return self._dbhandle.run_in_transaction(fn, *args, **kwargs) + return self._database.run_in_transaction(fn, *args, **kwargs) def append_ddl_statement(self, ddl_statement): self._raise_if_closed() @@ -90,7 +96,7 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): # hence this method exists to circumvent that limit. self.run_prior_DDL_statements() - with self._dbhandle.snapshot() as snapshot: + with self._database.snapshot() as snapshot: res = snapshot.execute_sql( sql, params=params, param_types=param_types ) @@ -123,7 +129,7 @@ def close(self): The connection will be unusable from this point forward. """ self.rollback() - self.__dbhandle = None + self._database = None self.is_closed = True def commit(self): @@ -132,10 +138,9 @@ def commit(self): self.run_prior_DDL_statements() def rollback(self): + """A no-op, raising an error if the connection is closed.""" self._raise_if_closed() - # TODO: to be added. - def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index c856872f50..5976664723 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -118,11 +118,11 @@ def execute(self, sql, args=None): :type args: list :param args: Additional parameters to supplement the SQL query. """ - self._raise_if_closed() - if not self._connection: raise ProgrammingError("Cursor is not connected to the database") + self._raise_if_closed() + self._res = None # Classify whether this is a read-only SQL statement. @@ -210,11 +210,11 @@ def fetchall(self): def setinputsizes(self, sizes): """A no-op, raising an error if the cursor or connection is closed.""" - pass + self._raise_if_closed() def setoutputsize(self, size, column=None): """A no-op, raising an error if the cursor or connection is closed.""" - pass + self._raise_if_closed() def __handle_update(self, sql, params): self._connection.in_transaction(self.__do_execute_update, sql, params) From 3cde6888d58a3439e6c787bf09ebdbd413170e53 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Tue, 6 Oct 2020 21:34:40 -0400 Subject: [PATCH 05/33] fix: lint --- google/cloud/spanner_dbapi/connection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 959671b806..9e34038bf5 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -20,6 +20,7 @@ class Connection(object): :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to which the connection is linked. """ + def __init__(self, database): self._database = database self._ddl_statements = [] From 75c6f1cabbf4125fc1dd30c2e98904d4a75ed9d5 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Wed, 7 Oct 2020 03:16:50 -0400 Subject: [PATCH 06/33] chore: refactor --- google/cloud/spanner_dbapi/__init__.py | 3 +- google/cloud/spanner_dbapi/connection.py | 184 +++++++++++------------ google/cloud/spanner_dbapi/cursor.py | 163 +++++++++++++------- 3 files changed, 198 insertions(+), 152 deletions(-) diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index 098b0bd786..cf380ae1ed 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -34,9 +34,10 @@ Time, TimeFromTicks, Timestamp, + TimestampStr, TimestampFromTicks, ) -from .version import google_client_info +from .version import google_client_info, DEFAULT_USER_AGENT apilevel = "2.0" # supports DP-API 2.0 level. paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 9e34038bf5..54cd3df7b0 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -4,14 +4,17 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from collections import namedtuple +"""DB-API Connection for the Google Cloud Spanner.""" -from google.cloud import spanner_v1 as spanner +# from google.cloud import spanner_v1 from .cursor import Cursor from .exceptions import InterfaceError -ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) +# from .version import google_client_info +# from google.cloud.spanner_dbapi import Cursor +# from google.cloud.spanner_dbapi import google_client_info +# from google.cloud.spanner_dbapi import InterfaceError class Connection(object): @@ -23,117 +26,32 @@ class Connection(object): def __init__(self, database): self._database = database - self._ddl_statements = [] - + self.ddl_statements = [] self.is_closed = False - def cursor(self): - """Factory to create a DB-API Cursor.""" - self._raise_if_closed() - - return Cursor(self) + @property + def database(self): + return self._database def _raise_if_closed(self): - """Raise an exception if this connection is closed. - - Helper to check the connection state before - running a SQL/DDL/DML query. + """Helper to check the connection state before running a query. + Raises an exception if this connection is closed. - :raises: :class:`InterfaceError` if this connection is closed. + :raises: :class:`InterfaceError`: if this connection is closed. """ if self.is_closed: raise InterfaceError("connection is already closed") - def __handle_update_ddl(self, ddl_statements): - """ - Run the list of Data Definition Language (DDL) statements on the underlying - database. Each DDL statement MUST NOT contain a semicolon. - Args: - ddl_statements: a list of DDL statements, each without a semicolon. - Returns: - google.api_core.operation.Operation.result() - """ - self._raise_if_closed() - # Synchronously wait on the operation's completion. - return self._database.update_ddl(ddl_statements).result() - - def read_snapshot(self): - self._raise_if_closed() - return self._database.snapshot() - - def in_transaction(self, fn, *args, **kwargs): - self._raise_if_closed() - return self._database.run_in_transaction(fn, *args, **kwargs) - - def append_ddl_statement(self, ddl_statement): - self._raise_if_closed() - self._ddl_statements.append(ddl_statement) - - def run_prior_DDL_statements(self): - self._raise_if_closed() - - if not self._ddl_statements: - return - - ddl_statements = self._ddl_statements - self._ddl_statements = [] - - return self.__handle_update_ddl(ddl_statements) - - def list_tables(self): - return self.run_sql_in_snapshot( - """ - SELECT - t.table_name - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ - ) - - def run_sql_in_snapshot(self, sql, params=None, param_types=None): - # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions - # hence this method exists to circumvent that limit. - self.run_prior_DDL_statements() - - with self._database.snapshot() as snapshot: - res = snapshot.execute_sql( - sql, params=params, param_types=param_types - ) - return list(res) - - def get_table_column_schema(self, table_name): - rows = self.run_sql_in_snapshot( - """SELECT - COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA = '' - AND - TABLE_NAME = @table_name""", - params={"table_name": table_name}, - param_types={"table_name": spanner.param_types.STRING}, - ) - - column_details = {} - for column_name, is_nullable, spanner_type in rows: - column_details[column_name] = ColumnDetails( - null_ok=is_nullable == "YES", spanner_type=spanner_type - ) - return column_details - def close(self): - """Close this connection. + """Closes this connection. The connection will be unusable from this point forward. """ - self.rollback() self._database = None self.is_closed = True def commit(self): + """Commits any pending transaction to the database.""" self._raise_if_closed() self.run_prior_DDL_statements() @@ -142,9 +60,81 @@ def rollback(self): """A no-op, raising an error if the connection is closed.""" self._raise_if_closed() + def cursor(self): + """Factory to create a DB-API Cursor.""" + self._raise_if_closed() + + return Cursor(self) + + def run_prior_DDL_statements(self): + self._raise_if_closed() + + if not self.ddl_statements: + return + + ddl_statements = self.ddl_statements + self.ddl_statements = [] + + return self._database.update_ddl(ddl_statements).result() + def __enter__(self): return self def __exit__(self, etype, value, traceback): self.commit() self.close() + + +# def connect( +# instance_id, database_id, project=None, credentials=None, user_agent=None +# ): +# """Creates a connection to a Google Cloud Spanner database. +# +# :type instance_id: str +# :param instance_id: ID of the instance to connect to. +# +# :type database_id: str +# :param database_id: The name of the database to connect to. +# +# :type project: str +# :param project: (Optional) The ID of the project which owns the +# instances, tables and data. If not provided, will +# attempt to determine from the environment. +# +# :type credentials: :class:`~google.auth.credentials.Credentials` +# :param credentials: (Optional) The authorization credentials to attach to +# requests. These credentials identify this application +# to the service. If none are specified, the client will +# attempt to ascertain the credentials from the +# environment. +# +# :type user_agent: str +# :param user_agent: (Optional) Prefix to the user agent header. +# +# :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` +# :returns: Connection object associated with the given Google Cloud Spanner +# resource. +# +# :raises: :class:`ValueError` in case of given instance/database +# doesn't exist. +# """ +# client_info = google_client_info(user_agent) +# +# client = spanner_v1.Client( +# project=project, +# credentials=credentials, +# # client_info=google_client_info(user_agent), +# client_info=client_info, +# ) +# +# instance = client.instance(instance_id) +# if not instance.exists(): +# raise ValueError("instance '%s' does not exist." % instance_id) +# +# database = instance.database( +# database_id, pool=spanner_v1.pool.BurstyPool() +# ) +# if not database.exists(): +# raise ValueError("database '%s' does not exist." % database_id) +# +# return Connection(database) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 5976664723..6f49808955 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -13,6 +13,9 @@ InvalidArgument, ) from google.cloud.spanner_v1 import param_types +from collections import namedtuple + +from google.cloud import spanner_v1 as spanner from .exceptions import ( IntegrityError, @@ -30,10 +33,12 @@ parse_insert, sql_pyformat_args_to_spanner, ) + from .utils import PeekIterator _UNSET_COUNT = -1 +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) # This table maps spanner_types to Spanner's data type sizes as per # https://cloud.google.com/spanner/docs/data-types#allowable-types @@ -70,6 +75,24 @@ def __init__(self, connection): # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 + @property + def connection(self): + return self._connection + + @property + def lastrowid(self): + return None + + @property + def is_closed(self): + """The cursor close indicator. + + :rtype: bool + :returns: True if the cursor or the parent connection is closed, + otherwise False. + """ + return self._is_closed or self._connection.is_closed + @property def description(self): """Read-only attribute containing a sequence of the following items: @@ -105,10 +128,39 @@ def rowcount(self): """The number of rows produced by the last `.execute()`.""" return self._row_count + def _raise_if_closed(self): + """Raise an exception if this cursor is closed. + + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + + :raises: :class:`InterfaceError` if this cursor is closed. + """ + if self.is_closed: + raise InterfaceError("Cursor and/or connection is already closed.") + + def callproc(self, procname, args=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + def close(self): """Closes this Cursor, making it unusable from this point forward.""" self._is_closed = True + def __do_execute_update(self, transaction, sql, params, param_types=None): + sql = ensure_where_clause(sql) + sql, params = sql_pyformat_args_to_spanner(sql, params) + + res = transaction.execute_update( + sql, params=params, param_types=get_param_types(params) + ) + self._itr = None + if type(res) == int: + self._row_count = res + + return res + def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. @@ -129,19 +181,23 @@ def execute(self, sql, args=None): try: classification = classify_stmt(sql) if classification == STMT_DDL: - self._connection.append_ddl_statement(sql) + self._connection.ddl_statements.append(sql) return # For every other operation, we've got to ensure that # any prior DDL statements were run. - self._run_prior_DDL_statements() + # self._run_prior_DDL_statements() + self._connection.run_prior_DDL_statements() if classification == STMT_NON_UPDATING: self.__handle_DQL(sql, args or None) elif classification == STMT_INSERT: self.__handle_insert(sql, args or None) else: - self.__handle_update(sql, args or None) + # self.__handle_update(sql, args or None) + self._connection.database.run_in_transaction( + self.__do_execute_update, sql, args or None + ) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: @@ -208,6 +264,10 @@ def fetchall(self): return list(self.__iter__()) + def nextset(self): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + def setinputsizes(self, sizes): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() @@ -216,21 +276,10 @@ def setoutputsize(self, size, column=None): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() - def __handle_update(self, sql, params): - self._connection.in_transaction(self.__do_execute_update, sql, params) - - def __do_execute_update(self, transaction, sql, params, param_types=None): - sql = ensure_where_clause(sql) - sql, params = sql_pyformat_args_to_spanner(sql, params) - - res = transaction.execute_update( - sql, params=params, param_types=get_param_types(params) - ) - self._itr = None - if type(res) == int: - self._row_count = res - - return res + # def __handle_update(self, sql, params): + # self._connection.database.run_in_transaction( + # self.__do_execute_update, sql, params + # ) def __handle_insert(self, sql, params): parts = parse_insert(sql, params) @@ -251,14 +300,14 @@ def __handle_insert(self, sql, params): if parts.get("homogenous"): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. - return self._connection.in_transaction( + return self._connection.database.run_in_transaction( self.__do_execute_insert_homogenous, parts ) else: # All the other cases that are esoteric and need # transaction.execute_sql sql_params_list = parts.get("sql_params_list") - return self._connection.in_transaction( + return self._connection.database.run_in_transaction( self.__do_execute_insert_heterogenous, sql_params_list ) @@ -281,7 +330,7 @@ def __do_execute_insert_homogenous(self, transaction, parts): return transaction.insert(table, columns, values) def __handle_DQL(self, sql, params): - with self._connection.read_snapshot() as snapshot: + with self._connection.database.snapshot() as snapshot: # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql sql, params = sql_pyformat_args_to_spanner(sql, params) @@ -314,28 +363,6 @@ def __enter__(self): def __exit__(self, etype, value, traceback): self.close() - @property - def is_closed(self): - """The cursor close indicator. - - :rtype: bool - :returns: True if the cursor or the parent connection is closed, - otherwise False. - """ - return self._is_closed or self._connection.is_closed - - def _raise_if_closed(self): - """Raise an exception if this cursor is closed. - - Helper to check this cursor's state before running a - SQL/DDL/DML query. If the parent connection is - already closed it also raises an error. - - :raises: :class:`InterfaceError` if this cursor is closed. - """ - if self.is_closed: - raise InterfaceError("Cursor and/or connection is already closed.") - def __next__(self): if self._itr is None: raise ProgrammingError("no results to return") @@ -346,21 +373,49 @@ def __iter__(self): raise ProgrammingError("no results to return") return self._itr - @property - def lastrowid(self): - return None - - def _run_prior_DDL_statements(self): - return self._connection.run_prior_DDL_statements() - def list_tables(self): - return self._connection.list_tables() + return self.run_sql_in_snapshot( + """ + SELECT + t.table_name + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ + ) + + def run_sql_in_snapshot(self, sql, params=None, param_types=None): + # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions + # hence this method exists to circumvent that limit. + self._connection.run_prior_DDL_statements() - def run_sql_in_snapshot(self, sql): - return self._connection.run_sql_in_snapshot(sql) + with self._connection.database.snapshot() as snapshot: + res = snapshot.execute_sql( + sql, params=params, param_types=param_types + ) + return list(res) def get_table_column_schema(self, table_name): - return self._connection.get_table_column_schema(table_name) + rows = self.run_sql_in_snapshot( + """SELECT + COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = '' + AND + TABLE_NAME = @table_name""", + params={"table_name": table_name}, + param_types={"table_name": spanner.param_types.STRING}, + ) + + column_details = {} + for column_name, is_nullable, spanner_type in rows: + column_details[column_name] = ColumnDetails( + null_ok=is_nullable == "YES", spanner_type=spanner_type + ) + return column_details class ColumnInfo: From 4383d1cd17fcd609819589e32377588bee55ae59 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Thu, 8 Oct 2020 01:56:44 -0400 Subject: [PATCH 07/33] chore: refactor --- google/cloud/spanner_dbapi/__init__.py | 116 +++++++--------------- google/cloud/spanner_dbapi/connection.py | 118 +++++++++++------------ google/cloud/spanner_dbapi/version.py | 19 +--- tests/spanner_dbapi/test_connect.py | 110 --------------------- tests/spanner_dbapi/test_connection.py | 63 +++++++++++- tests/spanner_dbapi/test_cursor.py | 35 +++---- tests/spanner_dbapi/test_globals.py | 8 +- tests/spanner_dbapi/test_version.py | 36 ------- 8 files changed, 174 insertions(+), 331 deletions(-) delete mode 100644 tests/spanner_dbapi/test_connect.py delete mode 100644 tests/spanner_dbapi/test_version.py diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index cf380ae1ed..7695c0058f 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -6,38 +6,39 @@ """Connection-based DB API for Cloud Spanner.""" -from google.cloud import spanner_v1 - -from .connection import Connection -from .exceptions import ( - DatabaseError, - DataError, - Error, - IntegrityError, - InterfaceError, - InternalError, - NotSupportedError, - OperationalError, - ProgrammingError, - Warning, -) -from .parse_utils import get_param_types -from .types import ( - BINARY, - DATETIME, - NUMBER, - ROWID, - STRING, - Binary, - Date, - DateFromTicks, - Time, - TimeFromTicks, - Timestamp, - TimestampStr, - TimestampFromTicks, -) -from .version import google_client_info, DEFAULT_USER_AGENT +from google.cloud.spanner_dbapi.connection import Connection +from google.cloud.spanner_dbapi.connection import connect + +from google.cloud.spanner_dbapi.cursor import Cursor + +from google.cloud.spanner_dbapi.exceptions import DatabaseError +from google.cloud.spanner_dbapi.exceptions import DataError +from google.cloud.spanner_dbapi.exceptions import Error +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import InternalError +from google.cloud.spanner_dbapi.exceptions import NotSupportedError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError +from google.cloud.spanner_dbapi.exceptions import Warning + +from google.cloud.spanner_dbapi.parse_utils import get_param_types + +from google.cloud.spanner_dbapi.types import BINARY +from google.cloud.spanner_dbapi.types import DATETIME +from google.cloud.spanner_dbapi.types import NUMBER +from google.cloud.spanner_dbapi.types import ROWID +from google.cloud.spanner_dbapi.types import STRING +from google.cloud.spanner_dbapi.types import Binary +from google.cloud.spanner_dbapi.types import Date +from google.cloud.spanner_dbapi.types import DateFromTicks +from google.cloud.spanner_dbapi.types import Time +from google.cloud.spanner_dbapi.types import TimeFromTicks +from google.cloud.spanner_dbapi.types import Timestamp +from google.cloud.spanner_dbapi.types import TimestampStr +from google.cloud.spanner_dbapi.types import TimestampFromTicks + +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT apilevel = "2.0" # supports DP-API 2.0 level. paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. @@ -49,56 +50,10 @@ threadsafety = 1 -def connect( - instance_id, database_id, project=None, credentials=None, user_agent=None -): - """ - Create a connection to Cloud Spanner database. - - :type instance_id: :class:`str` - :param instance_id: ID of the instance to connect to. - - :type database_id: :class:`str` - :param database_id: The name of the database to connect to. - - :type project: :class:`str` - :param project: (Optional) The ID of the project which owns the - instances, tables and data. If not provided, will - attempt to determine from the environment. - - :type credentials: :class:`google.auth.credentials.Credentials` - :param credentials: (Optional) The authorization credentials to attach to requests. - These credentials identify this application to the service. - If none are specified, the client will attempt to ascertain - the credentials from the environment. - - :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` - :returns: Connection object associated with the given Cloud Spanner resource. - - :raises: :class:`ValueError` in case of given instance/database - doesn't exist. - """ - client = spanner_v1.Client( - project=project, - credentials=credentials, - client_info=google_client_info(user_agent), - ) - - instance = client.instance(instance_id) - if not instance.exists(): - raise ValueError("instance '%s' does not exist." % instance_id) - - database = instance.database( - database_id, pool=spanner_v1.pool.BurstyPool() - ) - if not database.exists(): - raise ValueError("database '%s' does not exist." % database_id) - - return Connection(database) - - __all__ = [ "Connection", + "connect", + "Cursor", "DatabaseError", "DataError", "Error", @@ -111,7 +66,6 @@ def connect( "Warning", "DEFAULT_USER_AGENT", "apilevel", - "connect", "paramstyle", "threadsafety", "get_param_types", diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 54cd3df7b0..08b5294ecd 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -6,15 +6,13 @@ """DB-API Connection for the Google Cloud Spanner.""" -# from google.cloud import spanner_v1 +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud import spanner_v1 as spanner -from .cursor import Cursor -from .exceptions import InterfaceError - -# from .version import google_client_info -# from google.cloud.spanner_dbapi import Cursor -# from google.cloud.spanner_dbapi import google_client_info -# from google.cloud.spanner_dbapi import InterfaceError +from google.cloud.spanner_dbapi.cursor import Cursor +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT +from google.cloud.spanner_dbapi.version import PY_VERSION class Connection(object): @@ -85,56 +83,54 @@ def __exit__(self, etype, value, traceback): self.close() -# def connect( -# instance_id, database_id, project=None, credentials=None, user_agent=None -# ): -# """Creates a connection to a Google Cloud Spanner database. -# -# :type instance_id: str -# :param instance_id: ID of the instance to connect to. -# -# :type database_id: str -# :param database_id: The name of the database to connect to. -# -# :type project: str -# :param project: (Optional) The ID of the project which owns the -# instances, tables and data. If not provided, will -# attempt to determine from the environment. -# -# :type credentials: :class:`~google.auth.credentials.Credentials` -# :param credentials: (Optional) The authorization credentials to attach to -# requests. These credentials identify this application -# to the service. If none are specified, the client will -# attempt to ascertain the credentials from the -# environment. -# -# :type user_agent: str -# :param user_agent: (Optional) Prefix to the user agent header. -# -# :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` -# :returns: Connection object associated with the given Google Cloud Spanner -# resource. -# -# :raises: :class:`ValueError` in case of given instance/database -# doesn't exist. -# """ -# client_info = google_client_info(user_agent) -# -# client = spanner_v1.Client( -# project=project, -# credentials=credentials, -# # client_info=google_client_info(user_agent), -# client_info=client_info, -# ) -# -# instance = client.instance(instance_id) -# if not instance.exists(): -# raise ValueError("instance '%s' does not exist." % instance_id) -# -# database = instance.database( -# database_id, pool=spanner_v1.pool.BurstyPool() -# ) -# if not database.exists(): -# raise ValueError("database '%s' does not exist." % database_id) -# -# return Connection(database) +def connect( + instance_id, database_id, project=None, credentials=None, user_agent=None +): + """Creates a connection to a Google Cloud Spanner database. + + :type instance_id: str + :param instance_id: ID of the instance to connect to. + + :type database_id: str + :param database_id: The name of the database to connect to. + + :type project: str + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. + + :type credentials: :class:`~google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to + requests. These credentials identify this application + to the service. If none are specified, the client will + attempt to ascertain the credentials from the + environment. + + :type user_agent: str + :param user_agent: (Optional) Prefix to the user agent header. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Google Cloud Spanner + resource. + + :raises: :class:`ValueError` in case of given instance/database + doesn't exist. + """ + + client_info = ClientInfo( + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, + ) + + client = spanner.Client( + project=project, credentials=credentials, client_info=client_info, + ) + + instance = client.instance(instance_id) + if not instance.exists(): + raise ValueError("instance '%s' does not exist." % instance_id) + + database = instance.database(database_id, pool=spanner.pool.BurstyPool()) + if not database.exists(): + raise ValueError("database '%s' does not exist." % database_id) + + return Connection(database) diff --git a/google/cloud/spanner_dbapi/version.py b/google/cloud/spanner_dbapi/version.py index 563d1b4354..88d8f7cdaf 100644 --- a/google/cloud/spanner_dbapi/version.py +++ b/google/cloud/spanner_dbapi/version.py @@ -4,23 +4,8 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import sys - -from google.api_core.gapic_v1.client_info import ClientInfo +import platform +PY_VERSION = platform.python_version() VERSION = "2.2.0a1" DEFAULT_USER_AGENT = "django_spanner/" + VERSION - -vers = sys.version_info - - -def google_client_info(user_agent=None): - """ - Return a google.api_core.gapic_v1.client_info.ClientInfo - containg the user_agent and python_version for this library - """ - - return ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, - python_version="%d.%d.%d" % (vers.major, vers.minor, vers.micro or 0), - ) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py deleted file mode 100644 index 260d3a0993..0000000000 --- a/tests/spanner_dbapi/test_connect.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""connect() module function unit tests.""" - -import unittest -from unittest import mock - -import google.auth.credentials -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud.spanner_dbapi import connect, Connection - - -def _make_credentials(): - class _CredentialsWithScopes( - google.auth.credentials.Credentials, google.auth.credentials.Scoped - ): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - -class Test_connect(unittest.TestCase): - def test_connect(self): - PROJECT = "test-project" - USER_AGENT = "user-agent" - CREDENTIALS = _make_credentials() - CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) - - with mock.patch( - "google.cloud.spanner_dbapi.spanner_v1.Client" - ) as client_mock: - with mock.patch( - "google.cloud.spanner_dbapi.google_client_info", - return_value=CLIENT_INFO, - ) as client_info_mock: - - connection = connect( - "test-instance", - "test-database", - PROJECT, - CREDENTIALS, - USER_AGENT, - ) - - self.assertIsInstance(connection, Connection) - client_info_mock.assert_called_once_with(USER_AGENT) - - client_mock.assert_called_once_with( - project=PROJECT, - credentials=CREDENTIALS, - client_info=CLIENT_INFO, - ) - - def test_instance_not_found(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() - - def test_database_not_found(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() - - def test_connect_instance_id(self): - INSTANCE = "test-instance" - - with mock.patch( - "google.cloud.spanner_v1.client.Client.instance" - ) as instance_mock: - connection = connect(INSTANCE, "test-database") - - instance_mock.assert_called_once_with(INSTANCE) - - self.assertIsInstance(connection, Connection) - - def test_connect_database_id(self): - DATABASE = "test-database" - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connection = connect("test-instance", DATABASE) - - database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) - - self.assertIsInstance(connection, Connection) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index e7cd3f361f..603a3711d3 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -9,11 +9,30 @@ import unittest from unittest import mock -from google.cloud.spanner_dbapi import connect, InterfaceError + +def _make_credentials(): + from google.auth import credentials + + class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): + pass + + return mock.Mock(spec=_CredentialsWithScopes) class TestConnection(unittest.TestCase): + + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + + def _get_client_info(self): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo(user_agent=self.USER_AGENT) + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -30,3 +49,45 @@ def test_close(self): with self.assertRaises(InterfaceError): connection.cursor() + + def test_db_connect(self): + from google.cloud.spanner_dbapi import Connection, connect + + with mock.patch("google.cloud.spanner_v1.Client"): + with mock.patch( + # "google.cloud.spanner_dbapi.version.google_client_info", + "google.api_core.gapic_v1.client_info.ClientInfo", + return_value=self._get_client_info(), + ): + connection = connect( + "test-instance", + "test-database", + self.PROJECT, + self.CREDENTIALS, + self.USER_AGENT, + ) + self.assertIsInstance(connection, Connection) + + def test_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=False, + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_database_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=False, + ): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 673a95d3e5..f6eb46c697 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -9,32 +9,11 @@ import unittest from unittest import mock -from google.cloud.spanner_dbapi import connect, InterfaceError -from google.cloud.spanner_dbapi.cursor import ColumnInfo - class TestCursor(unittest.TestCase): def test_close(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - self.assertFalse(cursor.is_closed) + from google.cloud.spanner_dbapi import connect, InterfaceError - cursor.close() - - self.assertTrue(cursor.is_closed) - with self.assertRaises(InterfaceError): - cursor.execute("SELECT * FROM database") - - def test_connection_closed(self): with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -48,13 +27,15 @@ def test_connection_closed(self): cursor = connection.cursor() self.assertFalse(cursor.is_closed) - connection.close() + cursor.close() self.assertTrue(cursor.is_closed) with self.assertRaises(InterfaceError): cursor.execute("SELECT * FROM database") def test_executemany_on_closed_cursor(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -74,6 +55,8 @@ def test_executemany_on_closed_cursor(self): ) def test_executemany(self): + from google.cloud.spanner_dbapi import connect + operation = """SELECT * FROM table1 WHERE "col1" = @a1""" params_seq = ((1,), (2,)) @@ -100,6 +83,8 @@ def test_executemany(self): class TestColumns(unittest.TestCase): def test_ctor(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + name = "col-name" type_code = 8 display_size = 5 @@ -139,6 +124,8 @@ def test_ctor(self): ) def test___get_item__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + fields = ("col-name", 8, 5, 10, 3, None, False) cols = ColumnInfo(*fields) @@ -146,6 +133,8 @@ def test___get_item__(self): self.assertEqual(cols[i], fields[i]) def test___str__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) self.assertEqual( diff --git a/tests/spanner_dbapi/test_globals.py b/tests/spanner_dbapi/test_globals.py index 7c3e0396a9..3b702f7be7 100644 --- a/tests/spanner_dbapi/test_globals.py +++ b/tests/spanner_dbapi/test_globals.py @@ -6,11 +6,15 @@ from unittest import TestCase -from google.cloud.spanner_dbapi import apilevel, paramstyle, threadsafety - class DBAPIGlobalsTests(TestCase): def test_apilevel(self): + from google.cloud.spanner_dbapi import ( + apilevel, + paramstyle, + threadsafety, + ) + self.assertEqual(apilevel, "2.0", "We implement PEP-0249 version 2.0") self.assertEqual(paramstyle, "format", "Cloud Spanner uses @param") self.assertEqual( diff --git a/tests/spanner_dbapi/test_version.py b/tests/spanner_dbapi/test_version.py deleted file mode 100644 index 9dfed1f55f..0000000000 --- a/tests/spanner_dbapi/test_version.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import sys -from unittest import TestCase - -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud.spanner_dbapi.version import ( - DEFAULT_USER_AGENT, - google_client_info, -) - -vers = sys.version_info - - -class VersionUtils(TestCase): - def test_google_client_info_default_useragent(self): - got = google_client_info().to_grpc_metadata() - want = ClientInfo( - user_agent=DEFAULT_USER_AGENT, - python_version="%d.%d.%d" - % (vers.major, vers.minor, vers.micro or 0), - ).to_grpc_metadata() - self.assertEqual(got, want) - - def test_google_client_info_custom_useragent(self): - got = google_client_info("custom-user-agent").to_grpc_metadata() - want = ClientInfo( - user_agent="custom-user-agent", - python_version="%d.%d.%d" - % (vers.major, vers.minor, vers.micro or 0), - ).to_grpc_metadata() - self.assertEqual(got, want) From a5d6d30b1eef899667be7b3f38199901f7525630 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Thu, 8 Oct 2020 02:24:24 -0400 Subject: [PATCH 08/33] fix: conflicts in `connection.py` --- google/cloud/spanner_dbapi/connection.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 08b5294ecd..f93dd6502d 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -22,14 +22,12 @@ class Connection(object): :param database: The database to which the connection is linked. """ - def __init__(self, database): - self._database = database - self.ddl_statements = [] + def __init__(self, instance, database): + self.instance = instance + self.database = database self.is_closed = False - @property - def database(self): - return self._database + self.ddl_statements = [] def _raise_if_closed(self): """Helper to check the connection state before running a query. @@ -45,7 +43,7 @@ def close(self): The connection will be unusable from this point forward. """ - self._database = None + self.database = None self.is_closed = True def commit(self): @@ -73,7 +71,7 @@ def run_prior_DDL_statements(self): ddl_statements = self.ddl_statements self.ddl_statements = [] - return self._database.update_ddl(ddl_statements).result() + return self.database.update_ddl(ddl_statements).result() def __enter__(self): return self @@ -133,4 +131,4 @@ def connect( if not database.exists(): raise ValueError("database '%s' does not exist." % database_id) - return Connection(database) + return Connection(instance, database) From 26935824ed2aa42688a9cfa6ad74af2923cdf5aa Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Thu, 8 Oct 2020 22:33:19 -0400 Subject: [PATCH 09/33] test: coverage for `connection.py` --- google/cloud/spanner_dbapi/connection.py | 10 ++-- tests/spanner_dbapi/test_connection.py | 73 +++++++++++++++++++++--- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index f93dd6502d..1f6ae0b96b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -65,13 +65,11 @@ def cursor(self): def run_prior_DDL_statements(self): self._raise_if_closed() - if not self.ddl_statements: - return + if self.ddl_statements: + ddl_statements = self.ddl_statements + self.ddl_statements = [] - ddl_statements = self.ddl_statements - self.ddl_statements = [] - - return self.database.update_ddl(ddl_statements).result() + return self.database.update_ddl(ddl_statements).result() def __enter__(self): return self diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index 603a3711d3..d4afc04073 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -4,7 +4,7 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -"""Connection() class unit tests.""" +"""Cloud Spanner DB-API Connection class unit tests.""" import unittest from unittest import mock @@ -22,6 +22,8 @@ class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): class TestConnection(unittest.TestCase): PROJECT = "test-project" + INSTANCE = 'test-instance' + DATABASE = 'test-database' USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() @@ -50,25 +52,82 @@ def test_close(self): with self.assertRaises(InterfaceError): connection.cursor() - def test_db_connect(self): + def test_commit(self): + from google.cloud.spanner_dbapi import Connection, InterfaceError + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_prior_DDL_statements" + ) as run_ddl_mock: + connection.commit() + run_ddl_mock.assert_called_once() + + connection.is_closed = True + + with self.assertRaises(InterfaceError): + connection.commit() + + def test_rollback(self): + from google.cloud.spanner_dbapi import Connection, InterfaceError + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._raise_if_closed" + ) as check_closed_mock: + connection.rollback() + check_closed_mock.assert_called_once() + + def test_run_prior_DDL_statements(self): + from google.cloud.spanner_dbapi import Connection, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + autospec=True, + ) as mock_database: + connection = Connection(self.INSTANCE, mock_database) + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_not_called() + + connection.ddl_statements = ['ddl'] + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_called_once() + + connection.is_closed = True + + with self.assertRaises(InterfaceError): + connection.run_prior_DDL_statements() + + def test_context(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + with connection as conn: + self.assertEqual(conn, connection) + + self.assertTrue(connection.is_closed) + + def test_connect(self): from google.cloud.spanner_dbapi import Connection, connect with mock.patch("google.cloud.spanner_v1.Client"): with mock.patch( - # "google.cloud.spanner_dbapi.version.google_client_info", "google.api_core.gapic_v1.client_info.ClientInfo", return_value=self._get_client_info(), ): connection = connect( - "test-instance", - "test-database", + self.INSTANCE, + self.DATABASE, self.PROJECT, self.CREDENTIALS, self.USER_AGENT, ) self.assertIsInstance(connection, Connection) - def test_instance_not_found(self): + def test_connect_instance_not_found(self): from google.cloud.spanner_dbapi import connect with mock.patch( @@ -78,7 +137,7 @@ def test_instance_not_found(self): with self.assertRaises(ValueError): connect("test-instance", "test-database") - def test_database_not_found(self): + def test_connect_database_not_found(self): from google.cloud.spanner_dbapi import connect with mock.patch( From dd382392fce69e126cfbfbeca96e137b2b00c99f Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 9 Oct 2020 20:24:30 -0400 Subject: [PATCH 10/33] test: [WIP] `cursor.py` coverage --- google/cloud/spanner_dbapi/_helpers.py | 159 +++++++++++++++ google/cloud/spanner_dbapi/cursor.py | 260 +++++-------------------- noxfile.py | 4 +- tests/spanner_dbapi/test_connection.py | 11 +- tests/spanner_dbapi/test_cursor.py | 220 +++++++++++++++++++++ tox.ini | 0 6 files changed, 437 insertions(+), 217 deletions(-) create mode 100644 google/cloud/spanner_dbapi/_helpers.py delete mode 100644 tox.ini diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py new file mode 100644 index 0000000000..d1f51e9523 --- /dev/null +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -0,0 +1,159 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import parse_insert +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner +from google.cloud.spanner_v1 import param_types + + +SQL_LIST_TABLES = """ + SELECT + t.table_name + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ + +SQL_GET_TABLE_COLUMN_SCHEMA = """SELECT + COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = '' + AND + TABLE_NAME = @table_name + """ + +# This table maps spanner_types to Spanner's data type sizes as per +# https://cloud.google.com/spanner/docs/data-types#allowable-types +# It is used to map `display_size` to a known type for Cursor.description +# after a row fetch. +# Since ResultMetadata +# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata +# does not send back the actual size, we have to lookup the respective size. +# Some fields' sizes are dependent upon the dynamic data hence aren't sent back +# by Cloud Spanner. +code_to_display_size = { + param_types.BOOL.code: 1, + param_types.DATE.code: 4, + param_types.FLOAT64.code: 8, + param_types.INT64.code: 8, + param_types.TIMESTAMP.code: 12, +} + + +def execute_insert_heterogenous(transaction, sql_params_list): + for sql, params in sql_params_list: + sql, params = sql_pyformat_args_to_spanner(sql, params) + param_types = get_param_types(params) + res = transaction.execute_sql( + sql, params=params, param_types=param_types + ) + # TODO: File a bug with Cloud Spanner and the Python client maintainers + # about a lost commit when res isn't read from. + _ = list(res) + + +def execute_insert_homogenous(transaction, parts): + # Perform an insert in one shot. + table = parts.get("table") + columns = parts.get("columns") + values = parts.get("values") + return transaction.insert(table, columns, values) + + +def handle_insert(connection, sql, params): + parts = parse_insert(sql, params) + + # The split between the two styles exists because: + # in the common case of multiple values being passed + # with simple pyformat arguments, + # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) + # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] + # we can take advantage of a single RPC with: + # transaction.insert(table, columns, values) + # instead of invoking: + # with transaction: + # for sql, params in sql_params_list: + # transaction.execute_sql(sql, params, param_types) + # which invokes more RPCs and is more costly. + + if parts.get("homogenous"): + # The common case of multiple values being passed in + # non-complex pyformat args and need to be uploaded in one RPC. + return connection.database.run_in_transaction( + execute_insert_homogenous, parts + ) + else: + # All the other cases that are esoteric and need + # transaction.execute_sql + sql_params_list = parts.get("sql_params_list") + return connection.database.run_in_transaction( + execute_insert_heterogenous, sql_params_list + ) + + +class ColumnInfo: + """Row column description object.""" + + def __init__( + self, + name, + type_code, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ): + self.name = name + self.type_code = type_code + self.display_size = display_size + self.internal_size = internal_size + self.precision = precision + self.scale = scale + self.null_ok = null_ok + + self.fields = ( + self.name, + self.type_code, + self.display_size, + self.internal_size, + self.precision, + self.scale, + self.null_ok, + ) + + def __repr__(self): + return self.__str__() + + def __getitem__(self, index): + return self.fields[index] + + def __str__(self): + str_repr = ", ".join( + filter( + lambda part: part is not None, + [ + "name='%s'" % self.name, + "type_code=%d" % self.type_code, + "display_size=%d" % self.display_size + if self.display_size + else None, + "internal_size=%d" % self.internal_size + if self.internal_size + else None, + "precision='%s'" % self.precision + if self.precision + else None, + "scale='%s'" % self.scale if self.scale else None, + "null_ok='%s'" % self.null_ok if self.null_ok else None, + ], + ) + ) + return "ColumnInfo(%s)" % str_repr diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6f49808955..b768c92da6 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -6,33 +6,26 @@ """Database cursor for Google Cloud Spanner DB-API.""" -from google.api_core.exceptions import ( - AlreadyExists, - FailedPrecondition, - InternalServerError, - InvalidArgument, -) -from google.cloud.spanner_v1 import param_types +from google.api_core.exceptions import AlreadyExists +from google.api_core.exceptions import FailedPrecondition +from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import InvalidArgument + from collections import namedtuple from google.cloud import spanner_v1 as spanner -from .exceptions import ( - IntegrityError, - InterfaceError, - OperationalError, - ProgrammingError, -) -from .parse_utils import ( - STMT_DDL, - STMT_INSERT, - STMT_NON_UPDATING, - classify_stmt, - ensure_where_clause, - get_param_types, - parse_insert, - sql_pyformat_args_to_spanner, -) +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError + +from google.cloud.spanner_dbapi import _helpers +from google.cloud.spanner_dbapi._helpers import ColumnInfo +from google.cloud.spanner_dbapi._helpers import code_to_display_size + +from google.cloud.spanner_dbapi import parse_utils +from google.cloud.spanner_dbapi.parse_utils import get_param_types from .utils import PeekIterator @@ -40,23 +33,6 @@ ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) -# This table maps spanner_types to Spanner's data type sizes as per -# https://cloud.google.com/spanner/docs/data-types#allowable-types -# It is used to map `display_size` to a known type for Cursor.description -# after a row fetch. -# Since ResultMetadata -# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata -# does not send back the actual size, we have to lookup the respective size. -# Some fields' sizes are dependent upon the dynamic data hence aren't sent back -# by Cloud Spanner. -code_to_display_size = { - param_types.BOOL.code: 1, - param_types.DATE.code: 4, - param_types.FLOAT64.code: 8, - param_types.INT64.code: 8, - param_types.TIMESTAMP.code: 12, -} - class Cursor(object): """Database cursor to manage the context of a fetch operation. @@ -67,7 +43,7 @@ class Cursor(object): def __init__(self, connection): self._itr = None - self._res = None + self._result_set = None self._row_count = _UNSET_COUNT self._connection = connection self._is_closed = False @@ -79,10 +55,6 @@ def __init__(self, connection): def connection(self): return self._connection - @property - def lastrowid(self): - return None - @property def is_closed(self): """The cursor close indicator. @@ -105,22 +77,23 @@ def description(self): - ``scale`` - ``null_ok`` """ - if not (self._res and self._res.metadata): + if not (self._result_set and self._result_set.metadata): return None - row_type = self._res.metadata.row_type + row_type = self._result_set.metadata.row_type columns = [] + for field in row_type.fields: - columns.append( - ColumnInfo( - name=field.name, - type_code=field.type.code, - # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), - # Client perceived size of the column. - internal_size=field.ByteSize(), - ) + column_info = ColumnInfo( + name=field.name, + type_code=field.type.code, + # Size of the SQL type of the column. + display_size=code_to_display_size.get(field.type.code), + # Client perceived size of the column. + internal_size=field.ByteSize(), ) + columns.append(column_info) + return tuple(columns) @property @@ -148,18 +121,18 @@ def close(self): """Closes this Cursor, making it unusable from this point forward.""" self._is_closed = True - def __do_execute_update(self, transaction, sql, params, param_types=None): - sql = ensure_where_clause(sql) - sql, params = sql_pyformat_args_to_spanner(sql, params) + def _do_execute_update(self, transaction, sql, params, param_types=None): + sql = parse_utils.ensure_where_clause(sql) + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) - res = transaction.execute_update( + result = transaction.execute_update( sql, params=params, param_types=get_param_types(params) ) self._itr = None - if type(res) == int: - self._row_count = res + if type(result) == int: + self._row_count = result - return res + return result def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. @@ -175,12 +148,12 @@ def execute(self, sql, args=None): self._raise_if_closed() - self._res = None + self._result_set = None # Classify whether this is a read-only SQL statement. try: - classification = classify_stmt(sql) - if classification == STMT_DDL: + classification = parse_utils.classify_stmt(sql) + if classification == parse_utils.STMT_DDL: self._connection.ddl_statements.append(sql) return @@ -189,14 +162,13 @@ def execute(self, sql, args=None): # self._run_prior_DDL_statements() self._connection.run_prior_DDL_statements() - if classification == STMT_NON_UPDATING: - self.__handle_DQL(sql, args or None) - elif classification == STMT_INSERT: - self.__handle_insert(sql, args or None) + if classification == parse_utils.STMT_NON_UPDATING: + self._handle_DQL(sql, args or None) + elif classification == parse_utils.STMT_INSERT: + _helpers.handle_insert(sql, args or None) else: - # self.__handle_update(sql, args or None) self._connection.database.run_in_transaction( - self.__do_execute_update, sql, args or None + self._do_execute_update, sql, args or None ) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) @@ -276,64 +248,11 @@ def setoutputsize(self, size, column=None): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() - # def __handle_update(self, sql, params): - # self._connection.database.run_in_transaction( - # self.__do_execute_update, sql, params - # ) - - def __handle_insert(self, sql, params): - parts = parse_insert(sql, params) - - # The split between the two styles exists because: - # in the common case of multiple values being passed - # with simple pyformat arguments, - # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) - # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] - # we can take advantage of a single RPC with: - # transaction.insert(table, columns, values) - # instead of invoking: - # with transaction: - # for sql, params in sql_params_list: - # transaction.execute_sql(sql, params, param_types) - # which invokes more RPCs and is more costly. - - if parts.get("homogenous"): - # The common case of multiple values being passed in - # non-complex pyformat args and need to be uploaded in one RPC. - return self._connection.database.run_in_transaction( - self.__do_execute_insert_homogenous, parts - ) - else: - # All the other cases that are esoteric and need - # transaction.execute_sql - sql_params_list = parts.get("sql_params_list") - return self._connection.database.run_in_transaction( - self.__do_execute_insert_heterogenous, sql_params_list - ) - - def __do_execute_insert_heterogenous(self, transaction, sql_params_list): - for sql, params in sql_params_list: - sql, params = sql_pyformat_args_to_spanner(sql, params) - param_types = get_param_types(params) - res = transaction.execute_sql( - sql, params=params, param_types=param_types - ) - # TODO: File a bug with Cloud Spanner and the Python client maintainers - # about a lost commit when res isn't read from. - _ = list(res) - - def __do_execute_insert_homogenous(self, transaction, parts): - # Perform an insert in one shot. - table = parts.get("table") - columns = parts.get("columns") - values = parts.get("values") - return transaction.insert(table, columns, values) - - def __handle_DQL(self, sql, params): + def _handle_DQL(self, sql, params): with self._connection.database.snapshot() as snapshot: # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql - sql, params = sql_pyformat_args_to_spanner(sql, params) + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) res = snapshot.execute_sql( sql, params=params, param_types=get_param_types(params) ) @@ -349,10 +268,10 @@ def __handle_DQL(self, sql, params): # are for .fetchone() with those that would result in # many items returns a RuntimeError if .fetchone() is # invoked and vice versa. - self._res = res - # Read the first element so that StreamedResult can + self._result_set = res + # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. See issue #155. - self._itr = PeekIterator(self._res) + self._itr = PeekIterator(self._result_set) # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. self._row_count = _UNSET_COUNT @@ -374,16 +293,7 @@ def __iter__(self): return self._itr def list_tables(self): - return self.run_sql_in_snapshot( - """ - SELECT - t.table_name - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ - ) + return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) def run_sql_in_snapshot(self, sql, params=None, param_types=None): # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions @@ -398,14 +308,7 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): def get_table_column_schema(self, table_name): rows = self.run_sql_in_snapshot( - """SELECT - COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA = '' - AND - TABLE_NAME = @table_name""", + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, params={"table_name": table_name}, param_types={"table_name": spanner.param_types.STRING}, ) @@ -416,64 +319,3 @@ def get_table_column_schema(self, table_name): null_ok=is_nullable == "YES", spanner_type=spanner_type ) return column_details - - -class ColumnInfo: - """Row column description object.""" - - def __init__( - self, - name, - type_code, - display_size=None, - internal_size=None, - precision=None, - scale=None, - null_ok=False, - ): - self.name = name - self.type_code = type_code - self.display_size = display_size - self.internal_size = internal_size - self.precision = precision - self.scale = scale - self.null_ok = null_ok - - self.fields = ( - self.name, - self.type_code, - self.display_size, - self.internal_size, - self.precision, - self.scale, - self.null_ok, - ) - - def __repr__(self): - return self.__str__() - - def __getitem__(self, index): - return self.fields[index] - - def __str__(self): - str_repr = ", ".join( - filter( - lambda part: part is not None, - [ - "name='%s'" % self.name, - "type_code=%d" % self.type_code, - "display_size=%d" % self.display_size - if self.display_size - else None, - "internal_size=%d" % self.internal_size - if self.internal_size - else None, - "precision='%s'" % self.precision - if self.precision - else None, - "scale='%s'" % self.scale if self.scale else None, - "null_ok='%s'" % self.null_ok if self.null_ok else None, - ], - ) - ) - return "ColumnInfo(%s)" % str_repr diff --git a/noxfile.py b/noxfile.py index 9b7dfcd53f..c5d31e208d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -69,13 +69,13 @@ def default(session): session.run( "py.test", "--quiet", - "--cov=django_spanner", + # "--cov=django_spanner", "--cov=google.cloud", "--cov=tests.spanner_dbapi", "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=0", + "--cov-fail-under=90", os.path.join("tests", "spanner_dbapi"), *session.posargs ) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index d4afc04073..951025296b 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -22,8 +22,8 @@ class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): class TestConnection(unittest.TestCase): PROJECT = "test-project" - INSTANCE = 'test-instance' - DATABASE = 'test-database' + INSTANCE = "test-instance" + DATABASE = "test-database" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() @@ -69,7 +69,7 @@ def test_commit(self): connection.commit() def test_rollback(self): - from google.cloud.spanner_dbapi import Connection, InterfaceError + from google.cloud.spanner_dbapi import Connection connection = Connection(self.INSTANCE, self.DATABASE) @@ -83,15 +83,14 @@ def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError with mock.patch( - "google.cloud.spanner_v1.database.Database", - autospec=True, + "google.cloud.spanner_v1.database.Database", autospec=True, ) as mock_database: connection = Connection(self.INSTANCE, mock_database) connection.run_prior_DDL_statements() mock_database.update_ddl.assert_not_called() - connection.ddl_statements = ['ddl'] + connection.ddl_statements = ["ddl"] connection.run_prior_DDL_statements() mock_database.update_ddl.assert_called_once() diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index f6eb46c697..556fce7b91 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -11,6 +11,226 @@ class TestCursor(unittest.TestCase): + + INSTANCE = "test-instance" + DATABASE = "test-database" + + def test_property_connection(self): + from google.cloud.spanner_dbapi import Connection, Cursor + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + self.assertEqual(cursor.connection, connection) + + def test_property_description(self): + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi._helpers import ColumnInfo + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + self.assertIsNone(cursor.description) + cursor._result_set = res_set = mock.MagicMock() + res_set.metadata.row_type.fields = [mock.MagicMock()] + self.assertIsNotNone(cursor.description) + self.assertIsInstance(cursor.description[0], ColumnInfo) + + def test_property_rowcount(self): + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + self.assertEqual(cursor.rowcount, _UNSET_COUNT) + + def test_callproc(self): + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.exceptions import InterfaceError + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + cursor._is_closed = True + with self.assertRaises(InterfaceError): + cursor.callproc(procname=None) + + def test_do_execute_update(self): + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + transaction = mock.MagicMock() + + def run_helper(ret_value): + transaction.execute_update.return_value = ret_value + res = cursor._do_execute_update( + transaction=transaction, sql="sql", params=None, + ) + return res + + expected = "good" + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + expected = 1234 + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, expected) + + def test_execute_programming_error(self): + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + cursor._connection = None + with self.assertRaises(ProgrammingError): + cursor.execute(sql="") + + def test_execute_attribute_error(self): + from google.cloud.spanner_dbapi import Connection, Cursor + + connection = Connection(self.INSTANCE, self.DATABASE) + cursor = Cursor(connection) + + with self.assertRaises(AttributeError): + cursor.execute(sql="") + + def test_execute_statement(self): + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi import parse_utils + + connection = Connection(self.INSTANCE, mock.MagicMock()) + cursor = Cursor(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_DDL, + ) as mock_classify_stmt: + cursor.execute(sql="sql") + mock_classify_stmt.assert_called() + self.assertEqual(cursor._connection.ddl_statements, ["sql"]) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_NON_UPDATING, + ): + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", + return_value=parse_utils.STMT_NON_UPDATING, + ) as mock_handle_ddl: + cursor.execute(sql="sql") + mock_handle_ddl.assert_called() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_INSERT, + ): + with mock.patch( + "google.cloud.spanner_dbapi._helpers.handle_insert", + return_value=parse_utils.STMT_INSERT, + ) as mock_handle_insert: + cursor.execute(sql="sql") + mock_handle_insert.assert_called() + + def test_execute_integrity_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.exceptions import IntegrityError + + connection = Connection(self.INSTANCE, mock.MagicMock()) + cursor = Cursor(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.AlreadyExists("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.FailedPrecondition("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + def test_execute_invalid_argument(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = Connection(self.INSTANCE, mock.MagicMock()) + cursor = Cursor(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InvalidArgument("message"), + ): + with self.assertRaises(ProgrammingError): + cursor.execute(sql="sql") + + def test_execute_internal_server_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi import Connection, Cursor + from google.cloud.spanner_dbapi.exceptions import OperationalError + + connection = Connection(self.INSTANCE, mock.MagicMock()) + cursor = Cursor(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InternalServerError("message"), + ): + with self.assertRaises(OperationalError): + cursor.execute(sql="sql") + + def test_fetchone(self): + pass + + def test_fetchmany(self): + pass + + def test_fetchall(self): + pass + + def test_nextset(self): + pass + + def test_setinputsizes(self): + pass + + def test_setoutputsize(self): + pass + + def test_handle_insert(self): + pass + + def test_do_execute_insert_heterogenous(self): + pass + + def test_do_execute_insert_homogenous(self): + pass + + def test_handle_ddl(self): + pass + + def test_context(self): + pass + + def test_next(self): + pass + + def test_iter(self): + pass + + def test_list_tables(self): + pass + + def test_run_sql_in_snapshot(self): + pass + + def test_get_table_column_schema(self): + pass + def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError diff --git a/tox.ini b/tox.ini deleted file mode 100644 index e69de29bb2..0000000000 From 4640c462c82810fce454bc40b94e5ee571ff279b Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 08:51:21 -0400 Subject: [PATCH 11/33] fix: unittest assertions --- tests/spanner_dbapi/test_connection.py | 9 +++++---- tests/spanner_dbapi/test_cursor.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index 951025296b..e6dcfb6263 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -61,7 +61,7 @@ def test_commit(self): "google.cloud.spanner_dbapi.connection.Connection.run_prior_DDL_statements" ) as run_ddl_mock: connection.commit() - run_ddl_mock.assert_called_once() + run_ddl_mock.assert_called_once_with() connection.is_closed = True @@ -77,7 +77,7 @@ def test_rollback(self): "google.cloud.spanner_dbapi.connection.Connection._raise_if_closed" ) as check_closed_mock: connection.rollback() - check_closed_mock.assert_called_once() + check_closed_mock.assert_called_once_with() def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError @@ -90,10 +90,11 @@ def test_run_prior_DDL_statements(self): connection.run_prior_DDL_statements() mock_database.update_ddl.assert_not_called() - connection.ddl_statements = ["ddl"] + ddl = ['ddl'] + connection.ddl_statements = ddl connection.run_prior_DDL_statements() - mock_database.update_ddl.assert_called_once() + mock_database.update_ddl.assert_called_once_with(ddl) connection.is_closed = True diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 556fce7b91..f2b9bf3fc6 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -7,7 +7,7 @@ """Cursor() class unit tests.""" import unittest -from unittest import mock +import mock class TestCursor(unittest.TestCase): @@ -105,9 +105,10 @@ def test_execute_statement(self): "google.cloud.spanner_dbapi.parse_utils.classify_stmt", return_value=parse_utils.STMT_DDL, ) as mock_classify_stmt: - cursor.execute(sql="sql") - mock_classify_stmt.assert_called() - self.assertEqual(cursor._connection.ddl_statements, ["sql"]) + sql = 'sql' + cursor.execute(sql=sql) + mock_classify_stmt.assert_called_once_with(sql) + self.assertEqual(cursor._connection.ddl_statements, [sql]) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", From 023ea9062b548aedceb54df0431aa09be2b81107 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 09:04:53 -0400 Subject: [PATCH 12/33] chore: unit tests refactored --- tests/spanner_dbapi/test__helpers.py | 84 ++++++++++++++++++++++++++ tests/spanner_dbapi/test_connection.py | 2 +- tests/spanner_dbapi/test_cursor.py | 64 +------------------- 3 files changed, 86 insertions(+), 64 deletions(-) create mode 100644 tests/spanner_dbapi/test__helpers.py diff --git a/tests/spanner_dbapi/test__helpers.py b/tests/spanner_dbapi/test__helpers.py new file mode 100644 index 0000000000..7732a12ca4 --- /dev/null +++ b/tests/spanner_dbapi/test__helpers.py @@ -0,0 +1,84 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import unittest + +# from unittest import mock + + +class TestHelpers(unittest.TestCase): + def test_execute_insert_heterogenous(self): + pass + + def test_execute_insert_homogenous(self): + pass + + def test_handle_insert(self): + pass + + +class TestColumnInfo(unittest.TestCase): + def test_ctor(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + name = "col-name" + type_code = 8 + display_size = 5 + internal_size = 10 + precision = 3 + scale = None + null_ok = False + + cols = ColumnInfo( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ) + + self.assertEqual(cols.name, name) + self.assertEqual(cols.type_code, type_code) + self.assertEqual(cols.display_size, display_size) + self.assertEqual(cols.internal_size, internal_size) + self.assertEqual(cols.precision, precision) + self.assertEqual(cols.scale, scale) + self.assertEqual(cols.null_ok, null_ok) + self.assertEqual( + cols.fields, + ( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ), + ) + + def test___get_item__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + fields = ("col-name", 8, 5, 10, 3, None, False) + cols = ColumnInfo(*fields) + + for i in range(0, 7): + self.assertEqual(cols[i], fields[i]) + + def test___str__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) + + self.assertEqual( + str(cols), + "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", + ) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index e6dcfb6263..865db1cb61 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -90,7 +90,7 @@ def test_run_prior_DDL_statements(self): connection.run_prior_DDL_statements() mock_database.update_ddl.assert_not_called() - ddl = ['ddl'] + ddl = ["ddl"] connection.ddl_statements = ddl connection.run_prior_DDL_statements() diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index f2b9bf3fc6..6676f8ffb2 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -105,7 +105,7 @@ def test_execute_statement(self): "google.cloud.spanner_dbapi.parse_utils.classify_stmt", return_value=parse_utils.STMT_DDL, ) as mock_classify_stmt: - sql = 'sql' + sql = "sql" cursor.execute(sql=sql) mock_classify_stmt.assert_called_once_with(sql) self.assertEqual(cursor._connection.ddl_statements, [sql]) @@ -300,65 +300,3 @@ def test_executemany(self): execute_mock.assert_has_calls( (mock.call(operation, (1,)), mock.call(operation, (2,))) ) - - -class TestColumns(unittest.TestCase): - def test_ctor(self): - from google.cloud.spanner_dbapi.cursor import ColumnInfo - - name = "col-name" - type_code = 8 - display_size = 5 - internal_size = 10 - precision = 3 - scale = None - null_ok = False - - cols = ColumnInfo( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ) - - self.assertEqual(cols.name, name) - self.assertEqual(cols.type_code, type_code) - self.assertEqual(cols.display_size, display_size) - self.assertEqual(cols.internal_size, internal_size) - self.assertEqual(cols.precision, precision) - self.assertEqual(cols.scale, scale) - self.assertEqual(cols.null_ok, null_ok) - self.assertEqual( - cols.fields, - ( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ), - ) - - def test___get_item__(self): - from google.cloud.spanner_dbapi.cursor import ColumnInfo - - fields = ("col-name", 8, 5, 10, 3, None, False) - cols = ColumnInfo(*fields) - - for i in range(0, 7): - self.assertEqual(cols[i], fields[i]) - - def test___str__(self): - from google.cloud.spanner_dbapi.cursor import ColumnInfo - - cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) - - self.assertEqual( - str(cols), - "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", - ) From 9abee91e8801e6055055214c5e234f48865f2fd9 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 10:21:18 -0400 Subject: [PATCH 13/33] chore: unit test refactored --- noxfile.py | 4 ++-- tests/unit/__init__.py | 0 tests/{ => unit}/spanner_dbapi/__init__.py | 0 tests/{ => unit}/spanner_dbapi/test__helpers.py | 0 tests/{ => unit}/spanner_dbapi/test_connection.py | 0 tests/{ => unit}/spanner_dbapi/test_cursor.py | 0 tests/{ => unit}/spanner_dbapi/test_globals.py | 0 tests/{ => unit}/spanner_dbapi/test_parse_utils.py | 0 tests/{ => unit}/spanner_dbapi/test_parser.py | 0 tests/{ => unit}/spanner_dbapi/test_types.py | 0 tests/{ => unit}/spanner_dbapi/test_utils.py | 0 11 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/spanner_dbapi/__init__.py (100%) rename tests/{ => unit}/spanner_dbapi/test__helpers.py (100%) rename tests/{ => unit}/spanner_dbapi/test_connection.py (100%) rename tests/{ => unit}/spanner_dbapi/test_cursor.py (100%) rename tests/{ => unit}/spanner_dbapi/test_globals.py (100%) rename tests/{ => unit}/spanner_dbapi/test_parse_utils.py (100%) rename tests/{ => unit}/spanner_dbapi/test_parser.py (100%) rename tests/{ => unit}/spanner_dbapi/test_types.py (100%) rename tests/{ => unit}/spanner_dbapi/test_utils.py (100%) diff --git a/noxfile.py b/noxfile.py index c5d31e208d..3cfebd9358 100644 --- a/noxfile.py +++ b/noxfile.py @@ -71,12 +71,12 @@ def default(session): "--quiet", # "--cov=django_spanner", "--cov=google.cloud", - "--cov=tests.spanner_dbapi", + "--cov=tests.unit", "--cov-append", "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=90", - os.path.join("tests", "spanner_dbapi"), + os.path.join("tests", "unit"), *session.posargs ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/spanner_dbapi/__init__.py b/tests/unit/spanner_dbapi/__init__.py similarity index 100% rename from tests/spanner_dbapi/__init__.py rename to tests/unit/spanner_dbapi/__init__.py diff --git a/tests/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py similarity index 100% rename from tests/spanner_dbapi/test__helpers.py rename to tests/unit/spanner_dbapi/test__helpers.py diff --git a/tests/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py similarity index 100% rename from tests/spanner_dbapi/test_connection.py rename to tests/unit/spanner_dbapi/test_connection.py diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py similarity index 100% rename from tests/spanner_dbapi/test_cursor.py rename to tests/unit/spanner_dbapi/test_cursor.py diff --git a/tests/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py similarity index 100% rename from tests/spanner_dbapi/test_globals.py rename to tests/unit/spanner_dbapi/test_globals.py diff --git a/tests/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py similarity index 100% rename from tests/spanner_dbapi/test_parse_utils.py rename to tests/unit/spanner_dbapi/test_parse_utils.py diff --git a/tests/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py similarity index 100% rename from tests/spanner_dbapi/test_parser.py rename to tests/unit/spanner_dbapi/test_parser.py diff --git a/tests/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py similarity index 100% rename from tests/spanner_dbapi/test_types.py rename to tests/unit/spanner_dbapi/test_types.py diff --git a/tests/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py similarity index 100% rename from tests/spanner_dbapi/test_utils.py rename to tests/unit/spanner_dbapi/test_utils.py From 7dc3875a7a8df8c17c2fdb6cdc60ecfb4b3f22ed Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 10:24:39 -0400 Subject: [PATCH 14/33] fix: missing arg --- google/cloud/spanner_dbapi/cursor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index b768c92da6..044f980b8f 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -165,7 +165,7 @@ def execute(self, sql, args=None): if classification == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif classification == parse_utils.STMT_INSERT: - _helpers.handle_insert(sql, args or None) + _helpers.handle_insert(self._connection, sql, args or None) else: self._connection.database.run_in_transaction( self._do_execute_update, sql, args or None From 02b3bd4d0b05d8c894124a6ad9b90c0c65336107 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 20:00:01 -0400 Subject: [PATCH 15/33] test: unit test coverage for `cursor.py' --- tests/unit/spanner_dbapi/test_cursor.py | 343 ++++++++++++++++-------- 1 file changed, 236 insertions(+), 107 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 6676f8ffb2..e3bbcb673f 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -7,7 +7,8 @@ """Cursor() class unit tests.""" import unittest -import mock + +from unittest import mock class TestCursor(unittest.TestCase): @@ -15,19 +16,30 @@ class TestCursor(unittest.TestCase): INSTANCE = "test-instance" DATABASE = "test-database" - def test_property_connection(self): - from google.cloud.spanner_dbapi import Connection, Cursor + def _get_target_class(self): + from google.cloud.spanner_dbapi import Cursor + + return Cursor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def _make_connection(self, *args, **kwargs): + from google.cloud.spanner_dbapi import Connection - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + return Connection(*args, **kwargs) + + def test_property_connection(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) self.assertEqual(cursor.connection, connection) def test_property_description(self): - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi._helpers import ColumnInfo - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertIsNone(cursor.description) cursor._result_set = res_set = mock.MagicMock() res_set.metadata.row_type.fields = [mock.MagicMock()] @@ -35,29 +47,48 @@ def test_property_description(self): self.assertIsInstance(cursor.description[0], ColumnInfo) def test_property_rowcount(self): - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) self.assertEqual(cursor.rowcount, _UNSET_COUNT) def test_callproc(self): - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.exceptions import InterfaceError - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) cursor._is_closed = True with self.assertRaises(InterfaceError): cursor.callproc(procname=None) + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect(self.INSTANCE, self.DATABASE) + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + def test_do_execute_update(self): - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) transaction = mock.MagicMock() def run_helper(ret_value): @@ -76,30 +107,26 @@ def run_helper(ret_value): self.assertEqual(cursor._row_count, expected) def test_execute_programming_error(self): - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.exceptions import ProgrammingError - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) cursor._connection = None with self.assertRaises(ProgrammingError): cursor.execute(sql="") def test_execute_attribute_error(self): - from google.cloud.spanner_dbapi import Connection, Cursor - - connection = Connection(self.INSTANCE, self.DATABASE) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) with self.assertRaises(AttributeError): cursor.execute(sql="") def test_execute_statement(self): - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi import parse_utils - connection = Connection(self.INSTANCE, mock.MagicMock()) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", @@ -134,11 +161,10 @@ def test_execute_statement(self): def test_execute_integrity_error(self): from google.api_core import exceptions - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.exceptions import IntegrityError - connection = Connection(self.INSTANCE, mock.MagicMock()) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", @@ -156,11 +182,10 @@ def test_execute_integrity_error(self): def test_execute_invalid_argument(self): from google.api_core import exceptions - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.exceptions import ProgrammingError - connection = Connection(self.INSTANCE, mock.MagicMock()) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", @@ -171,11 +196,10 @@ def test_execute_invalid_argument(self): def test_execute_internal_server_error(self): from google.api_core import exceptions - from google.cloud.spanner_dbapi import Connection, Cursor from google.cloud.spanner_dbapi.exceptions import OperationalError - connection = Connection(self.INSTANCE, mock.MagicMock()) - cursor = Cursor(connection) + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", @@ -184,78 +208,9 @@ def test_execute_internal_server_error(self): with self.assertRaises(OperationalError): cursor.execute(sql="sql") - def test_fetchone(self): - pass - - def test_fetchmany(self): - pass - - def test_fetchall(self): - pass - - def test_nextset(self): - pass - - def test_setinputsizes(self): - pass - - def test_setoutputsize(self): - pass - - def test_handle_insert(self): - pass - - def test_do_execute_insert_heterogenous(self): - pass - - def test_do_execute_insert_homogenous(self): - pass - - def test_handle_ddl(self): - pass - - def test_context(self): - pass - - def test_next(self): - pass - - def test_iter(self): - pass - - def test_list_tables(self): - pass - - def test_run_sql_in_snapshot(self): - pass - - def test_get_table_column_schema(self): - pass - - def test_close(self): - from google.cloud.spanner_dbapi import connect, InterfaceError - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - self.assertFalse(cursor.is_closed) - - cursor.close() - - self.assertTrue(cursor.is_closed) - with self.assertRaises(InterfaceError): - cursor.execute("SELECT * FROM database") - def test_executemany_on_closed_cursor(self): - from google.cloud.spanner_dbapi import connect, InterfaceError + from google.cloud.spanner_dbapi import InterfaceError + from google.cloud.spanner_dbapi import connect with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", @@ -300,3 +255,177 @@ def test_executemany(self): execute_mock.assert_has_calls( (mock.call(operation, (1,)), mock.call(operation, (2,))) ) + + def test_fetchone(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [1, 2, 3] + cursor._itr = iter(lst) + for i in range(len(lst)): + self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) + + def test_fetchmany(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + + self.assertEqual(cursor.fetchmany(), [lst[0]]) + + result = cursor.fetchmany(len(lst)) + self.assertEqual(result, lst[1:]) + + def test_fetchall(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + self.assertEqual(cursor.fetchall(), lst) + + def test_nextset(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.nextset() + + def test_setinputsizes(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setinputsizes(sizes=None) + + def test_setoutputsize(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setoutputsize(size=None) + + # def test_handle_insert(self): + # pass + # + # def test_do_execute_insert_heterogenous(self): + # pass + # + # def test_do_execute_insert_homogenous(self): + # pass + + def test_handle_dql(self): + from google.cloud.spanner_dbapi import utils + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + mock_snapshot.execute_sql.return_value = int(0) + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._row_count, 0) + self.assertIsNone(cursor._itr) + + mock_snapshot.execute_sql.return_value = "0" + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._result_set, "0") + self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + def test_context(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with cursor as c: + self.assertEqual(c, cursor) + + self.assertTrue(c.is_closed) + + def test_next(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + cursor.__next__() + + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + i = 0 + for c in cursor._itr: + self.assertEqual(c, lst[i]) + i += 1 + + def test_iter(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + _ = iter(cursor) + + iterator = iter([(1,), (2,), (3,)]) + cursor._itr = iterator + self.assertEqual(iter(cursor), iterator) + + def test_list_tables(self): + from google.cloud.spanner_dbapi import _helpers + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + table_list = ["table1", "table2", "table3"] + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=table_list, + ) as mock_run_sql: + cursor.list_tables() + mock_run_sql.assert_called_once_with(_helpers.SQL_LIST_TABLES) + + def test_run_sql_in_snapshot(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + results = 1, 2, 3 + mock_snapshot.execute_sql.return_value = results + self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results)) + + def test_get_table_column_schema(self): + from google.cloud.spanner_dbapi.cursor import ColumnDetails + from google.cloud.spanner_dbapi import _helpers + from google.cloud.spanner_v1 import param_types + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + column_name = "column1" + is_nullable = "YES" + spanner_type = "spanner_type" + rows = [(column_name, is_nullable, spanner_type)] + expected = { + column_name: ColumnDetails( + null_ok=True, spanner_type=spanner_type, + ) + } + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=rows, + ) as mock_run_sql: + table_name = "table1" + result = cursor.get_table_column_schema(table_name=table_name) + mock_run_sql.assert_called_once_with( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": param_types.STRING}, + ) + self.assertEqual(result, expected) From 5e0ee36e898b35d093ced287fe4f41f8fa7f253a Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 20:40:28 -0400 Subject: [PATCH 16/33] fix: assertion call --- tests/unit/spanner_dbapi/test_cursor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index e3bbcb673f..e905c3216d 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -145,8 +145,9 @@ def test_execute_statement(self): "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", return_value=parse_utils.STMT_NON_UPDATING, ) as mock_handle_ddl: - cursor.execute(sql="sql") - mock_handle_ddl.assert_called() + sql = "sql" + cursor.execute(sql=sql) + mock_handle_ddl.assert_called_once_with(sql, None) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", From 628f04aefff5594b3cd50d333ed8b25d947b1ab2 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 20:58:15 -0400 Subject: [PATCH 17/33] fix: assertion call --- tests/unit/spanner_dbapi/test_cursor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index e905c3216d..1251eda36c 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -157,8 +157,11 @@ def test_execute_statement(self): "google.cloud.spanner_dbapi._helpers.handle_insert", return_value=parse_utils.STMT_INSERT, ) as mock_handle_insert: - cursor.execute(sql="sql") - mock_handle_insert.assert_called() + sql = "sql" + cursor.execute(sql=sql) + mock_handle_insert.assert_called_once_with( + connection, sql, None + ) def test_execute_integrity_error(self): from google.api_core import exceptions From 58c0bb6e88e4da057f76e86eed39ebdbced7ee68 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sat, 10 Oct 2020 23:55:43 -0400 Subject: [PATCH 18/33] test: unit test coverage for `parser.py` --- google/cloud/spanner_dbapi/parser.py | 26 ++--- tests/unit/spanner_dbapi/test_parser.py | 146 +++++++++++++++++++++--- 2 files changed, 142 insertions(+), 30 deletions(-) diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py index 755384b2c1..2fc0156b57 100644 --- a/google/cloud/spanner_dbapi/parser.py +++ b/google/cloud/spanner_dbapi/parser.py @@ -33,7 +33,7 @@ VALUES = "VALUES" -class func: +class func(object): def __init__(self, func_name, args): self.name = func_name self.args = args @@ -67,7 +67,7 @@ class terminal(str): pass -class a_args: +class a_args(object): def __init__(self, argv): self.argv = argv @@ -89,13 +89,11 @@ def __eq__(self, other): if type(self) != type(other): return False - s_len, o_len = len(self), len(other) - if s_len != o_len: + if len(self) != len(other): return False - for i, s_item in enumerate(self): - o_item = other[i] - if s_item != o_item: + for i, item in enumerate(self): + if item != other[i]: return False return True @@ -108,7 +106,7 @@ def homogenous(self): Return True if all the arguments are pyformat args and have the same number of arguments. """ - if not self.all_have_same_argc(): + if not self._is_equal_length(): return False for arg in self.argv: @@ -121,7 +119,7 @@ def homogenous(self): return False return True - def all_have_same_argc(self): + def _is_equal_length(self): """ Return False if all the arguments have the same length. """ @@ -200,7 +198,7 @@ def expect(word, token): # (%s, %s...) if not (word and word.startswith("(")): raise ProgrammingError( - "ARGS: supposed to begin with `(` in `%s`" % (word) + "ARGS: supposed to begin with `(` in `%s`" % word ) word = word[1:] @@ -226,7 +224,7 @@ def expect(word, token): if not (word and word.startswith(")")): raise ProgrammingError( - "ARGS: supposed to end with `)` in `%s`" % (word) + "ARGS: supposed to end with `)` in `%s`" % word ) word = word[1:] @@ -235,7 +233,7 @@ def expect(word, token): elif token == EXPR: if word == "%s": # Terminal symbol. - return "", (pyfmt_str) + return "", pyfmt_str # Otherwise we expect a function. return expect(word, FUNC) @@ -244,5 +242,5 @@ def expect(word, token): def as_values(values_stmt): - _, values = parse_values(values_stmt) - return values + _, _values = parse_values(values_stmt) + return _values diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index 9aecf38e42..60c7283e9c 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -4,23 +4,19 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from unittest import TestCase - -from google.cloud.spanner_dbapi.exceptions import ProgrammingError -from google.cloud.spanner_dbapi.parser import ( - ARGS, - FUNC, - VALUES, - a_args, - expect, - func, - pyfmt_str, - values, -) - - -class ParserTests(TestCase): +import unittest + +from unittest import mock + + +class TestParser(unittest.TestCase): def test_func(self): + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + cases = [ ("_91())", ")", func("_91", a_args([]))), ("_a()", "", func("_a", a_args([]))), @@ -61,6 +57,10 @@ def test_func(self): self.assertEqual(got_unconsumed, want_unconsumed) def test_func_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + cases = [ ("", "FUNC: `` does not begin with `a-zA-z` nor a `_`"), ("91", "FUNC: `91` does not begin with `a-zA-z` nor a `_`"), @@ -76,7 +76,30 @@ def test_func_fail(self): ProgrammingError, wantException, lambda: expect(text, FUNC) ) + def test_func_eq(self): + from google.cloud.spanner_dbapi.parser import func + + func1 = func('func1', None) + func2 = func('func2', None) + self.assertFalse(func1 == object) + self.assertFalse(func1 == func2) + func2.name = func1.name + func1.args = 0 + func2.args = '0' + self.assertFalse(func1 == func2) + func1.args = [0] + func2.args = [0, 0] + self.assertFalse(func1 == func2) + func2.args = func1.args + self.assertTrue(func1 == func2) + def test_a_args(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + cases = [ ("()", "", a_args([])), ("(%s)", "", a_args([pyfmt_str])), @@ -102,6 +125,10 @@ def test_a_args(self): self.assertEqual(got_unconsumed, want_unconsumed) def test_a_args_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import expect + cases = [ ("", "ARGS: supposed to begin with `\\(`"), ("(", "ARGS: supposed to end with `\\)`"), @@ -115,7 +142,80 @@ def test_a_args_fail(self): ProgrammingError, wantException, lambda: expect(text, ARGS) ) + def test_a_args_has_expr(self): + from google.cloud.spanner_dbapi.parser import a_args + + self.assertFalse(a_args([]).has_expr()) + self.assertTrue(a_args([[0]]).has_expr()) + + def test_a_args_eq(self): + from google.cloud.spanner_dbapi.parser import a_args + + a1 = a_args([0]) + self.assertFalse(a1 == object()) + a2 = a_args([0, 0]) + self.assertFalse(a1 == a2) + a1.argv = [0, 1] + self.assertFalse(a1 == a2) + a2.argv = [0, 1] + self.assertTrue(a1 == a2) + + def test_a_args_homogeneous(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + + a_obj = a_args([a_args([terminal(10**i)]) for i in range(10)]) + self.assertTrue(a_obj.homogenous()) + + a_obj = a_args([a_args([[object()]]) for _ in range(10)]) + self.assertFalse(a_obj.homogenous()) + + def test_a_args__is_equal_length(self): + from google.cloud.spanner_dbapi.parser import a_args + + a_obj = a_args([]) + self.assertTrue(a_obj._is_equal_length()) + + def test_values(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + from google.cloud.spanner_dbapi.parser import values + + a_obj = a_args([a_args([terminal(10**i)]) for i in range(10)]) + self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) + + def test_expect(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import EXPR + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi import exceptions + + with self.assertRaises(exceptions.ProgrammingError): + expect(word='', token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word='ABC', token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word='(', token=ARGS) + + expected = "", pyfmt_str + self.assertEqual(expect('%s', EXPR), expected) + + expected = expect('function()', FUNC) + self.assertEqual(expect('function()', EXPR), expected) + + with self.assertRaises(exceptions.ProgrammingError): + expect(word='', token='ABC') + def test_expect_values(self): + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi.parser import values + cases = [ ("VALUES ()", "", values([a_args([])])), ("VALUES", "", values([])), @@ -156,6 +256,10 @@ def test_expect_values(self): self.assertEqual(got_unconsumed, want_unconsumed) def test_expect_values_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import expect + cases = [ ("", "VALUES: `` does not start with VALUES"), ( @@ -172,3 +276,13 @@ def test_expect_values_fail(self): wantException, lambda: expect(text, VALUES), ) + + def test_as_values(self): + from google.cloud.spanner_dbapi.parser import as_values + + values = (1, 2) + with mock.patch( + "google.cloud.spanner_dbapi.parser.parse_values", + return_value=values, + ): + self.assertEqual(as_values(None), values[1]) From 301ca8418d42b2f720ef57470d24e6564cbf561e Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sun, 11 Oct 2020 00:07:27 -0400 Subject: [PATCH 19/33] fix: lint --- tests/unit/spanner_dbapi/test_parser.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py index 60c7283e9c..d5baf9d824 100644 --- a/tests/unit/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -79,13 +79,13 @@ def test_func_fail(self): def test_func_eq(self): from google.cloud.spanner_dbapi.parser import func - func1 = func('func1', None) - func2 = func('func2', None) + func1 = func("func1", None) + func2 = func("func2", None) self.assertFalse(func1 == object) self.assertFalse(func1 == func2) func2.name = func1.name func1.args = 0 - func2.args = '0' + func2.args = "0" self.assertFalse(func1 == func2) func1.args = [0] func2.args = [0, 0] @@ -164,7 +164,7 @@ def test_a_args_homogeneous(self): from google.cloud.spanner_dbapi.parser import a_args from google.cloud.spanner_dbapi.parser import terminal - a_obj = a_args([a_args([terminal(10**i)]) for i in range(10)]) + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) self.assertTrue(a_obj.homogenous()) a_obj = a_args([a_args([[object()]]) for _ in range(10)]) @@ -181,7 +181,7 @@ def test_values(self): from google.cloud.spanner_dbapi.parser import terminal from google.cloud.spanner_dbapi.parser import values - a_obj = a_args([a_args([terminal(10**i)]) for i in range(10)]) + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) def test_expect(self): @@ -193,20 +193,20 @@ def test_expect(self): from google.cloud.spanner_dbapi import exceptions with self.assertRaises(exceptions.ProgrammingError): - expect(word='', token=ARGS) + expect(word="", token=ARGS) with self.assertRaises(exceptions.ProgrammingError): - expect(word='ABC', token=ARGS) + expect(word="ABC", token=ARGS) with self.assertRaises(exceptions.ProgrammingError): - expect(word='(', token=ARGS) + expect(word="(", token=ARGS) expected = "", pyfmt_str - self.assertEqual(expect('%s', EXPR), expected) + self.assertEqual(expect("%s", EXPR), expected) - expected = expect('function()', FUNC) - self.assertEqual(expect('function()', EXPR), expected) + expected = expect("function()", FUNC) + self.assertEqual(expect("function()", EXPR), expected) with self.assertRaises(exceptions.ProgrammingError): - expect(word='', token='ABC') + expect(word="", token="ABC") def test_expect_values(self): from google.cloud.spanner_dbapi.parser import VALUES From 2e5ba9b5c8f260077af1468b45d82f126b2a9310 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sun, 11 Oct 2020 01:35:02 -0400 Subject: [PATCH 20/33] test: unit test coverage for `parse_utils.py` --- google/cloud/spanner_dbapi/parse_utils.py | 17 ++- tests/unit/spanner_dbapi/test_parse_utils.py | 114 +++++++++++-------- tests/unit/spanner_dbapi/test_utils.py | 19 +++- 3 files changed, 98 insertions(+), 52 deletions(-) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index d0e807435e..084eea315e 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -262,11 +262,18 @@ def parse_insert(insert_sql, params): if not params: # Case a) perhaps? # Check if any %s exists. - pyformat_str_count = after_values_sql.count("%s") - if pyformat_str_count > 0: - raise ProgrammingError( - 'no params yet there are %d "%s" tokens' % pyformat_str_count - ) + + # pyformat_str_count = after_values_sql.count("%s") + # if pyformat_str_count > 0: + # raise ProgrammingError( + # 'no params yet there are %d "%%s" tokens' % pyformat_str_count + # ) + for item in after_values_sql: + if item.count("%s") > 0: + raise ProgrammingError( + 'no params yet there are %d "%%s" tokens' + % item.count("%s") + ) insert_sql = sanitize_literals_for_upload(insert_sql) # Confirmed case of: diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 815811ce21..1bd38c85eb 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -4,33 +4,19 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import datetime -import decimal -from unittest import TestCase +import unittest from google.cloud.spanner_v1 import param_types -from google.cloud.spanner_dbapi.exceptions import Error, ProgrammingError -from google.cloud.spanner_dbapi.parse_utils import ( - STMT_DDL, - STMT_INSERT, - STMT_NON_UPDATING, - STMT_UPDATING, - DateStr, - TimestampStr, - cast_for_spanner, - classify_stmt, - ensure_where_clause, - escape_name, - get_param_types, - parse_insert, - rows_for_insert_or_update, - sql_pyformat_args_to_spanner, -) -from google.cloud.spanner_dbapi.utils import backtick_unicode - - -class ParseUtilsTests(TestCase): + + +class TestParseUtils(unittest.TestCase): def test_classify_stmt(self): + from google.cloud.spanner_dbapi.parse_utils import STMT_DDL + from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT + from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING + from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING + from google.cloud.spanner_dbapi.parse_utils import classify_stmt + cases = ( ("SELECT 1", STMT_NON_UPDATING), ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), @@ -61,6 +47,12 @@ def test_classify_stmt(self): self.assertEqual(classify_stmt(query), want_class) def test_parse_insert(self): + from google.cloud.spanner_dbapi.parse_utils import parse_insert + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with self.assertRaises(ProgrammingError): + parse_insert("bad-sql", None) + cases = [ ( "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", @@ -173,6 +165,10 @@ def test_parse_insert(self): ), ] + sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" + with self.assertRaises(ProgrammingError): + parse_insert(sql, None) + for sql, params, want in cases: with self.subTest(sql=sql): got = parse_insert(sql, params) @@ -181,6 +177,9 @@ def test_parse_insert(self): ) def test_parse_insert_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import parse_insert + cases = [ ( "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", @@ -202,12 +201,23 @@ def test_parse_insert_invalid(self): for sql, params, wantException in cases: with self.subTest(sql=sql): self.assertRaisesRegex( - ProgrammingError, + exceptions.ProgrammingError, wantException, lambda: parse_insert(sql, params), ) def test_rows_for_insert_or_update(self): + from google.cloud.spanner_dbapi.parse_utils import ( + rows_for_insert_or_update, + ) + from google.cloud.spanner_dbapi.exceptions import Error + + with self.assertRaises(Error): + rows_for_insert_or_update([0], [[]]) + + with self.assertRaises(Error): + rows_for_insert_or_update([0], None, ["0", "%s"]) + cases = [ ( ["id", "app", "name"], @@ -255,6 +265,12 @@ def test_rows_for_insert_or_update(self): self.assertEqual(got, want) def test_sql_pyformat_args_to_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import ( + sql_pyformat_args_to_spanner, + ) + cases = [ ( ( @@ -323,6 +339,11 @@ def test_sql_pyformat_args_to_spanner(self): ) def test_sql_pyformat_args_to_spanner_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import ( + sql_pyformat_args_to_spanner, + ) + cases = [ ( "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s", @@ -332,12 +353,28 @@ def test_sql_pyformat_args_to_spanner_invalid(self): for sql, params in cases: with self.subTest(sql=sql): self.assertRaisesRegex( - Error, + exceptions.Error, "pyformat_args mismatch", lambda: sql_pyformat_args_to_spanner(sql, params), ) + def test_cast_for_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner + + value = decimal.Decimal(3) + self.assertEqual(cast_for_spanner(value), float(3.0)) + self.assertEqual(cast_for_spanner(5), 5) + self.assertEqual(cast_for_spanner("string"), "string") + def test_get_param_types(self): + import datetime + + from google.cloud.spanner_dbapi.parse_utils import DateStr + from google.cloud.spanner_dbapi.parse_utils import TimestampStr + from google.cloud.spanner_dbapi.parse_utils import get_param_types + params = { "a1": 10, "b1": "string", @@ -365,15 +402,13 @@ def test_get_param_types(self): self.assertEqual(got_types, want_types) def test_get_param_types_none(self): - self.assertEqual(get_param_types(None), None) + from google.cloud.spanner_dbapi.parse_utils import get_param_types - def test_cast_for_spanner(self): - value = decimal.Decimal(3) - self.assertEqual(cast_for_spanner(value), float(3.0)) - self.assertEqual(cast_for_spanner(5), 5) - self.assertEqual(cast_for_spanner("string"), "string") + self.assertEqual(get_param_types(None), None) def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause + cases = [ ( "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", @@ -404,6 +439,8 @@ def test_ensure_where_clause(self): self.assertEqual(got, want) def test_escape_name(self): + from google.cloud.spanner_dbapi.parse_utils import escape_name + cases = ( ("SELECT", "`SELECT`"), ("dashed-value", "`dashed-value`"), @@ -415,16 +452,3 @@ def test_escape_name(self): with self.subTest(name=name): got = escape_name(name) self.assertEqual(got, want) - - def test_backtick_unicode(self): - cases = [ - ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), - ("SELECT (1) as föö", "SELECT (1) as `föö`"), - ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), - ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), - ("SELECT (1) as `föö", "SELECT (1) as `föö"), - ] - for sql, want in cases: - with self.subTest(sql=sql): - got = backtick_unicode(sql) - self.assertEqual(got, want) diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 2ec10eefaf..007c36c36d 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -4,12 +4,12 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from unittest import TestCase +import unittest from google.cloud.spanner_dbapi.utils import PeekIterator -class UtilsTests(TestCase): +class TestUtils(unittest.TestCase): def test_PeekIterator(self): cases = [ ("list", [1, 2, 3, 4, 6, 7], [1, 2, 3, 4, 6, 7]), @@ -51,3 +51,18 @@ def test_peekIterator_nonlist_rows_unconverted(self): got = list(pi) want = ["a", "b", "c", "d", "e"] self.assertEqual(got, want, "Values should be returned unchanged") + + def test_backtick_unicode(self): + from google.cloud.spanner_dbapi.utils import backtick_unicode + + cases = [ + ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), + ("SELECT (1) as föö", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), + ("SELECT (1) as `föö", "SELECT (1) as `föö"), + ] + for sql, want in cases: + with self.subTest(sql=sql): + got = backtick_unicode(sql) + self.assertEqual(got, want) From b4fde4be0062453c343627cf19ec6ea037a6c44b Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sun, 11 Oct 2020 13:14:32 -0400 Subject: [PATCH 21/33] test: unit test coverage for `_helpers.py` --- google/cloud/spanner_dbapi/_helpers.py | 8 +-- tests/unit/spanner_dbapi/test__helpers.py | 60 ++++++++++++++++++--- tests/unit/spanner_dbapi/test_connection.py | 1 + tests/unit/spanner_dbapi/test_globals.py | 12 ++--- tests/unit/spanner_dbapi/test_types.py | 22 ++++++-- tests/unit/spanner_dbapi/test_utils.py | 8 ++- 6 files changed, 86 insertions(+), 25 deletions(-) diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index d1f51e9523..f581fdebbd 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -47,7 +47,7 @@ } -def execute_insert_heterogenous(transaction, sql_params_list): +def _execute_insert_heterogenous(transaction, sql_params_list): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) param_types = get_param_types(params) @@ -59,7 +59,7 @@ def execute_insert_heterogenous(transaction, sql_params_list): _ = list(res) -def execute_insert_homogenous(transaction, parts): +def _execute_insert_homogenous(transaction, parts): # Perform an insert in one shot. table = parts.get("table") columns = parts.get("columns") @@ -87,14 +87,14 @@ def handle_insert(connection, sql, params): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. return connection.database.run_in_transaction( - execute_insert_homogenous, parts + _execute_insert_homogenous, parts ) else: # All the other cases that are esoteric and need # transaction.execute_sql sql_params_list = parts.get("sql_params_list") return connection.database.run_in_transaction( - execute_insert_heterogenous, sql_params_list + _execute_insert_heterogenous, sql_params_list ) diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index 7732a12ca4..e5316d254e 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -8,18 +8,64 @@ import unittest -# from unittest import mock +from unittest import mock class TestHelpers(unittest.TestCase): - def test_execute_insert_heterogenous(self): - pass - - def test_execute_insert_homogenous(self): - pass + def test__execute_insert_heterogenous(self): + from google.cloud.spanner_dbapi import _helpers + + sql = "sql" + params = (sql, None) + with mock.patch( + "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", + return_value=params, + ) as mock_pyformat: + with mock.patch( + "google.cloud.spanner_dbapi._helpers.get_param_types", + return_value=None, + ) as mock_param_types: + transaction = mock.MagicMock() + transaction.execute_sql = mock_execute = mock.MagicMock() + _helpers._execute_insert_heterogenous(transaction, [params]) + + mock_pyformat.assert_called_once_with(params[0], params[1]) + mock_param_types.assert_called_once_with(None) + mock_execute.assert_called_once_with( + sql, params=None, param_types=None + ) + + def test__execute_insert_homogenous(self): + from google.cloud.spanner_dbapi import _helpers + + transaction = mock.MagicMock() + transaction.insert = mock.MagicMock() + parts = mock.MagicMock() + parts.get = mock.MagicMock(return_value=0) + + _helpers._execute_insert_homogenous(transaction, parts) + transaction.insert.assert_called_once_with(0, 0, 0) def test_handle_insert(self): - pass + from google.cloud.spanner_dbapi import _helpers + + connection = mock.MagicMock() + connection.database.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + parts = mock.MagicMock() + with mock.patch( + "google.cloud.spanner_dbapi._helpers.parse_insert", + return_value=parts, + ): + parts.get = mock.MagicMock(return_value=True) + mock_run_in.return_value = 0 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 0) + + parts.get = mock.MagicMock(return_value=False) + mock_run_in.return_value = 1 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 1) class TestColumnInfo(unittest.TestCase): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 865db1cb61..2cbd6ac1ed 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -7,6 +7,7 @@ """Cloud Spanner DB-API Connection class unit tests.""" import unittest + from unittest import mock diff --git a/tests/unit/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py index 3b702f7be7..3f8360e2ea 100644 --- a/tests/unit/spanner_dbapi/test_globals.py +++ b/tests/unit/spanner_dbapi/test_globals.py @@ -4,16 +4,14 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from unittest import TestCase +import unittest -class DBAPIGlobalsTests(TestCase): +class TestDBAPIGlobals(unittest.TestCase): def test_apilevel(self): - from google.cloud.spanner_dbapi import ( - apilevel, - paramstyle, - threadsafety, - ) + from google.cloud.spanner_dbapi import apilevel + from google.cloud.spanner_dbapi import paramstyle + from google.cloud.spanner_dbapi import threadsafety self.assertEqual(apilevel, "2.0", "We implement PEP-0249 version 2.0") self.assertEqual(paramstyle, "format", "Cloud Spanner uses @param") diff --git a/tests/unit/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py index 6c41041628..4246a43e45 100644 --- a/tests/unit/spanner_dbapi/test_types.py +++ b/tests/unit/spanner_dbapi/test_types.py @@ -4,36 +4,48 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import datetime -from time import timezone -from unittest import TestCase +import unittest -from google.cloud.spanner_dbapi import types +from time import timezone -class TypesTests(TestCase): +class TestTypes(unittest.TestCase): TICKS = 1572822862.9782631 + timezone # Sun 03 Nov 2019 23:14:22 UTC def test__date_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + actual = types._date_from_ticks(self.TICKS) expected = datetime.date(2019, 11, 3) self.assertEqual(actual, expected) def test__time_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + actual = types._time_from_ticks(self.TICKS) expected = datetime.time(23, 14, 22) self.assertEqual(actual, expected) def test__timestamp_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + actual = types._timestamp_from_ticks(self.TICKS) expected = datetime.datetime(2019, 11, 3, 23, 14, 22) self.assertEqual(actual, expected) def test_type_equal(self): + from google.cloud.spanner_dbapi import types + self.assertEqual(types.BINARY, "TYPE_CODE_UNSPECIFIED") self.assertEqual(types.BINARY, "BYTES") self.assertEqual(types.BINARY, "ARRAY") diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 007c36c36d..90e1b7cf04 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -6,11 +6,11 @@ import unittest -from google.cloud.spanner_dbapi.utils import PeekIterator - class TestUtils(unittest.TestCase): def test_PeekIterator(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + cases = [ ("list", [1, 2, 3, 4, 6, 7], [1, 2, 3, 4, 6, 7]), ("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]), @@ -26,6 +26,8 @@ def test_PeekIterator(self): self.assertEqual(actual, expected) def test_peekIterator_list_rows_converted_to_tuples(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + # Cloud Spanner returns results in lists e.g. [result]. # PeekIterator is used by BaseCursor in its fetch* methods. # This test ensures that anything passed into PeekIterator @@ -47,6 +49,8 @@ def test_peekIterator_list_rows_converted_to_tuples(self): self.assertEqual(next(pit), ("Clark", "Kent")) def test_peekIterator_nonlist_rows_unconverted(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + pi = PeekIterator(["a", "b", "c", "d", "e"]) got = list(pi) want = ["a", "b", "c", "d", "e"] From d5855793a8eb3f64f69a86d36cb8226a0a938537 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Mon, 12 Oct 2020 13:17:00 -0400 Subject: [PATCH 22/33] chore: `connection` attrubute made public in `cursor.py` --- google/cloud/spanner_dbapi/cursor.py | 24 ++++++++++-------------- tests/unit/spanner_dbapi/test_cursor.py | 4 ++-- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 044f980b8f..f6ba16d216 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -45,16 +45,12 @@ def __init__(self, connection): self._itr = None self._result_set = None self._row_count = _UNSET_COUNT - self._connection = connection + self.connection = connection self._is_closed = False # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 - @property - def connection(self): - return self._connection - @property def is_closed(self): """The cursor close indicator. @@ -63,7 +59,7 @@ def is_closed(self): :returns: True if the cursor or the parent connection is closed, otherwise False. """ - return self._is_closed or self._connection.is_closed + return self._is_closed or self.connection.is_closed @property def description(self): @@ -143,7 +139,7 @@ def execute(self, sql, args=None): :type args: list :param args: Additional parameters to supplement the SQL query. """ - if not self._connection: + if not self.connection: raise ProgrammingError("Cursor is not connected to the database") self._raise_if_closed() @@ -154,20 +150,20 @@ def execute(self, sql, args=None): try: classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: - self._connection.ddl_statements.append(sql) + self.connection.ddl_statements.append(sql) return # For every other operation, we've got to ensure that # any prior DDL statements were run. # self._run_prior_DDL_statements() - self._connection.run_prior_DDL_statements() + self.connection.run_prior_DDL_statements() if classification == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif classification == parse_utils.STMT_INSERT: - _helpers.handle_insert(self._connection, sql, args or None) + _helpers.handle_insert(self.connection, sql, args or None) else: - self._connection.database.run_in_transaction( + self.connection.database.run_in_transaction( self._do_execute_update, sql, args or None ) except (AlreadyExists, FailedPrecondition) as e: @@ -249,7 +245,7 @@ def setoutputsize(self, size, column=None): self._raise_if_closed() def _handle_DQL(self, sql, params): - with self._connection.database.snapshot() as snapshot: + with self.connection.database.snapshot() as snapshot: # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) @@ -298,9 +294,9 @@ def list_tables(self): def run_sql_in_snapshot(self, sql, params=None, param_types=None): # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions # hence this method exists to circumvent that limit. - self._connection.run_prior_DDL_statements() + self.connection.run_prior_DDL_statements() - with self._connection.database.snapshot() as snapshot: + with self.connection.database.snapshot() as snapshot: res = snapshot.execute_sql( sql, params=params, param_types=param_types ) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 1251eda36c..337a645736 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -111,7 +111,7 @@ def test_execute_programming_error(self): connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) - cursor._connection = None + cursor.connection = None with self.assertRaises(ProgrammingError): cursor.execute(sql="") @@ -135,7 +135,7 @@ def test_execute_statement(self): sql = "sql" cursor.execute(sql=sql) mock_classify_stmt.assert_called_once_with(sql) - self.assertEqual(cursor._connection.ddl_statements, [sql]) + self.assertEqual(cursor.connection.ddl_statements, [sql]) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", From b535bf11b80121debcc64e13b928bdf1cb6f95ba Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 21:17:17 -0400 Subject: [PATCH 23/33] test: unit test for `Connection.autocommit` property setter --- tests/unit/spanner_dbapi/test_connection.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 409bb4a4d7..2ea1e04c15 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -42,6 +42,25 @@ def _make_connection(self): database = instance.database(self.DATABASE) return Connection(instance, database) + def test_property_autocommit_setter(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = True + mock_commit.assert_called_once_with() + self.assertEqual(connection._autocommit, True) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = False + mock_commit.assert_not_called() + self.assertEqual(connection._autocommit, False) + def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError From 81db6f73a9c652b3238f1c93886465beef78d5fd Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 21:41:03 -0400 Subject: [PATCH 24/33] test: unit test for `Connection._session_checkout` --- tests/unit/spanner_dbapi/test_connection.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 2ea1e04c15..a46152c672 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -61,6 +61,24 @@ def test_property_autocommit_setter(self): mock_commit.assert_not_called() self.assertEqual(connection._autocommit, False) + def test__session_checkout(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.get = mock.MagicMock(return_value='db_session_pool') + connection = Connection(self.INSTANCE, mock_database) + + connection._session_checkout() + mock_database._pool.get.assert_called_once_with() + self.assertEqual(connection._session, 'db_session_pool') + + connection._session = 'db_session' + connection._session_checkout() + self.assertEqual(connection._session, 'db_session') + def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError From d484496e1e09d00876be52e77e48f0a3e813441d Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 22:20:30 -0400 Subject: [PATCH 25/33] test: unit test for `Connection.transaction_checkout` --- tests/unit/spanner_dbapi/test_connection.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index a46152c672..99f78216cb 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -79,6 +79,21 @@ def test__session_checkout(self): connection._session_checkout() self.assertEqual(connection._session, 'db_session') + def test_transaction_checkout(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + connection._session_checkout = mock_checkout = mock.MagicMock(autospec=True) + connection.transaction_checkout() + mock_checkout.assert_called_once_with() + + connection._transaction = mock_xaction = mock.MagicMock() + mock_xaction.committed = mock_xaction.rolled_back = False + self.assertEqual(connection.transaction_checkout(), mock_xaction) + + connection._autocommit = True + self.assertIsNone(connection.transaction_checkout()) + def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError From 4c0b83cba8d50256da9ce03150fe1c2cc345b8fa Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 22:21:08 -0400 Subject: [PATCH 26/33] test: unit test for `Connection._release_session` --- tests/unit/spanner_dbapi/test_connection.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 99f78216cb..f3896b7272 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -79,6 +79,21 @@ def test__session_checkout(self): connection._session_checkout() self.assertEqual(connection._session, 'db_session') + def test__release_session(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.put = mock.MagicMock() + connection = Connection(self.INSTANCE, mock_database) + connection._session = 'session' + + connection._release_session() + mock_database._pool.put.assert_called_once_with('session') + self.assertIsNone(connection._session) + def test_transaction_checkout(self): from google.cloud.spanner_dbapi import Connection From 6580e6a127bd4a5e5ed51a4bd5e50e015cf44738 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 22:22:22 -0400 Subject: [PATCH 27/33] test: updated unit test for `Connection.close` --- tests/unit/spanner_dbapi/test_connection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index f3896b7272..866fb5c8b0 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -129,6 +129,12 @@ def test_close(self): with self.assertRaises(InterfaceError): connection.cursor() + connection._transaction = mock_xaction = mock.MagicMock() + mock_xaction.committed = mock_xaction.rolled_back = False + mock_xaction.rollback = mock_rollback = mock.MagicMock() + connection.close() + mock_rollback.assert_called_once_with() + def test_commit(self): from google.cloud.spanner_dbapi import Connection, InterfaceError From 4b7f7d8946519708f0d8e36df521db5333c954e7 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 22:33:37 -0400 Subject: [PATCH 28/33] test: updated unit test for `Connection.commit` --- google/cloud/spanner_dbapi/connection.py | 2 +- tests/unit/spanner_dbapi/test_connection.py | 22 +++++++++------------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index f13ef07157..1caf1fbf40 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -153,7 +153,7 @@ def close(self): def commit(self): """Commits any pending transaction to the database.""" - if self.autocommit: + if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: self._transaction.commit() diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 866fb5c8b0..5b8452e669 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -7,6 +7,7 @@ """Cloud Spanner DB-API Connection class unit tests.""" import unittest +import warnings from unittest import mock @@ -135,19 +136,15 @@ def test_close(self): connection.close() mock_rollback.assert_called_once_with() - def test_commit(self): - from google.cloud.spanner_dbapi import Connection, InterfaceError + @mock.patch.object(warnings, 'warn') + def test_commit(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING connection = Connection(self.INSTANCE, self.DATABASE) - - # with mock.patch( - # "google.cloud.spanner_dbapi.connection.Connection.run_prior_DDL_statements" - # ) as run_ddl_mock: - # connection.commit() - # run_ddl_mock.assert_called_once_with() - connection._transaction = mock_transaction = mock.MagicMock() mock_transaction.commit = mock_commit = mock.MagicMock() + with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: @@ -155,10 +152,9 @@ def test_commit(self): mock_commit.assert_called_once_with() mock_release.assert_called_once_with() - # connection.is_closed = True - # - # with self.assertRaises(InterfaceError): - # connection.commit() + connection._autocommit = True + connection.commit() + mock_warn.assert_called_once_with(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) def test_rollback(self): from google.cloud.spanner_dbapi import Connection From a6fcbdfc21d326bae5f576a6c2abe13208eef5f4 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 22:47:56 -0400 Subject: [PATCH 29/33] test: updated unit test for `Connection.rollback` --- google/cloud/spanner_dbapi/connection.py | 2 +- tests/unit/spanner_dbapi/test_connection.py | 38 ++++++++++++++------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 1caf1fbf40..10e9361140 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -161,7 +161,7 @@ def commit(self): def rollback(self): """Rollback all the pending transactions.""" - if self.autocommit: + if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: self._transaction.rollback() diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 5b8452e669..a2d65aa739 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -103,9 +103,9 @@ def test_transaction_checkout(self): connection.transaction_checkout() mock_checkout.assert_called_once_with() - connection._transaction = mock_xaction = mock.MagicMock() - mock_xaction.committed = mock_xaction.rolled_back = False - self.assertEqual(connection.transaction_checkout(), mock_xaction) + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + self.assertEqual(connection.transaction_checkout(), mock_transaction) connection._autocommit = True self.assertIsNone(connection.transaction_checkout()) @@ -130,9 +130,9 @@ def test_close(self): with self.assertRaises(InterfaceError): connection.cursor() - connection._transaction = mock_xaction = mock.MagicMock() - mock_xaction.committed = mock_xaction.rolled_back = False - mock_xaction.rollback = mock_rollback = mock.MagicMock() + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + mock_transaction.rollback = mock_rollback = mock.MagicMock() connection.close() mock_rollback.assert_called_once_with() @@ -142,6 +142,13 @@ def test_commit(self, mock_warn): from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_release.assert_not_called() + connection._transaction = mock_transaction = mock.MagicMock() mock_transaction.commit = mock_commit = mock.MagicMock() @@ -156,19 +163,22 @@ def test_commit(self, mock_warn): connection.commit() mock_warn.assert_called_once_with(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) - def test_rollback(self): + @mock.patch.object(warnings, 'warn') + def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING connection = Connection(self.INSTANCE, self.DATABASE) - # with mock.patch( - # "google.cloud.spanner_dbapi.connection.Connection._raise_if_closed" - # ) as check_closed_mock: - # connection.rollback() - # check_closed_mock.assert_called_once_with() + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_release.assert_not_called() connection._transaction = mock_transaction = mock.MagicMock() mock_transaction.rollback = mock_rollback = mock.MagicMock() + with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: @@ -176,6 +186,10 @@ def test_rollback(self): mock_rollback.assert_called_once_with() mock_release.assert_called_once_with() + connection._autocommit = True + connection.rollback() + mock_warn.assert_called_once_with(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError From b38982aecd46cb8bd21ef2150d022a4cb83f3779 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 23:11:52 -0400 Subject: [PATCH 30/33] test: updated unit tests for `Cursor.execute` --- google/cloud/spanner_dbapi/cursor.py | 3 +-- tests/unit/spanner_dbapi/test_cursor.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 22827ae82d..6997752a42 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -26,8 +26,7 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types - -from .utils import PeekIterator +from google.cloud.spanner_dbapi.utils import PeekIterator _UNSET_COUNT = -1 diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 00e6ff960e..de20eaff98 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -122,6 +122,18 @@ def test_execute_attribute_error(self): with self.assertRaises(AttributeError): cursor.execute(sql="") + def test_execute_autocommit_off(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.connection._autocommit = False + cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) + + cursor.execute('sql') + self.assertIsInstance(cursor._result_set, mock.MagicMock) + self.assertIsInstance(cursor._itr, PeekIterator) + def test_execute_statement(self): from google.cloud.spanner_dbapi import parse_utils @@ -164,6 +176,18 @@ def test_execute_statement(self): connection, sql, None ) + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value='other_statement', + ): + cursor.connection._database = mock_db = mock.MagicMock() + mock_db.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + cursor.execute(sql=sql) + mock_run_in.assert_called_once_with( + cursor._do_execute_update, sql, None + ) + def test_execute_integrity_error(self): from google.api_core import exceptions from google.cloud.spanner_dbapi.exceptions import IntegrityError From 594c9bcebabb29fd7d67b99967fa6a8388111323 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 23:55:45 -0400 Subject: [PATCH 31/33] chore: cleanup --- tests/unit/spanner_dbapi/test_cursor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index de20eaff98..09288df94e 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -130,7 +130,7 @@ def test_execute_autocommit_off(self): cursor.connection._autocommit = False cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) - cursor.execute('sql') + cursor.execute("sql") self.assertIsInstance(cursor._result_set, mock.MagicMock) self.assertIsInstance(cursor._itr, PeekIterator) @@ -178,7 +178,7 @@ def test_execute_statement(self): with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value='other_statement', + return_value="other_statement", ): cursor.connection._database = mock_db = mock.MagicMock() mock_db.run_in_transaction = mock_run_in = mock.MagicMock() From da15d94e6c0ec0bd981ee4ce15a8b224edf53ff5 Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Fri, 23 Oct 2020 23:55:59 -0400 Subject: [PATCH 32/33] chore: cleanup --- tests/unit/spanner_dbapi/test_connection.py | 74 +++++++++++---------- 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index a2d65aa739..d545472c57 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -38,7 +38,7 @@ def _make_connection(self): from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1.instance import Instance - # we don't need real Client object to test the constructor + # We don't need a real Client object to test the constructor instance = Instance(self.INSTANCE, client=None) database = instance.database(self.DATABASE) return Connection(instance, database) @@ -62,6 +62,20 @@ def test_property_autocommit_setter(self): mock_commit.assert_not_called() self.assertEqual(connection._autocommit, False) + def test_property_database(self): + from google.cloud.spanner_v1.database import Database + + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + def test_property_instance(self): + from google.cloud.spanner_v1.instance import Instance + + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + def test__session_checkout(self): from google.cloud.spanner_dbapi import Connection @@ -69,16 +83,18 @@ def test__session_checkout(self): "google.cloud.spanner_v1.database.Database", ) as mock_database: mock_database._pool = mock.MagicMock() - mock_database._pool.get = mock.MagicMock(return_value='db_session_pool') + mock_database._pool.get = mock.MagicMock( + return_value="db_session_pool" + ) connection = Connection(self.INSTANCE, mock_database) connection._session_checkout() mock_database._pool.get.assert_called_once_with() - self.assertEqual(connection._session, 'db_session_pool') + self.assertEqual(connection._session, "db_session_pool") - connection._session = 'db_session' + connection._session = "db_session" connection._session_checkout() - self.assertEqual(connection._session, 'db_session') + self.assertEqual(connection._session, "db_session") def test__release_session(self): from google.cloud.spanner_dbapi import Connection @@ -89,17 +105,19 @@ def test__release_session(self): mock_database._pool = mock.MagicMock() mock_database._pool.put = mock.MagicMock() connection = Connection(self.INSTANCE, mock_database) - connection._session = 'session' + connection._session = "session" connection._release_session() - mock_database._pool.put.assert_called_once_with('session') + mock_database._pool.put.assert_called_once_with("session") self.assertIsNone(connection._session) def test_transaction_checkout(self): from google.cloud.spanner_dbapi import Connection connection = Connection(self.INSTANCE, self.DATABASE) - connection._session_checkout = mock_checkout = mock.MagicMock(autospec=True) + connection._session_checkout = mock_checkout = mock.MagicMock( + autospec=True + ) connection.transaction_checkout() mock_checkout.assert_called_once_with() @@ -136,10 +154,12 @@ def test_close(self): connection.close() mock_rollback.assert_called_once_with() - @mock.patch.object(warnings, 'warn') + @mock.patch.object(warnings, "warn") def test_commit(self, mock_warn): from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) connection = Connection(self.INSTANCE, self.DATABASE) @@ -161,12 +181,16 @@ def test_commit(self, mock_warn): connection._autocommit = True connection.commit() - mock_warn.assert_called_once_with(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) - @mock.patch.object(warnings, 'warn') + @mock.patch.object(warnings, "warn") def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) connection = Connection(self.INSTANCE, self.DATABASE) @@ -188,7 +212,9 @@ def test_rollback(self, mock_warn): connection._autocommit = True connection.rollback() - mock_warn.assert_called_once_with(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError @@ -290,23 +316,3 @@ def test_sessions_pool(self): ): connect("test-instance", database_id, pool=pool) database_mock.assert_called_once_with(database_id, pool=pool) - - def test_database_property(self): - from google.cloud.spanner_v1.database import Database - - connection = self._make_connection() - self.assertIsInstance(connection.database, Database) - self.assertEqual(connection.database, connection._database) - - with self.assertRaises(AttributeError): - connection.database = None - - def test_instance_property(self): - from google.cloud.spanner_v1.instance import Instance - - connection = self._make_connection() - self.assertIsInstance(connection.instance, Instance) - self.assertEqual(connection.instance, connection._instance) - - with self.assertRaises(AttributeError): - connection.instance = None From 41abaebb6f2e0b1cf16704aa1e394acc5a47e68b Mon Sep 17 00:00:00 2001 From: "STATION\\MF" Date: Sun, 25 Oct 2020 12:45:27 -0400 Subject: [PATCH 33/33] fix: implementing suggested changes --- google/cloud/spanner_dbapi/connection.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 10e9361140..b572c8573b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -55,7 +55,8 @@ def autocommit(self): @autocommit.setter def autocommit(self, value): - """Change this connection autocommit mode. + """Change this connection autocommit mode. Setting this value to True + while a transaction is active will commit the current transaction. :type value: bool :param value: New autocommit mode state. @@ -152,7 +153,10 @@ def close(self): self.is_closed = True def commit(self): - """Commits any pending transaction to the database.""" + """Commits any pending transaction to the database. + + This method is non-operational in autocommit mode. + """ if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: @@ -160,7 +164,11 @@ def commit(self): self._release_session() def rollback(self): - """Rollback all the pending transactions.""" + """Rolls back any pending transaction. + + This is a no-op if there is no active transaction or if the connection + is in autocommit mode. + """ if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: