8000 added option to allow user to pass their own ClientSession for parity · baruchiro/botbuilder-python@5e2e763 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5e2e763

Browse files
committed
added option to allow user to pass their own ClientSession for parity
1 parent 1b1e5a0 commit 5e2e763

File tree

3 files changed

+32
-68
lines changed

3 files changed

+32
-68
lines changed

libraries/botbuilder-ai/botbuilder/ai/qna/qnamaker.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4+
from aiohttp import ClientSession, ClientTimeout
5+
46
from botbuilder.schema import Activity
57
from botbuilder.core import BotTelemetryClient, NullTelemetryClient, TurnContext
68
from copy import copy
7-
import aiohttp, json, platform, requests
8-
from typing import Dict, List, NamedTuple
9+
import json, platform, requests
10+
from typing import Dict, List, NamedTuple, Union
911

1012
from .metadata import Metadata
1113
from .query_result import QueryResult
@@ -33,7 +35,8 @@ class QnAMaker(QnAMakerTelemetryClient):
3335
def __init__(
3436
self,
3537
endpoint: QnAMakerEndpoint,
36-
options: QnAMakerOptions = None,
38+
options: QnAMakerOptions = None,
39+
http_client: ClientSession = None,
3740
telemetry_client: BotTelemetryClient = None,
3841
log_personal_information: bool = None
3942
):
@@ -43,16 +46,16 @@ def __init__(
4346
if endpoint.host.endswith('v2.0'):
4447
raise ValueError('v2.0 of QnA Maker service is no longer supported in the Bot Framework. Please upgrade your QnA Maker service at www.qnamaker.ai.')
4548

46-
self._endpoint = endpoint
49+
self._endpoint: str = endpoint
4750
self._is_legacy_protocol: bool = self._endpoint.host.endswith('v3.0')
4851

49-
self._options: QnAMakerOptions = options or QnAMakerOptions()
52+
self._options = options or QnAMakerOptions()
5053
self._validate_options(self._options)
51-
52-
instance_timeout = aiohttp.ClientTimeout(total=self._options.timeout/1000)
53-
self._req_client = aiohttp.ClientSession(timeout=instance_timeout)
5454

55-
self._telemetry_client = telemetry_client or NullTelemetryClient()
55+
instance_timeout = ClientTimeout(total=self._options.timeout/1000)
56+
self._req_client = http_client or ClientSession(timeout=instance_timeout)
57+
58+
self._telemetry_client: Union[BotTelemetryClient, NullTelemetryClient] = telemetry_client or NullTelemetryClient()
5659
self._log_personal_information = log_personal_information or False
5760

5861
@property
@@ -176,6 +179,8 @@ def fill_qna_event(
176179

177180
return EventData(properties=properties, metrics=metrics)
178181

182+
183+
179184
async def get_answers(
180185
self,
181186
context: TurnContext,
@@ -261,35 +266,18 @@ async def _query_qna_service(self, turn_context: TurnContext, options: QnAMakerO
261266

262267
# Convert miliseconds to seconds (as other BotBuilder SDKs accept timeout value in miliseconds)
263268
# aiohttp.ClientSession units are in seconds
264-
timeout = aiohttp.ClientTimeout(total=options.timeout/1000)
265-
266-
async with self._req_client as client:
267-
response = await client.post(
268-
url,
269-
data = serialized_content,
270-
headers = headers,
271-
timeout = timeout
272-
)
273-
274-
# result = self._format_qna_result(response, options)
275-
json_res = await response.json()
276-
277-
answers_within_threshold = [
278-
{ **answer,'score': answer['score']/100 }
279-
if answer['score']/100 > options.score_threshold
280-
else {**answer} for answer in json_res['answers']
281-
]
282-
sorted_answers = sorted(answers_within_threshold, key = lambda ans: ans['score'], reverse = True)
283-
284-
# The old version of the protocol returns the id in a field called qnaId
285-
# The following translates this old structure to the new
286-
if self._is_legacy_protocol:
287-
for answer in answers_within_threshold:
288-
answer['id'] = answer.pop('qnaId', None)
289-
290-
answers_as_query_results = list(map(lambda answer: QueryResult(**answer), sorted_answers))
269+
timeout = ClientTimeout(total=options.timeout/1000)
270+
271+
response = await self._req_client.post(
272+
url,
273+
data = serialized_content,
274+
headers = headers,
275+
timeout = timeout
276+
)
291277

292-
return answers_as_query_results
278+
result = await self._format_qna_result(response, options)
279+
280+
return result
293281

294282
async def _emit_trace_info(self, turn_context: TurnContext, result: [QueryResult], options: QnAMakerOptions):
295283
trace_info = QnAMakerTraceInfo(
@@ -311,11 +299,13 @@ async def _emit_trace_info(self, turn_context: TurnContext, result: [QueryResult
311299

312300
await turn_context.send_activity(trace_activity)
313301

314-
def _format_qna_result(self, result, options: QnAMakerOptions) -> [QueryResult]:
302+
async def _format_qna_result(self, result, options: QnAMakerOptions) -> [QueryResult]:
303+
json_res = await result.json()
315304

316305
answers_within_threshold = [
317-
{ **answer,'score': answer['score']/100 } for answer in result['answers']
318-
if answer['score']/100 > options.score_threshold
306+
{ **answer,'score': answer['score']/100 }
307+
if answer['score']/100 > options.score_threshold
308+
else {**answer} for answer in json_res['answers']
319309
]
320310
sorted_answers = sorted(answers_within_threshold, key = lambda ans: ans['score'], reverse = True)
321311

@@ -350,4 +340,4 @@ def get_user_agent(self):
350340
platform_user_agent = f'({os_version}; {py_version})'
351341
user_agent = f'{package_user_agent} {platform_user_agent}'
352342

353-
return user_agent
343+
return user_agent

libraries/botbuilder-ai/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
long_description=package_info["__summary__"],
3333
license=package_info["__license__"],
3434
packages=["botbuilder.ai", "botbuilder.ai.qna", "botbuilder.ai.luis"],
35-
install_requires=REQUIRES + TESTS_REQUIRES,
35+
install_requires=REQUIRES,
3636
tests_require=TESTS_REQUIRES,
3737
include_package_data=True,
3838
classifiers=[

libraries/botbuilder-ai/tests/qna/test_qna.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
from botbuilder.core import BotAdapter, BotTelemetryClient, NullTelemetryClient, TurnContext
1616
from botbuilder.core.adapters import TestAdapter
1717
from botbuilder.schema import Activity, ActivityTypes, ChannelAccount, ResourceResponse, ConversationAccount
18-
19-
class InterceptRequestClient(ClientSession):
20-
def __init__(self, timeout):
21-
super().__init__(timeout=timeout.total)
22-
self.intercepted_headers = None
23-
2418

2519
class TestContext(TurnContext):
2620
def __init__(self, request):
@@ -52,9 +46,6 @@ def test_qnamaker_construction(self):
5246
endpoint = qna._endpoint
5347

5448
# Assert
55-
# self.assertEqual('a090f9f3-2f8e-41d1-a581-4f7a49269a0c', endpoint.knowledge_base_id)
56-
# self.assertEqual('4a439d5b-163b-47c3-b1d1-168cc0db5608', endpoint.endpoint_key)
57-
# self.assertEqual('https://ashleyNlpBot1-qnahost.azurewebsites.net/qnamaker', endpoint.host)
5849
self.assertEqual('f028d9k3-7g9z-11d3-d300-2b8x98227q8w', endpoint.knowledge_base_id)
5950
self.assertEqual('1k997n7w-207z-36p3-j2u1-09tas20ci6011', endpoint.endpoint_key)
6051
self.assertEqual('https://dummyqnahost.azurewebsites.net/qnamaker', endpoint.host)
@@ -275,19 +266,7 @@ async def test_returns_answer_with_timeout(self):
275266
self.assertIsNotNone(result)
276267
self.assertEqual(options.timeout, qna._options.timeout)
277268

278-
milisec_to_sec_timeout = options.timeout/1000
279-
self.assertEqual(milisec_to_sec_timeout, qna._req_client._timeout.total)
280269

281-
# Work in progress
282-
async def test_user_agent(self):
283-
question = 'up'
284-
timeout = ClientTimeout(total=300000)
285-
intercept_client = InterceptRequestClient(timeout=timeout)
286-
qna = QnAMaker(QnaApplicationTest.tests_endpoint)
287-
# context = QnaApplicationTest._get_context(question, TestAdapter())
288-
# response = await qna.get_answers(context)
289-
290-
pass
291270

292271
@classmethod
293272
async def _get_service_result(
@@ -337,8 +316,3 @@ def _get_context(utterance: str, bot_adapter: BotAdapter) -> TurnContext:
337316

338317
return TurnContext(test_adapter, activity)
339318

340-
341-
342-
343-
344-

0 commit comments

Comments
 (0)
0