55import pymysql
66from pymysql .tests import base
77from pymysql ._compat import text_type
8+ from pymysql .constants import CLIENT
89
910
1011class TempUser :
@@ -411,7 +412,7 @@ def test_connection_gone_away(self):
411412 http://dev.mysql.com/doc/refman/5.0/en/gone-away.html
412413 http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error
413414 """
414- con = self .connections [ 0 ]
415+ con = self .connect ()
415416 cur = con .cursor ()
416417 cur .execute ("SET wait_timeout=1" )
417418 time .sleep (2 )
@@ -422,10 +423,9 @@ def test_connection_gone_away(self):
422423 self .assertIn (cm .exception .args [0 ], (2006 , 2013 ))
423424
424425 def test_init_command (self ):
425- conn = pymysql .connect (
426+ conn = self .connect (
426427 init_command = 'SELECT "bar"; SELECT "baz"' ,
427- ** self .databases [0 ]
428- )
428+ client_flag = CLIENT .MULTI_STATEMENTS )
429429 c = conn .cursor ()
430430 c .execute ('select "foobar";' )
431431 self .assertEqual (('foobar' ,), c .fetchone ())
@@ -434,22 +434,21 @@ def test_init_command(self):
434434 conn .ping (reconnect = False )
435435
436436 def test_read_default_group (self ):
437- conn = pymysql .connect (
437+ conn = self .connect (
438438 read_default_group = 'client' ,
439- ** self .databases [0 ]
440439 )
441440 self .assertTrue (conn .open )
442441
443442 def test_context (self ):
444443 with self .assertRaises (ValueError ):
445- c = pymysql .connect (** self . databases [ 0 ] )
444+ c = self .connect ()
446445 with c as cur :
447446 cur .execute ('create table test ( a int )' )
448447 c .begin ()
449448 cur .execute ('insert into test values ((1))' )
450449 raise ValueError ('pseudo abort' )
451450 c .commit ()
452- c = pymysql .connect (** self . databases [ 0 ] )
451+ c = self .connect ()
453452 with c as cur :
454453 cur .execute ('select count(*) from test' )
455454 self .assertEqual (0 , cur .fetchone ()[0 ])
@@ -460,31 +459,31 @@ def test_context(self):
460459 cur .execute ('drop table test' )
461460
462461 def test_set_charset (self ):
463- c = pymysql .connect (** self . databases [ 0 ] )
462+ c = self .connect ()
464463 c .set_charset ('utf8' )
465464 # TODO validate setting here
466465
467466 def test_defer_connect (self ):
468467 import socket
469- for db in self .databases :
470- d = db .copy ()
468+
469+ d = self .databases [0 ].copy ()
470+ try :
471+ sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
472+ sock .connect (d ['unix_socket' ])
473+ except KeyError :
474+ sock = socket .create_connection (
475+ (d .get ('host' , 'localhost' ), d .get ('port' , 3306 )))
476+ for k in ['unix_socket' , 'host' , 'port' ]:
471477 try :
472- sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
473- sock .connect (d ['unix_socket' ])
478+ del d [k ]
474479 except KeyError :
475- sock = socket .create_connection (
476- (d .get ('host' , 'localhost' ), d .get ('port' , 3306 )))
477- for k in ['unix_socket' , 'host' , 'port' ]:
478- try :
479- del d [k ]
480- except KeyError :
481- pass
482-
483- c = pymysql .connect (defer_connect = True , ** d )
484- self .assertFalse (c .open )
485- c .connect (sock )
486- c .close ()
487- sock .close ()
480+ pass
481+
482+ c = pymysql .connect (defer_connect = True , ** d )
483+ self .assertFalse (c .open )
484+ c .connect (sock )
485+ c .close ()
486+ sock .close ()
488487
489488 @unittest2 .skipUnless (sys .version_info [0 :2 ] >= (3 ,2 ), "required py-3.2" )
490489 def test_no_delay_warning (self ):
@@ -560,15 +559,17 @@ def test_escape_list_item(self):
560559 self .assertEqual (con .escape ([Foo ()], mapping ), "(bar)" )
561560
562561 def test_previous_cursor_not_closed (self ):
563- con = self .connections [0 ]
562+ con = self .connect (
563+ init_command = 'SELECT "bar"; SELECT "baz"' ,
564+ client_flag = CLIENT .MULTI_STATEMENTS )
564565 cur1 = con .cursor ()
565566 cur1 .execute ("SELECT 1; SELECT 2" )
566567 cur2 = con .cursor ()
567568 cur2 .execute ("SELECT 3" )
568569 self .assertEqual (cur2 .fetchone ()[0 ], 3 )
569570
570571 def test_commit_during_multi_result (self ):
571- con = self .connections [ 0 ]
572+ con = self .connect ( client_flag = CLIENT . MULTI_STATEMENTS )
572573 cur = con .cursor ()
573574 cur .execute ("SELECT 1; SELECT 2" )
574575 con .commit ()
0 commit comments