4
4
# license that can be found in the LICENSE file or at
5
5
# https://developers.google.com/open-source/licenses/bsd
6
6
7
- import datetime
8
- import decimal
9
- from unittest import TestCase
7
+ import unittest
10
8
11
9
from google .cloud .spanner_v1 import param_types
12
- from google .cloud .spanner_dbapi .exceptions import Error , ProgrammingError
13
- from google .cloud .spanner_dbapi .parse_utils import (
14
- STMT_DDL ,
15
- STMT_INSERT ,
16
- STMT_NON_UPDATING ,
17
- STMT_UPDATING ,
18
- DateStr ,
19
- TimestampStr ,
20
- cast_for_spanner ,
21
- classify_stmt ,
22
- ensure_where_clause ,
23
- escape_name ,
24
- get_param_types ,
25
- parse_insert ,
26
- rows_for_insert_or_update ,
27
- sql_pyformat_args_to_spanner ,
28
- )
29
- from google .cloud .spanner_dbapi .utils import backtick_unicode
30
-
31
-
32
- class ParseUtilsTests (TestCase ):
10
+
11
+
12
+ class TestParseUtils (unittest .TestCase ):
33
13
def test_classify_stmt (self ):
14
+ from google .cloud .spanner_dbapi .parse_utils import STMT_DDL
15
+ from google .cloud .spanner_dbapi .parse_utils import STMT_INSERT
16
+ from google .cloud .spanner_dbapi .parse_utils import STMT_NON_UPDATING
17
+ from google .cloud .spanner_dbapi .parse_utils import STMT_UPDATING
18
+ from google .cloud .spanner_dbapi .parse_utils import classify_stmt
19
+
34
20
cases = (
35
21
("SELECT 1" , STMT_NON_UPDATING ),
36
22
("SELECT s.SongName FROM Songs AS s" , STMT_NON_UPDATING ),
@@ -61,6 +47,12 @@ def test_classify_stmt(self):
61
47
self .assertEqual (classify_stmt (query ), want_class )
62
48
63
49
def test_parse_insert (self ):
50
+ from google .cloud .spanner_dbapi .parse_utils import parse_insert
51
+ from google .cloud .spanner_dbapi .exceptions import ProgrammingError
52
+
53
+ with self .assertRaises (ProgrammingError ):
54
+ parse_insert ("bad-sql" , None )
55
+
64
56
cases = [
65
57
(
66
58
"INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" ,
@@ -173,6 +165,10 @@ def test_parse_insert(self):
173
165
),
174
166
]
175
167
168
+ sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)"
169
+ with self .assertRaises (ProgrammingError ):
170
+ parse_insert (sql , None )
171
+
176
172
for sql , params , want in cases :
177
173
with self .subTest (sql = sql ):
178
174
got = parse_insert (sql , params )
@@ -181,6 +177,9 @@ def test_parse_insert(self):
181
177
)
182
178
183
179
def test_parse_insert_invalid (self ):
180
+ from google .cloud .spanner_dbapi import exceptions
181
+ from google .cloud .spanner_dbapi .parse_utils import parse_insert
182
+
184
183
cases = [
185
184
(
186
185
"INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)" ,
@@ -202,12 +201,23 @@ def test_parse_insert_invalid(self):
202
201
for sql , params , wantException in cases :
203
202
with self .subTest (sql = sql ):
204
203
self .assertRaisesRegex (
205
- ProgrammingError ,
204
+ exceptions . ProgrammingError ,
206
205
wantException ,
207
206
lambda : parse_insert (sql , params ),
208
207
)
209
208
210
209
def test_rows_for_insert_or_update (self ):
210
+ from google .cloud .spanner_dbapi .parse_utils import (
211
+ rows_for_insert_or_update ,
212
+ )
213
+ from google .cloud .spanner_dbapi .exceptions import Error
214
+
215
+ with self .assertRaises (Error ):
216
+ rows_for_insert_or_update ([0 ], [[]])
217
+
218
+ with self .assertRaises (Error ):
219
+ rows_for_insert_or_update ([0 ], None , ["0" , "%s" ])
220
+
211
221
cases = [
212
222
(
213
223
["id" , "app" , "name" ],
@@ -255,6 +265,12 @@ def test_rows_for_insert_or_update(self):
255
265
self .assertEqual (got , want )
256
266
257
267
def test_sql_pyformat_args_to_spanner (self ):
268
+ import decimal
269
+
270
+ from google .cloud .spanner_dbapi .parse_utils import (
271
+ sql_pyformat_args_to_spanner ,
272
+ )
273
+
258
274
cases = [
259
275
(
260
276
(
@@ -323,6 +339,11 @@ def test_sql_pyformat_args_to_spanner(self):
323
339
)
324
340
325
341
def test_sql_pyformat_args_to_spanner_invalid (self ):
342
+ from google .cloud .spanner_dbapi import exceptions
343
+ from google .cloud .spanner_dbapi .parse_utils import (
344
+ sql_pyformat_args_to_spanner ,
345
+ )
346
+
326
347
cases = [
327
348
(
328
349
"SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s" ,
@@ -332,12 +353,28 @@ def test_sql_pyformat_args_to_spanner_invalid(self):
332
353
for sql , params in cases :
333
354
with self .subTest (sql = sql ):
334
355
self .assertRaisesRegex (
335
- Error ,
356
+ exceptions . Error ,
336
357
"pyformat_args mismatch" ,
337
358
lambda : sql_pyformat_args_to_spanner (sql , params ),
338
359
)
339
360
361
+ def test_cast_for_spanner (self ):
362
+ import decimal
363
+
364
+ from google .cloud .spanner_dbapi .parse_utils import cast_for_spanner
365
+
366
+ value = decimal .Decimal (3 )
367
+ self .assertEqual (cast_for_spanner (value ), float (3.0 ))
368
+ self .assertEqual (cast_for_spanner (5 ), 5 )
369
+ self .assertEqual (cast_for_spanner ("string" ), "string" )
370
+
340
371
def test_get_param_types (self ):
372
+ import datetime
373
+
374
+ from google .cloud .spanner_dbapi .parse_utils import DateStr
375
+ from google .cloud .spanner_dbapi .parse_utils import TimestampStr
376
+ from google .cloud .spanner_dbapi .parse_utils import get_param_types
377
+
341
378
params = {
342
379
"a1" : 10 ,
343
380
"b1" : "string" ,
@@ -365,15 +402,13 @@ def test_get_param_types(self):
365
402
self .assertEqual (got_types , want_types )
366
403
367
404
def test_get_param_types_none (self ):
368
- self . assertEqual ( get_param_types ( None ), None )
405
+ from google . cloud . spanner_dbapi . parse_utils import get_param_types
369
406
370
- def test_cast_for_spanner (self ):
371
- value = decimal .Decimal (3 )
372
- self .assertEqual (cast_for_spanner (value ), float (3.0 ))
373
- self .assertEqual (cast_for_spanner (5 ), 5 )
374
- self .assertEqual (cast_for_spanner ("string" ), "string" )
407
+ self .assertEqual (get_param_types (None ), None )
375
408
376
409
def test_ensure_where_clause (self ):
410
+ from google .cloud .spanner_dbapi .parse_utils import ensure_where_clause
411
+
377
412
cases = [
378
413
(
379
414
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1" ,
@@ -404,6 +439,8 @@ def test_ensure_where_clause(self):
404
439
self .assertEqual (got , want )
405
440
406
441
def test_escape_name (self ):
442
+ from google .cloud .spanner_dbapi .parse_utils import escape_name
443
+
407
444
cases = (
408
445
("SELECT" , "`SELECT`" ),
409
446
("dashed-value" , "`dashed-value`" ),
@@ -415,16 +452,3 @@ def test_escape_name(self):
415
452
with self .subTest (name = name ):
416
453
got = escape_name (name )
417
454
self .assertEqual (got , want )
418
-
419
- def test_backtick_unicode (self ):
420
- cases = [
421
- ("SELECT (1) as foo WHERE 1=1" , "SELECT (1) as foo WHERE 1=1" ),
422
- ("SELECT (1) as föö" , "SELECT (1) as `föö`" ),
423
- ("SELECT (1) as `föö`" , "SELECT (1) as `föö`" ),
424
- ("SELECT (1) as `föö` `umläut" , "SELECT (1) as `föö` `umläut" ),
425
- ("SELECT (1) as `föö" , "SELECT (1) as `föö" ),
426
- ]
427
- for sql , want in cases :
428
- with self .subTest (sql = sql ):
429
- got = backtick_unicode (sql )
430
- self .assertEqual (got , want )
0 commit comments