66from typing import Dict , List
77from threading import Lock
88import json
9-
9+ from hashlib import sha256
10+ from azure .core import MatchConditions
1011from azure .cosmos import documents , http_constants
1112from jsonpickle .pickler import Pickler
1213from jsonpickle .unpickler import Unpickler
1314import azure .cosmos .cosmos_client as cosmos_client # pylint: disable=no-name-in-module,import-error
14- import azure .cosmos .errors as cosmos_errors # pylint: disable=no-name-in-module,import-error
15+ import azure .cosmos .exceptions as cosmos_exceptions
1516from botbuilder .core .storage import Storage
16- from botbuilder .azure import CosmosDbKeyEscape
1717
1818
1919class CosmosDbPartitionedConfig :
@@ -63,6 +63,49 @@ def __init__(
6363 self .compatibility_mode = compatibility_mode or kwargs .get ("compatibility_mode" )
6464
6565
66+ class CosmosDbKeyEscape :
67+ @staticmethod
68+ def sanitize_key (
69+ key : str , key_suffix : str = "" , compatibility_mode : bool = True
70+ ) -> str :
71+ """Return the sanitized key.
72+
73+ Replace characters that are not allowed in keys in Cosmos.
74+
75+ :param key: The provided key to be escaped.
76+ :param key_suffix: The string to add a the end of all RowKeys.
77+ :param compatibility_mode: True if keys should be truncated in order to support previous CosmosDb
78+ max key length of 255. This behavior can be overridden by setting
79+ cosmosdb_partitioned_config.compatibility_mode to False.
80+ :return str:
81+ """
82+ # forbidden characters
83+ bad_chars = ["\\ " , "?" , "/" , "#" , "\t " , "\n " , "\r " , "*" ]
84+ # replace those with with '*' and the
85+ # Unicode code point of the character and return the new string
86+ key = "" .join (map (lambda x : "*" + str (ord (x )) if x in bad_chars else x , key ))
87+
88+ if key_suffix is None :
89+ key_suffix = ""
90+
91+ return CosmosDbKeyEscape .truncate_key (f"{ key } { key_suffix } " , compatibility_mode )
92+
93+ @staticmethod
94+ def truncate_key (key : str , compatibility_mode : bool = True ) -> str :
95+ max_key_len = 255
96+
97+ if not compatibility_mode :
98+ return key
99+
100+ if len (key ) > max_key_len :
101+ aux_hash = sha256 (key .encode ("utf-8" ))
102+ aux_hex = aux_hash .hexdigest ()
103+
104+ key = key [0 : max_key_len - len (aux_hex )] + aux_hex
105+
106+ return key
107+
108+
66109class CosmosDbPartitionedStorage (Storage ):
67110 """A CosmosDB based storage provider using partitioning for a bot."""
68111
@@ -99,7 +142,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
99142 :return dict:
100143 """
101144 if not keys :
102- raise Exception ("Keys are required when reading" )
145+ # No keys passed in, no result to return. Back-compat with original CosmosDBStorage.
146+ return {}
103147
104148 await self .initialize ()
105149
@@ -111,8 +155,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
111155 key , self .config .key_suffix , self .config .compatibility_mode
112156 )
113157
114- read_item_response = self .client . ReadItem (
115- self . __item_link ( escaped_key ) , self .__get_partition_key (escaped_key )
158+ read_item_response = self .container . read_item (
159+ escaped_key , self .__get_partition_key (escaped_key )
116160 )
117161 document_store_item = read_item_response
118162 if document_store_item :
@@ -122,13 +166,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
122166 # When an item is not found a CosmosException is thrown, but we want to
123167 # return an empty collection so in this instance we catch and do not rethrow.
124168 # Throw for any other exception.
125- except cosmos_errors .HTTPFailure as err :
126- if (
127- err .status_code
128- == cosmos_errors .http_constants .StatusCodes .NOT_FOUND
129- ):
130- continue
131- raise err
169+ except cosmos_exceptions .CosmosResourceNotFoundError :
170+ continue
132171 except Exception as err :
133172 raise err
134173 return store_items
@@ -162,20 +201,16 @@ async def write(self, changes: Dict[str, object]):
162201 if e_tag == "" :
163202 raise Exception ("cosmosdb_storage.write(): etag missing" )
164203
165- access_condition = {
166- "accessCondition" : {"type" : "IfMatch" , "condition" : e_tag }
167- }
168- options = (
169- access_condition if e_tag != "*" and e_tag and e_tag != "" else None
170- )
204+ access_condition = e_tag != "*" and e_tag and e_tag != ""
205+
171206 try :
172- self .client .UpsertItem (
173- database_or_Container_link = self .__container_link ,
174- document = doc ,
175- options = options ,
207+ self .container .upsert_item (
208+ body = doc ,
209+ etag = e_tag if access_condition else None ,
210+ match_condition = (
211+ MatchConditions .IfNotModified if access_condition else None
212+ ),
176213 )
177- except cosmos_errors .HTTPFailure as err :
178- raise err
179214 except Exception as err :
180215 raise err
181216
@@ -192,69 +227,66 @@ async def delete(self, keys: List[str]):
192227 key , self .config .key_suffix , self .config .compatibility_mode
193228 )
194229 try :
195- self .client . DeleteItem (
196- document_link = self . __item_link ( escaped_key ) ,
197- options = self .__get_partition_key (escaped_key ),
230+ self .container . delete_item (
231+ escaped_key ,
232+ self .__get_partition_key (escaped_key ),
198233 )
199- except cosmos_errors .HTTPFailure as err :
200- if (
201- err .status_code
202- == cosmos_errors .http_constants .StatusCodes .NOT_FOUND
203- ):
204- continue
205- raise err
234+ except cosmos_exceptions .CosmosResourceNotFoundError :
235+ continue
206236 except Exception as err :
207237 raise err
208238
209239 async def initialize (self ):
210240 if not self .container :
211241 if not self .client :
242+ connection_policy = self .config .cosmos_client_options .get (
243+ "connection_policy" , documents .ConnectionPolicy ()
244+ )
245+
246+ # kwargs 'connection_verify' is to handle CosmosClient overwriting the
247+ # ConnectionPolicy.DisableSSLVerification value.
212248 self .client = cosmos_client .CosmosClient (
213249 self .config .cosmos_db_endpoint ,
214- {"masterKey" : self .config .auth_key },
215- self .config .cosmos_client_options .get ("connection_policy" , None ),
250+ self .config .auth_key ,
216251 self .config .cosmos_client_options .get ("consistency_level" , None ),
252+ ** {
253+ "connection_policy" : connection_policy ,
254+ "connection_verify" : not connection_policy .DisableSSLVerification ,
255+ },
217256 )
218257
219258 if not self .database :
220259 with self .__lock :
221- try :
222- if not self .database :
223- self .database = self .client .CreateDatabase (
224- {"id" : self .config .database_id }
225- )
226- except cosmos_errors .HTTPFailure :
227- self .database = self .client .ReadDatabase (
228- "dbs/" + self .config .database_id
260+ if not self .database :
261+ self .database = self .client .create_database_if_not_exists (
262+ self .config .database_id
229263 )
230264
231265 self .__get_or_create_container ()
232266
233267 def __get_or_create_container (self ):
234268 with self .__lock :
235- container_def = {
236- "id" : self .config .container_id ,
237- "partitionKey" : {
238- "paths" : ["/id" ],
239- "kind" : documents .PartitionKind .Hash ,
240- },
269+ partition_key = {
270+ "paths" : ["/id" ],
271+ "kind" : documents .PartitionKind .Hash ,
241272 }
242273 try :
243274 if not self .container :
244- self .container = self .client . CreateContainer (
245- "dbs/" + self .database [ "id" ] ,
246- container_def ,
247- { "offerThroughput" : self .config .container_throughput } ,
275+ self .container = self .database . create_container (
276+ self .config . container_id ,
277+ partition_key ,
278+ offer_throughput = self .config .container_throughput ,
248279 )
249- except cosmos_errors . HTTPFailure as err :
280+ except cosmos_exceptions . CosmosHttpResponseError as err :
250281 if err .status_code == http_constants .StatusCodes .CONFLICT :
251- self .container = self .client . ReadContainer (
252- "dbs/" + self .database [ "id" ] + "/colls/" + container_def [ "id" ]
282+ self .container = self .database . get_container_client (
283+ self .config . container_id
253284 )
254- if "partitionKey" not in self .container :
285+ properties = self .container .read ()
286+ if "partitionKey" not in properties :
255287 self .compatability_mode_partition_key = True
256288 else :
257- paths = self . container ["partitionKey" ]["paths" ]
289+ paths = properties ["partitionKey" ]["paths" ]
258290 if "/partitionKey" in paths :
259291 self .compatability_mode_partition_key = True
260292 elif "/id" not in paths :
@@ -267,7 +299,7 @@ def __get_or_create_container(self):
267299 raise err
268300
269301 def __get_partition_key (self , key : str ) -> str :
270- return None if self .compatability_mode_partition_key else { "partitionKey" : key }
302+ return None if self .compatability_mode_partition_key else key
271303
272304 @staticmethod
273305 def __create_si (result ) -> object :
@@ -303,28 +335,3 @@ def __create_dict(store_item: object) -> Dict:
303335
304336 # loop through attributes and write and return a dict
305337 return json_dict
306-
307- def __item_link (self , identifier ) -> str :
308- """Return the item link of a item in the container.
309-
310- :param identifier:
311- :return str:
312- """
313- return self .__container_link + "/docs/" + identifier
314-
315- @property
316- def __container_link (self ) -> str :
317- """Return the container link in the database.
318-
319- :param:
320- :return str:
321- """
322- return self .__database_link + "/colls/" + self .config .container_id
323-
324- @property
325- def __database_link (self ) -> str :
326- """Return the database link.
327-
328- :return str:
329- """
330- return "dbs/" + self .config .database_id
0 commit comments