|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +from uuid import uuid4 |
| 5 | +from typing import Any, List |
| 6 | + |
| 7 | +from jsonpickle import Pickler |
| 8 | +from botbuilder.core import BotState, ConversationState, TurnContext, UserState |
| 9 | +from botbuilder.schema import Activity, ActivityTypes, ConversationReference |
| 10 | +from botframework.connector.auth import MicrosoftAppCredentials |
| 11 | + |
| 12 | +from .inspection_session import InspectionSession |
| 13 | +from .inspection_sessions_by_status import ( |
| 14 | + InspectionSessionsByStatus, |
| 15 | + DEFAULT_INSPECTION_SESSIONS_BY_STATUS, |
| 16 | +) |
| 17 | +from .inspection_state import InspectionState |
| 18 | +from .interception_middleware import InterceptionMiddleware |
| 19 | +from .trace_activity import from_state, make_command_activity |
| 20 | + |
| 21 | + |
| 22 | +class InspectionMiddleware(InterceptionMiddleware): |
| 23 | + _COMMAND = "/INSPECT" |
| 24 | + |
| 25 | + def __init__( # pylint: disable=super-init-not-called |
| 26 | + self, |
| 27 | + inspection_state: InspectionState, |
| 28 | + user_state: UserState = None, |
| 29 | + conversation_state: ConversationState = None, |
| 30 | + credentials: MicrosoftAppCredentials = None, |
| 31 | + ): |
| 32 | + |
| 33 | + self.inspection_state = inspection_state |
| 34 | + self.inspection_state_accessor = inspection_state.create_property( |
| 35 | + "InspectionSessionByStatus" |
| 36 | + ) |
| 37 | + self.user_state = user_state |
| 38 | + self.conversation_state = conversation_state |
| 39 | + self.credentials = MicrosoftAppCredentials( |
| 40 | + credentials.microsoft_app_id if credentials else "", |
| 41 | + credentials.microsoft_app_password if credentials else "", |
| 42 | + ) |
| 43 | + |
| 44 | + async def process_command(self, context: TurnContext) -> Any: |
| 45 | + if context.activity.type == ActivityTypes.message and context.activity.text: |
| 46 | + |
| 47 | + original_text = context.activity.text |
| 48 | + TurnContext.remove_recipient_mention(context.activity) |
| 49 | + |
| 50 | + command = context.activity.text.strip().split(" ") |
| 51 | + if len(command) > 1 and command[0] == InspectionMiddleware._COMMAND: |
| 52 | + |
| 53 | + if len(command) == 2 and command[1] == "open": |
| 54 | + await self._process_open_command(context) |
| 55 | + return True |
| 56 | + |
| 57 | + if len(command) == 3 and command[1] == "attach": |
| 58 | + await self.process_attach_command(context, command[2]) |
| 59 | + return True |
| 60 | + |
| 61 | + context.activity.text = original_text |
| 62 | + |
| 63 | + return False |
| 64 | + |
| 65 | + async def _inbound(self, context: TurnContext, trace_activity: Activity) -> Any: |
| 66 | + if await self.process_command(context): |
| 67 | + return False, False |
| 68 | + |
| 69 | + session = await self._find_session(context) |
| 70 | + if session: |
| 71 | + if await self._invoke_send(context, session, trace_activity): |
| 72 | + return True, True |
| 73 | + return True, False |
| 74 | + |
| 75 | + async def _outbound( |
| 76 | + self, context: TurnContext, trace_activities: List[Activity] |
| 77 | + ) -> Any: |
| 78 | + session = await self._find_session(context) |
| 79 | + if session: |
| 80 | + for trace_activity in trace_activities: |
| 81 | + if not await self._invoke_send(context, session, trace_activity): |
| 82 | + break |
| 83 | + |
| 84 | + async def _trace_state(self, context: TurnContext) -> Any: |
| 85 | + session = await self._find_session(context) |
| 86 | + if session: |
| 87 | + if self.user_state: |
| 88 | + await self.user_state.load(context, False) |
| 89 | + |
| 90 | + if self.conversation_state: |
| 91 | + await self.conversation_state.load(context, False) |
| 92 | + |
| 93 | + bot_state = {} |
| 94 | + |
| 95 | + if self.user_state: |
| 96 | + bot_state["user_state"] = InspectionMiddleware._get_serialized_context( |
| 97 | + self.user_state, context |
| 98 | + ) |
| 99 | + |
| 100 | + if self.conversation_state: |
| 101 | + bot_state[ |
| 102 | + "conversation_state" |
| 103 | + ] = InspectionMiddleware._get_serialized_context( |
| 104 | + self.conversation_state, context |
| 105 | + ) |
| 106 | + |
| 107 | + await self._invoke_send(context, session, from_state(bot_state)) |
| 108 | + |
| 109 | + async def _process_open_command(self, context: TurnContext) -> Any: |
| 110 | + sessions = await self.inspection_state_accessor.get( |
| 111 | + context, DEFAULT_INSPECTION_SESSIONS_BY_STATUS |
| 112 | + ) |
| 113 | + session_id = self._open_command( |
| 114 | + sessions, TurnContext.get_conversation_reference(context.activity) |
| 115 | + ) |
| 116 | + await context.send_activity( |
| 117 | + make_command_activity( |
| 118 | + f"{InspectionMiddleware._COMMAND} attach {session_id}" |
| 119 | + ) |
| 120 | + ) |
| 121 | + await self.inspection_state.save_changes(context, False) |
| 122 | + |
| 123 | + async def process_attach_command( |
| 124 | + self, context: TurnContext, session_id: str |
| 125 | + ) -> None: |
| 126 | + sessions = await self.inspection_state_accessor.get( |
| 127 | + context, DEFAULT_INSPECTION_SESSIONS_BY_STATUS |
| 128 | + ) |
| 129 | + |
| 130 | + if self._attach_comamnd(context.activity.conversation.id, sessions, session_id): |
| 131 | + await context.send_activity( |
| 132 | + "Attached to session, all traffic is being replicated for inspection." |
| 133 | + ) |
| 134 | + else: |
| 135 | + await context.send_activity( |
| 136 | + f"Open session with id {session_id} does not exist." |
| 137 | + ) |
| 138 | + |
| 139 | + await self.inspection_state.save_changes(context, False) |
| 140 | + |
| 141 | + def _open_command( |
| 142 | + self, |
| 143 | + sessions: InspectionSessionsByStatus, |
| 144 | + conversation_reference: ConversationReference, |
| 145 | + ) -> str: |
| 146 | + session_id = str(uuid4()) |
| 147 | + sessions.opened_sessions[session_id] = conversation_reference |
| 148 | + return session_id |
| 149 | + |
| 150 | + def _attach_comamnd( |
| 151 | + self, |
| 152 | + conversation_id: str, |
| 153 | + sessions: InspectionSessionsByStatus, |
| 154 | + session_id: str, |
| 155 | + ) -> bool: |
| 156 | + inspection_session_state = sessions.opened_sessions.get(session_id) |
| 157 | + if inspection_session_state: |
| 158 | + sessions.attached_sessions[conversation_id] = inspection_session_state |
| 159 | + del sessions.opened_sessions[session_id] |
| 160 | + return True |
| 161 | + |
| 162 | + return False |
| 163 | + |
| 164 | + @staticmethod |
| 165 | + def _get_serialized_context(state: BotState, context: TurnContext): |
| 166 | + ctx = state.get(context) |
| 167 | + return Pickler(unpicklable=False).flatten(ctx) |
| 168 | + |
| 169 | + async def _find_session(self, context: TurnContext) -> Any: |
| 170 | + sessions = await self.inspection_state_accessor.get( |
| 171 | + context, DEFAULT_INSPECTION_SESSIONS_BY_STATUS |
| 172 | + ) |
| 173 | + |
| 174 | + conversation_reference = sessions.attached_sessions.get( |
| 175 | + context.activity.conversation.id |
| 176 | + ) |
| 177 | + if conversation_reference: |
| 178 | + return InspectionSession(conversation_reference, self.credentials) |
| 179 | + |
| 180 | + return None |
| 181 | + |
| 182 | + async def _invoke_send( |
| 183 | + self, context: TurnContext, session: InspectionSession, activity: Activity |
| 184 | + ) -> bool: |
| 185 | + if await session.send(activity): |
| 186 | + return True |
| 187 | + |
| 188 | + await self._clean_up_session(context) |
| 189 | + return False |
| 190 | + |
| 191 | + async def _clean_up_session(self, context: TurnContext) -> None: |
| 192 | + sessions = await self.inspection_state_accessor.get( |
| 193 | + context, DEFAULT_INSPECTION_SESSIONS_BY_STATUS |
| 194 | + ) |
| 195 | + |
| 196 | + del sessions.attached_sessions[context.activity.conversation.id] |
| 197 | + await self.inspection_state.save_changes(context, False) |
0 commit comments