1
1
import mamonsu .lib .platform as platform
2
2
from distutils .version import LooseVersion
3
- from . _connection import Connection , ConnectionInfo
3
+ from connection import Connection , ConnectionInfo
4
4
5
5
6
- class Pool (ConnectionInfo ):
6
+ class Pool (object ):
7
7
8
8
ExcludeDBs = ['template0' , 'template1' , 'postgres' ]
9
9
@@ -38,8 +38,9 @@ class Pool(ConnectionInfo):
38
38
),
39
39
}
40
40
41
- def __init__ (self ):
42
- super (Pool , self ).__init__ ()
41
+ def __init__ (self , params = {}):
42
+ self ._params = params
43
+ self ._primary_connection_hash = None
43
44
self ._connections = {}
44
45
self ._cache = {
45
46
'server_version' : {'storage' : {}},
@@ -49,38 +50,38 @@ def __init__(self):
49
50
'pgproee' : {'storage' : {}}
50
51
}
51
52
52
- def get_with_default_db (self , db = None ):
53
- if db is None :
54
- return self .default_db
55
- return db
56
-
57
53
def connection_string (self , db = None ):
58
- db = self .get_with_default_db (db )
59
- return self ._connections [db ]._connection_string ()
54
+ db = self ._normalize_db (db )
55
+ return self ._connections [db ].to_string ()
60
56
61
57
def query (self , query , db = None ):
62
- db = self .get_with_default_db (db )
58
+ db = self ._normalize_db (db )
63
59
self ._init_connection (db )
64
60
return self ._connections [db ].query (query )
65
61
66
62
def server_version (self , db = None ):
63
+ db = self ._normalize_db (db )
67
64
if db in self ._cache ['server_version' ]['storage' ]:
68
65
return self ._cache ['server_version' ]['storage' ][db ]
69
66
if platform .PY2 :
70
67
result = self .query ('show server_version' , db )[0 ][0 ]
71
68
elif platform .PY3 :
72
- result = bytes (self .query ('show server_version' , db )[0 ][0 ], 'utf-8' )
69
+ result = bytes (
70
+ self .query ('show server_version' , db )[0 ][0 ], 'utf-8' )
73
71
self ._cache ['server_version' ]['storage' ][db ] = '{0}' .format (
74
72
result .decode ('ascii' ))
75
73
return self ._cache ['server_version' ]['storage' ][db ]
76
74
77
75
def server_version_greater (self , version , db = None ):
76
+ db = self ._normalize_db (db )
78
77
return self .server_version (db ) >= LooseVersion (version )
79
78
80
79
def server_version_less (self , version , db = None ):
80
+ db = self ._normalize_db (db )
81
81
return self .server_version (db ) <= LooseVersion (version )
82
82
83
83
def in_recovery (self , db = None ):
84
+ db = self ._normalize_db (db )
84
85
if db in self ._cache ['recovery' ]['storage' ]:
85
86
if self ._cache ['recovery' ]['counter' ] < self ._cache ['recovery' ]['cache' ]:
86
87
self ._cache ['recovery' ]['counter' ] += 1
@@ -91,6 +92,7 @@ def in_recovery(self, db=None):
91
92
return self ._cache ['recovery' ]['storage' ][db ]
92
93
93
94
def is_bootstraped (self , db = None ):
95
+ db = self ._normalize_db (db )
94
96
if db in self ._cache ['bootstrap' ]['storage' ]:
95
97
if self ._cache ['bootstrap' ]['counter' ] < self ._cache ['bootstrap' ]['cache' ]:
96
98
self ._cache ['bootstrap' ]['counter' ] += 1
@@ -104,10 +106,12 @@ def is_bootstraped(self, db=None):
104
106
self ._connections [db ].log .info ('Found mamonsu bootstrap' )
105
107
else :
106
108
self ._connections [db ].log .info ('Can\' t found mamonsu bootstrap' )
107
- self ._connections [db ].log .info ('hint: run `mamonsu bootstrap` if you want to run without superuser rights' )
109
+ self ._connections [db ].log .info (
110
+ 'hint: run `mamonsu bootstrap` if you want to run without superuser rights' )
108
111
return self ._cache ['bootstrap' ]['storage' ][db ]
109
112
110
113
def is_pgpro (self , db = None ):
114
+ db = self ._normalize_db (db )
111
115
if db in self ._cache ['pgpro' ]:
112
116
return self ._cache ['pgpro' ][db ]
113
117
try :
@@ -118,6 +122,7 @@ def is_pgpro(self, db=None):
118
122
return self ._cache ['pgpro' ][db ]
119
123
120
124
def is_pgpro_ee (self , db = None ):
125
+ db = self ._normalize_db (db )
121
126
if not self .is_pgpro (db ):
122
127
return False
123
128
if db in self ._cache ['pgproee' ]:
@@ -128,6 +133,7 @@ def is_pgpro_ee(self, db=None):
128
133
return self ._cache ['pgproee' ][db ]
129
134
130
135
def extension_installed (self , ext , db = None ):
136
+ db = self ._normalize_db (db )
131
137
result = self .query ('select count(*) from pg_catalog.pg_extension\
132
138
where extname = \' {0}\' ' .format (ext ), db )
133
139
return (int (result [0 ][0 ])) == 1
@@ -141,6 +147,7 @@ def databases(self):
141
147
return databases
142
148
143
149
def get_sql (self , typ , db = None ):
150
+ db = self ._normalize_db (db )
144
151
if typ not in self .SQL :
145
152
raise LookupError ("Unknown SQL type: '{0}'" .format (typ ))
146
153
result = self .SQL [typ ]
@@ -152,10 +159,26 @@ def get_sql(self, typ, db=None):
152
159
def run_sql_type (self , typ , db = None ):
153
160
return self .query (self .get_sql (typ , db ), db )
154
161
162
+ def _normalize_db (self , db = None ):
163
+ if db is None :
164
+ connection_hash = self ._get_primary_connection_hash ()
165
+ db = connection_hash ['db' ]
166
+ return db
167
+
168
+ # cache function for get primary connection params
169
+ def _get_primary_connection_hash (self ):
170
+ if self ._primary_connection_hash is None :
171
+ self ._primary_connection_hash = ConnectionInfo (self ._params ).get_hash ()
172
+ return self ._primary_connection_hash
173
+
174
+ # build connection hash
175
+ def _build_connection_hash (self , db ):
176
+ info = ConnectionInfo (self ._get_primary_connection_hash ()).get_hash ()
177
+ info ['db' ] = self ._normalize_db (db )
178
+ return info
179
+
155
180
def _init_connection (self , db ):
156
- db = self .get_with_default_db (db )
181
+ db = self ._normalize_db (db )
157
182
if db not in self ._connections :
158
183
# create new connection
159
- connection_info = self .connection_info
160
- connection_info ['db' ] = db
161
- self ._connections [db ] = Connection (connection_info )
184
+ self ._connections [db ] = Connection (self ._build_connection_hash (db ))
0 commit comments