1
1
# Copyright (c) Microsoft Corporation. All rights reserved.
2
2
# Licensed under the MIT License.
3
3
4
+ from aiohttp import ClientSession , ClientTimeout
5
+
4
6
from botbuilder .schema import Activity
5
7
from botbuilder .core import BotTelemetryClient , NullTelemetryClient , TurnContext
6
8
from copy import copy
7
- import aiohttp , json , platform , requests
8
- from typing import Dict , List , NamedTuple
9
+ import json , platform , requests
10
+ from typing import Dict , List , NamedTuple , Union
9
11
10
12
from .metadata import Metadata
11
13
from .query_result import QueryResult
@@ -33,7 +35,8 @@ class QnAMaker(QnAMakerTelemetryClient):
33
35
def __init__ (
34
36
self ,
35
37
endpoint : QnAMakerEndpoint ,
36
- options : QnAMakerOptions = None ,
38
+ options : QnAMakerOptions = None ,
39
+ http_client : ClientSession = None ,
37
40
telemetry_client : BotTelemetryClient = None ,
38
41
log_personal_information : bool = None
39
42
):
@@ -43,16 +46,16 @@ def __init__(
43
46
if endpoint .host .endswith ('v2.0' ):
44
47
raise ValueError ('v2.0 of QnA Maker service is no longer supported in the Bot Framework. Please upgrade your QnA Maker service at www.qnamaker.ai.' )
45
48
46
- self ._endpoint = endpoint
49
+ self ._endpoint : str = endpoint
47
50
self ._is_legacy_protocol : bool = self ._endpoint .host .endswith ('v3.0' )
48
51
49
- self ._options : QnAMakerOptions = options or QnAMakerOptions ()
52
+ self ._options = options or QnAMakerOptions ()
50
53
self ._validate_options (self ._options )
51
-
52
- instance_timeout = aiohttp .ClientTimeout (total = self ._options .timeout / 1000 )
53
- self ._req_client = aiohttp .ClientSession (timeout = instance_timeout )
54
54
55
- self ._telemetry_client = telemetry_client or NullTelemetryClient ()
55
+ instance_timeout = ClientTimeout (total = self ._options .timeout / 1000 )
56
+ self ._req_client = http_client or ClientSession (timeout = instance_timeout )
57
+
58
+ self ._telemetry_client : Union [BotTelemetryClient , NullTelemetryClient ] = telemetry_client or NullTelemetryClient ()
56
59
self ._log_personal_information = log_personal_information or False
57
60
58
61
@property
@@ -176,6 +179,8 @@ def fill_qna_event(
176
179
177
180
return EventData (properties = properties , metrics = metrics )
178
181
182
+
183
+
179
184
async def get_answers (
180
185
self ,
181
186
context : TurnContext ,
@@ -261,35 +266,18 @@ async def _query_qna_service(self, turn_context: TurnContext, options: QnAMakerO
261
266
262
267
# Convert miliseconds to seconds (as other BotBuilder SDKs accept timeout value in miliseconds)
263
268
# aiohttp.ClientSession units are in seconds
264
- timeout = aiohttp .ClientTimeout (total = options .timeout / 1000 )
265
-
266
- async with self ._req_client as client :
267
- response = await client .post (
268
- url ,
269
- data = serialized_content ,
270
- headers = headers ,
271
- timeout = timeout
272
- )
273
-
274
- # result = self._format_qna_result(response, options)
275
- json_res = await response .json ()
276
-
277
- answers_within_threshold = [
278
- { ** answer ,'score' : answer ['score' ]/ 100 }
279
- if answer ['score' ]/ 100 > options .score_threshold
280
- else {** answer } for answer in json_res ['answers' ]
281
- ]
282
- sorted_answers = sorted (answers_within_threshold , key = lambda ans : ans ['score' ], reverse = True )
283
-
284
- # The old version of the protocol returns the id in a field called qnaId
285
- # The following translates this old structure to the new
286
- if self ._is_legacy_protocol :
287
- for answer in answers_within_threshold :
288
- answer ['id' ] = answer .pop ('qnaId' , None )
289
-
290
- answers_as_query_results = list (map (lambda answer : QueryResult (** answer ), sorted_answers ))
269
+ timeout = ClientTimeout (total = options .timeout / 1000 )
270
+
271
+ response = await self ._req_client .post (
272
+ url ,
273
+ data = serialized_content ,
274
+ headers = headers ,
275
+ timeout = timeout
276
+ )
291
277
292
- return answers_as_query_results
278
+ result = await self ._format_qna_result (response , options )
279
+
280
+ return result
293
281
294
282
async def _emit_trace_info (self , turn_context : TurnContext , result : [QueryResult ], options : QnAMakerOptions ):
295
283
trace_info = QnAMakerTraceInfo (
@@ -311,11 +299,13 @@ async def _emit_trace_info(self, turn_context: TurnContext, result: [QueryResult
311
299
312
300
await turn_context .send_activity (trace_activity )
313
301
314
- def _format_qna_result (self , result , options : QnAMakerOptions ) -> [QueryResult ]:
302
+ async def _format_qna_result (self , result , options : QnAMakerOptions ) -> [QueryResult ]:
303
+ json_res = await result .json ()
315
304
316
305
answers_within_threshold = [
317
- { ** answer ,'score' : answer ['score' ]/ 100 } for answer in result ['answers' ]
318
- if answer ['score' ]/ 100 > options .score_threshold
306
+ { ** answer ,'score' : answer ['score' ]/ 100 }
307
+ if answer ['score' ]/ 100 > options .score_threshold
308
+ else {** answer } for answer in json_res ['answers' ]
319
309
]
320
310
sorted_answers = sorted (answers_within_threshold , key = lambda ans : ans ['score' ], reverse = True )
321
311
@@ -350,4 +340,4 @@ def get_user_agent(self):
350
340
platform_user_agent = f'({ os_version } ; { py_version } )'
351
341
user_agent = f'{ package_user_agent } { platform_user_agent } '
352
342
353
- return user_agent
343
+ return user_agent
0 commit comments