8000 fix(dbapi): executemany() hiding all the results except the last (#181) · larkee/python-spanner@020dc17 · GitHub
[go: up one dir, main page]

Skip to content

Commit 020dc17

Browse files
author
Ilya Gurov
authored
fix(dbapi): executemany() hiding all the results except the last (googleapis#181)
1 parent cbe6ec1 commit 020dc17

File tree

4 files changed

+111
-1
lines changed

4 files changed

+111
-1
lines changed

google/cloud/spanner_dbapi/cursor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from google.cloud.spanner_dbapi.parse_utils import get_param_types
3838
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
3939
from google.cloud.spanner_dbapi.utils import PeekIterator
40+
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
4041

4142
_UNSET_COUNT = -1
4243

@@ -210,8 +211,20 @@ def executemany(self, operation, seq_of_params):
210211
"""
211212
self._raise_if_closed()
212213

214+
classification = parse_utils.classify_stmt(operation)
215+
if classification == parse_utils.STMT_DDL:
216+
raise ProgrammingError(
217+
"Executing DDL statements with executemany() method is not allowed."
218+
)
219+
220+
many_result_set = StreamedManyResultSets()
221+
213222
for params in seq_of_params:
214223
self.execute(operation, params)
224+
many_result_set.add_iter(self._itr)
225+
226+
self._result_set = many_result_set
227+
self._itr = many_result_set
215228

216229
def fetchone(self):
217230
"""Fetch the next row of a query result set, returning a single

google/cloud/spanner_dbapi/utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import re
1616

17+
re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")
18+
1719

1820
class PeekIterator:
1921
"""
@@ -55,7 +57,43 @@ def __iter__(self):
5557
return self
5658

5759

58-
re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")
60+
class StreamedManyResultSets:
61+
"""Iterator to walk through several `StreamedResultsSet` iterators.
62+
This type of iterator is used by `Cursor.executemany()`
63+
method to iterate through several `StreamedResultsSet`
64+
iterators like they all are merged into single iterator.
65+
"""
66+
67+
def __init__(se 8000 lf):
68+
self._iterators = []
69+
self._index = 0
70+
71+
def add_iter(self, iterator):
72+
"""Add new iterator into this one.
73+
:type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`
74+
:param iterator: Iterator to merge into this one.
75+
"""
76+
self._iterators.append(iterator)
77+
78+
def __next__(self):
79+
"""Return the next value from the currently streamed iterator.
80+
If the current iterator is streamed to the end,
81+
start to stream the next one.
82+
:rtype: list
83+
:returns: The next result row.
84+
"""
85+
try:
86+
res = next(self._iterators[self._index])
87+
except StopIteration:
88+
self._index += 1
89+
res = self.__next__()
90+
except IndexError:
91+
raise StopIteration
92+
93+
return res
94+
95+
def __iter__(self):
96+
return self
5997

6098

6199
def backtick_unicode(sql):

tests/system/test_system_dbapi.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,46 @@ def test_results_checksum(self):
305305

306306
self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest())
307307

308+
def test_execute_many(self):
309+
# connect to the test database
310+
conn = Connection(Config.INSTANCE, self._db)
311+
cursor = conn.cursor()
312+
313+
cursor.execute(
314+
"""
315+
INSERT INTO contacts (contact_id, first_name, last_name, email)
316+
VALUES (1, 'first-name', 'last-name', 'test.email@example.com'),
317+
(2, 'first-name2', 'last-name2', 'test.email2@example.com')
318+
"""
319+
)
320+
conn.commit()
321+
322+
cursor.executemany(
323+
"""
324+
SELECT * FROM contacts WHERE contact_id = @a1
325+
""",
326+
({"a1": 1}, {"a1": 2}),
327+
)
328+
res = cursor.fetchall()
329+
conn.commit()
330+
331+
self.assertEqual(len(res), 2)
332+
self.assertEqual(res[0][0], 1)
333+
self.assertEqual(res[1][0], 2)
334+
335+
# checking that execute() and executemany()
336+
# results are not mixed together
337+
cursor.execute(
338+
"""
339+
SELECT * FROM contacts WHERE contact_id = 1
340+
""",
341+
)
342+
res = cursor.fetchone()
343+
conn.commit()
344+
345+
self.assertEqual(res[0], 1)
346+
conn.close()
347+
308348

309349
def clear_table(transaction):
310350
"""Clear the test table."""

tests/unit/spanner_dbapi/test_cursor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,22 @@ def test_executemany_on_closed_cursor(self):
257257
with self.assertRaises(InterfaceError):
258258
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())
259259

260+
def test_executemany_DLL(self):
261+
from google.cloud.spanner_dbapi import connect, ProgrammingError
262+
263+
with mock.patch(
264+
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True,
265+
):
266+
with mock.patch(
267+
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
268+
):
269+
connection = connect("test-instance", "test-database")
270+
271+
cursor = connection.cursor()
272+
273+
with self.assertRaises(ProgrammingError):
274+
cursor.executemany("""DROP DATABASE database_name""", ())
275+
260276
def test_executemany(self):
261277
from google.cloud.spanner_dbapi import connect
262278

@@ -272,6 +288,9 @@ def test_executemany(self):
272288
connection = connect("test-instance", "test-database")
273289

274290
cursor = connection.cursor()
291+
cursor._result_set = [1, 2, 3]
292+
cursor._itr = iter([1, 2, 3])
293+
275294
with mock.patch(
276295
"google.cloud.spanner_dbapi.cursor.Cursor.execute"
277296
) as execute_mock:

0 commit comments

Comments
 (0)
0