5
5
import pymysql
6
6
from pymysql .tests import base
7
7
from pymysql ._compat import text_type
8
+ from pymysql .constants import CLIENT
8
9
9
10
10
11
class TempUser :
@@ -411,7 +412,7 @@ def test_connection_gone_away(self):
411
412
http://dev.mysql.com/doc/refman/5.0/en/gone-away.html
412
413
http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error
413
414
"""
414
- con = self .connections [ 0 ]
415
+ con = self .connect ()
415
416
cur = con .cursor ()
416
417
cur .execute ("SET wait_timeout=1" )
417
418
time .sleep (2 )
@@ -422,10 +423,9 @@ def test_connection_gone_away(self):
422
423
self .assertIn (cm .exception .args [0 ], (2006 , 2013 ))
423
424
424
425
def test_init_command (self ):
425
- conn = pymysql .connect (
426
+ conn = self .connect (
426
427
init_command = 'SELECT "bar"; SELECT "baz"' ,
427
- ** self .databases [0 ]
428
- )
428
+ client_flag = CLIENT .MULTI_STATEMENTS )
429
429
c = conn .cursor ()
430
430
c .execute ('select "foobar";' )
431
431
self .assertEqual (('foobar' ,), c .fetchone ())
@@ -434,22 +434,21 @@ def test_init_command(self):
434
434
conn .ping (reconnect = False )
435
435
436
436
def test_read_default_group (self ):
437
- conn = pymysql .connect (
437
+ conn = self .connect (
438
438
read_default_group = 'client' ,
439
- ** self .databases [0 ]
440
439
)
441
440
self .assertTrue (conn .open )
442
441
443
442
def test_context (self ):
444
443
with self .assertRaises (ValueError ):
445
- c = pymysql .connect (** self . databases [ 0 ] )
444
+ c = self .connect ()
446
445
with c as cur :
447
446
cur .execute ('create table test ( a int )' )
448
447
c .begin ()
449
448
cur .execute ('insert into test values ((1))' )
450
449
raise ValueError ('pseudo abort' )
451
450
c .commit ()
452
- c = pymysql .connect (** self . databases [ 0 ] )
451
+ c = self .connect ()
453
452
with c as cur :
454
453
cur .execute ('select count(*) from test' )
455
454
self .assertEqual (0 , cur .fetchone ()[0 ])
@@ -460,31 +459,31 @@ def test_context(self):
460
459
cur .execute ('drop table test' )
461
460
462
461
def test_set_charset (self ):
463
- c = pymysql .connect (** self . databases [ 0 ] )
462
+ c = self .connect ()
464
463
c .set_charset ('utf8' )
465
464
# TODO validate setting here
466
465
467
466
def test_defer_connect (self ):
468
467
import socket
469
- for db in self .databases :
470
- d = db .copy ()
468
+
469
+ d = self .databases [0 ].copy ()
470
+ try :
471
+ sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
472
+ sock .connect (d ['unix_socket' ])
473
+ except KeyError :
474
+ sock = socket .create_connection (
475
+ (d .get ('host' , 'localhost' ), d .get ('port' , 3306 )))
476
+ for k in ['unix_socket' , 'host' , 'port' ]:
471
477
try :
472
- sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
473
- sock .connect (d ['unix_socket' ])
478
+ del d [k ]
474
479
except KeyError :
475
- sock = socket .create_connection (
476
- (d .get ('host' , 'localhost' ), d .get ('port' , 3306 )))
477
- for k in ['unix_socket' , 'host' , 'port' ]:
478
- try :
479
- del d [k ]
480
- except KeyError :
481
- pass
482
-
483
- c = pymysql .connect (defer_connect = True , ** d )
484
- self .assertFalse (c .open )
485
- c .connect (sock )
486
- c .close ()
487
- sock .close ()
480
+ pass
481
+
482
+ c = pymysql .connect (defer_connect = True , ** d )
483
+ self .assertFalse (c .open )
484
+ c .connect (sock )
485
+ c .close ()
486
+ sock .close ()
488
487
489
488
@unittest2 .skipUnless (sys .version_info [0 :2 ] >= (3 ,2 ), "required py-3.2" )
490
489
def test_no_delay_warning (self ):
@@ -560,15 +559,17 @@ def test_escape_list_item(self):
560
559
self .assertEqual (con .escape ([Foo ()], mapping ), "(bar)" )
561
560
562
561
def test_previous_cursor_not_closed (self ):
563
- con = self .connections [0 ]
562
+ con = self .connect (
563
+ init_command = 'SELECT "bar"; SELECT "baz"' ,
564
+ client_flag = CLIENT .MULTI_STATEMENTS )
564
565
cur1 = con .cursor ()
565
566
cur1 .execute ("SELECT 1; SELECT 2" )
566
567
cur2 = con .cursor ()
567
568
cur2 .execute ("SELECT 3" )
568
569
self .assertEqual (cur2 .fetchone ()[0 ], 3 )
569
570
570
571
def test_commit_during_multi_result (self ):
571
- con = self .connections [ 0 ]
572
+ con = self .connect ( client_flag = CLIENT . MULTI_STATEMENTS )
572
573
cur = con .cursor ()
573
574
cur .execute ("SELECT 1; SELECT 2" )
574
575
con .commit ()
0 commit comments