6
6
from typing import Dict , List
7
7
from threading import Lock
8
8
import json
9
-
9
+ from hashlib import sha256
10
+ from azure .core import MatchConditions
10
11
from azure .cosmos import documents , http_constants
11
12
from jsonpickle .pickler import Pickler
12
13
from jsonpickle .unpickler import Unpickler
13
14
import azure .cosmos .cosmos_client as cosmos_client # pylint: disable=no-name-in-module,import-error
14
- import azure .cosmos .errors as cosmos_errors # pylint: disable=no-name-in-module,import-error
15
+ import azure .cosmos .exceptions as cosmos_exceptions
15
16
from botbuilder .core .storage import Storage
16
- from botbuilder .azure import CosmosDbKeyEscape
17
17
18
18
19
19
class CosmosDbPartitionedConfig :
@@ -63,6 +63,49 @@ def __init__(
63
63
self .compatibility_mode = compatibility_mode or kwargs .get ("compatibility_mode" )
64
64
65
65
66
+ class CosmosDbKeyEscape :
67
+ @staticmethod
68
+ def sanitize_key (
69
+ key : str , key_suffix : str = "" , compatibility_mode : bool = True
70
+ ) -> str :
71
+ """Return the sanitized key.
72
+
73
+ Replace characters that are not allowed in keys in Cosmos.
74
+
75
+ :param key: The provided key to be escaped.
76
+ :param key_suffix: The string to add a the end of all RowKeys.
77
+ :param compatibility_mode: True if keys should be truncated in order to support previous CosmosDb
78
+ max key length of 255. This behavior can be overridden by setting
79
+ cosmosdb_partitioned_config.compatibility_mode to False.
80
+ :return str:
81
+ """
82
+ # forbidden characters
83
+ bad_chars = ["\\ " , "?" , "/" , "#" , "\t " , "\n " , "\r " , "*" ]
84
+ # replace those with with '*' and the
85
+ # Unicode code point of the character and return the new string
86
+ key = "" .join (map (lambda x : "*" + str (ord (x )) if x in bad_chars else x , key ))
87
+
88
+ if key_suffix is None :
89
+ key_suffix = ""
90
+
91
+ return CosmosDbKeyEscape .truncate_key (f"{ key } { key_suffix } " , compatibility_mode )
92
+
93
+ @staticmethod
94
+ def truncate_key (key : str , compatibility_mode : bool = True ) -> str :
95
+ max_key_len = 255
96
+
97
+ if not compatibility_mode :
98
+ return key
99
+
100
+ if len (key ) > max_key_len :
101
+ aux_hash = sha256 (key .encode ("utf-8" ))
102
+ aux_hex = aux_hash .hexdigest ()
103
+
104
+ key = key [0 : max_key_len - len (aux_hex )] + aux_hex
105
+
106
+ return key
107
+
108
+
66
109
class CosmosDbPartitionedStorage (Storage ):
67
110
"""A CosmosDB based storage provider using partitioning for a bot."""
68
111
@@ -99,7 +142,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
99
142
:return dict:
100
143
"""
101
144
if not keys :
102
- raise Exception ("Keys are required when reading" )
145
+ # No keys passed in, no result to return. Back-compat with original CosmosDBStorage.
146
+ return {}
103
147
104
148
await self .initialize ()
105
149
@@ -111,8 +155,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
111
155
key , self .config .key_suffix , self .config .compatibility_mode
112
156
)
113
157
114
- read_item_response = self .client . ReadItem (
115
- self . __item_link ( escaped_key ) , self .__get_partition_key (escaped_key )
158
+ read_item_response = self .container . read_item (
159
+ escaped_key , self .__get_partition_key (escaped_key )
116
160
)
117
161
document_store_item = read_item_response
118
162
if document_store_item :
@@ -122,13 +166,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
122
166
# When an item is not found a CosmosException is thrown, but we want to
123
167
# return an empty collection so in this instance we catch and do not rethrow.
124
168
# Throw for any other exception.
125
- except cosmos_errors .HTTPFailure as err :
126
- if (
127
- err .status_code
128
- == cosmos_errors .http_constants .StatusCodes .NOT_FOUND
129
- ):
130
- continue
131
- raise err
169
+ except cosmos_exceptions .CosmosResourceNotFoundError :
170
+ continue
132
171
except Exception as err :
133
172
raise err
134
173
return store_items
@@ -162,20 +201,16 @@ async def write(self, changes: Dict[str, object]):
162
201
if e_tag == "" :
163
202
raise Exception ("cosmosdb_storage.write(): etag missing" )
164
203
165
- access_condition = {
166
- "accessCondition" : {"type" : "IfMatch" , "condition" : e_tag }
167
- }
168
- options = (
169
- access_condition if e_tag != "*" and e_tag and e_tag != "" else None
170
- )
204
+ access_condition = e_tag != "*" and e_tag and e_tag != ""
205
+
171
206
try :
172
- self .client .UpsertItem (
173
- database_or_Container_link = self .__container_link ,
174
- document = doc ,
175
- options = options ,
207
+ self .container .upsert_item (
208
+ body = doc ,
209
+ etag = e_tag if access_condition else None ,
210
+ match_condition = (
211
+ MatchConditions .IfNotModified if access_condition else None
212
+ ),
176
213
)
177
- except cosmos_errors .HTTPFailure as err :
178
- raise err
179
214
except Exception as err :
180
215
raise err
181
216
@@ -192,69 +227,66 @@ async def delete(self, keys: List[str]):
192
227
key , self .config .key_suffix , self .config .compatibility_mode
193
228
)
194
229
try :
195
- self .client . DeleteItem (
196
- document_link = self . __item_link ( escaped_key ) ,
197
- options = self .__get_partition_key (escaped_key ),
230
+ self .container . delete_item (
231
+ escaped_key ,
232
+ self .__get_partition_key (escaped_key ),
198
233
)
199
- except cosmos_errors .HTTPFailure as err :
200
- if (
201
- err .status_code
202
- == cosmos_errors .http_constants .StatusCodes .NOT_FOUND
203
- ):
204
- continue
205
- raise err
234
+ except cosmos_exceptions .CosmosResourceNotFoundError :
235
+ continue
206
236
except Exception as err :
207
237
raise err
208
238
209
239
async def initialize (self ):
210
240
if not self .container :
211
241
if not self .client :
242
+ connection_policy = self .config .cosmos_client_options .get (
243
+ "connection_policy" , documents .ConnectionPolicy ()
244
+ )
245
+
246
+ # kwargs 'connection_verify' is to handle CosmosClient overwriting the
247
+ # ConnectionPolicy.DisableSSLVerification value.
212
248
self .client = cosmos_client .CosmosClient (
213
249
self .config .cosmos_db_endpoint ,
214
- {"masterKey" : self .config .auth_key },
215
- self .config .cosmos_client_options .get ("connection_policy" , None ),
250
+ self .config .auth_key ,
216
251
self .config .cosmos_client_options .get ("consistency_level" , None ),
252
+ ** {
253
+ "connection_policy" : connection_policy ,
254
+ "connection_verify" : not connection_policy .DisableSSLVerification ,
255
+ },
217
256
)
218
257
219
258
if not self .database :
220
259
with self .__lock :
221
- try :
222
- if not self .database :
223
- self .database = self .client .CreateDatabase (
224
- {"id" : self .config .database_id }
225
- )
226
- except cosmos_errors .HTTPFailure :
227
- self .database = self .client .ReadDatabase (
228
- "dbs/" + self .config .database_id
260
+ if not self .database :
261
+ self .database = self .client .create_database_if_not_exists (
262
+ self .config .database_id
229
263
)
230
264
231
265
self .__get_or_create_container ()
232
266
233
267
def __get_or_create_container (self ):
234
268
with self .__lock :
235
- container_def = {
236
- "id" : self .config .container_id ,
237
- "partitionKey" : {
238
- "paths" : ["/id" ],
239
- "kind" : documents .PartitionKind .Hash ,
240
- },
269
+ partition_key = {
270
+ "paths" : ["/id" ],
271
+ "kind" : documents .PartitionKind .Hash ,
241
272
}
242
273
try :
243
274
if not self .container :
244
- self .container = self .client . CreateContainer (
245
- "dbs/" + self .database [ "id" ] ,
246
- container_def ,
247
- { "offerThroughput" : self .config .container_throughput } ,
275
+ self .container = self .database . create_container (
276
+ self .config . container_id ,
277
+ partition_key ,
278
+ offer_throughput = self .config .container_throughput ,
248
279
)
249
- except cosmos_errors . HTTPFailure as err :
280
+ except cosmos_exceptions . CosmosHttpResponseError as err :
250
281
if err .status_code == http_constants .StatusCodes .CONFLICT :
251
- self .container = self .client . ReadContainer (
252
- "dbs/" + self .database [ "id" ] + "/colls/" + container_def [ "id" ]
282
+ self .container = self .database . get_container_client (
283
+ self .config . container_id
253
284
)
254
- if "partitionKey" not in self .container :
285
+ properties = self .container .read ()
286
+ if "partitionKey" not in properties :
255
287
self .compatability_mode_partition_key = True
256
288
else :
257
- paths = self . container ["partitionKey" ]["paths" ]
289
+ paths = properties ["partitionKey" ]["paths" ]
258
290
if "/partitionKey" in paths :
259
291
self .compatability_mode_partition_key = True
260
292
elif "/id" not in paths :
@@ -267,7 +299,7 @@ def __get_or_create_container(self):
267
299
raise err
268
300
269
301
def __get_partition_key (self , key : str ) -> str :
270
- return None if self .compatability_mode_partition_key else { "partitionKey" : key }
302
+ return None if self .compatability_mode_partition_key else key
271
303
272
304
@staticmethod
273
305
def __create_si (result ) -> object :
@@ -303,28 +335,3 @@ def __create_dict(store_item: object) -> Dict:
303
335
304
336
# loop through attributes and write and return a dict
305
337
return json_dict
306
-
307
- def __item_link (self , identifier ) -> str :
308
- """Return the item link of a item in the container.
309
-
310
- :param identifier:
311
- :return str:
312
- """
313
- return self .__container_link + "/docs/" + identifier
314
-
315
- @property
316
- def __container_link (self ) -> str :
317
- """Return the container link in the database.
318
-
319
- :param:
320
- :return str:
321
- """
322
- return self .__database_link + "/colls/" + self .config .container_id
323
-
324
- @property
325
- def __database_link (self ) -> str :
326
- """Return the database link.
327
-
328
- :return str:
329
- """
330
- return "dbs/" + self .config .database_id
0 commit comments