|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +from datetime import datetime |
| 5 | + |
| 6 | +from botbuilder.core import ActivityHandler, ConversationState, TurnContext, UserState, MessageFactory |
| 7 | +from recognizers_number import recognize_number, Culture |
| 8 | +from recognizers_date_time import recognize_datetime |
| 9 | + |
| 10 | +from data_models import ConversationFlow, Question, UserProfile |
| 11 | + |
| 12 | + |
| 13 | +class ValidationResult: |
| 14 | + def __init__(self, is_valid: bool = False, value: object = None, message: str = None): |
| 15 | + self.is_valid = is_valid |
| 16 | + self.value = value |
| 17 | + self.message = message |
| 18 | + |
| 19 | + |
| 20 | +class CustomPromptBot(ActivityHandler): |
| 21 | + def __init__(self, conversation_state: ConversationState, user_state: UserState): |
| 22 | + if conversation_state is None: |
| 23 | + raise TypeError( |
| 24 | + "[CustomPromptBot]: Missing parameter. conversation_state is required but None was given" |
| 25 | + ) |
| 26 | + if user_state is None: |
| 27 | + raise TypeError( |
| 28 | + "[CustomPromptBot]: Missing parameter. user_state is required but None was given" |
| 29 | + ) |
| 30 | + |
| 31 | + self.conversation_state = conversation_state |
| 32 | + self.user_state = user_state |
| 33 | + |
| 34 | + self.flow_accessor = self.conversation_state.create_property("ConversationFlow") |
| 35 | + self.profile_accessor = self.conversation_state.create_property("UserProfile") |
| 36 | + |
| 37 | + async def on_message_activity(self, turn_context: TurnContext): |
| 38 | + # Get the state properties from the turn context. |
| 39 | + profile = await self.profile_accessor.get(turn_context, UserProfile) |
| 40 | + flow = await self.flow_accessor.get(turn_context, ConversationFlow) |
| 41 | + |
| 42 | + await self._fill_out_user_profile(flow, profile, turn_context) |
| 43 | + |
| 44 | + # Save changes to UserState and ConversationState |
| 45 | + await self.conversation_state.save_changes(turn_context) |
| 46 | + await self.user_state.save_changes(turn_context) |
| 47 | + |
| 48 | + async def _fill_out_user_profile(self, flow: ConversationFlow, profile: UserProfile, turn_context: TurnContext): |
| 49 | + user_input = turn_context.activity.text.strip() |
| 50 | + |
| 51 | + # ask for name |
| 52 | + if flow.last_question_asked == Question.NONE: |
| 53 | + await turn_context.send_activity(MessageFactory.text("Let's get started. What is your name?")) |
| 54 | + flow.last_question_asked = Question.NAME |
| 55 | + |
| 56 | + # validate name then ask for age |
| 57 | + elif flow.last_question_asked == Question.NAME: |
| 58 | + validate_result = self._validate_name(user_input) |
| 59 | + if not validate_result.is_valid: |
| 60 | + await turn_context.send_activity(MessageFactory.text(validate_result.message)) |
| 61 | + else: |
| 62 | + profile.name = validate_result.value |
| 63 | + await turn_context.send_activity(MessageFactory.text(f"Hi {profile.name}")) |
| 64 | + await turn_context.send_activity(MessageFactory.text("How old are you?")) |
| 65 | + flow.last_question_asked = Question.AGE |
| 66 | + |
| 67 | + # validate age then ask for date |
| 68 | + elif flow.last_question_asked == Question.AGE: |
| 69 | + validate_result = self._validate_age(user_input) |
| 70 | + if not validate_result.is_valid: |
| 71 | + await turn_context.send_activity(MessageFactory.text(validate_result.message)) |
| 72 | + else: |
| 73 | + profile.age = validate_result.value |
| 74 | + await turn_context.send_activity(MessageFactory.text(f"I have your age as {profile.age}.")) |
| 75 | + await turn_context.send_activity(MessageFactory.text("When is your flight?")) |
| 76 | + flow.last_question_asked = Question.DATE |
| 77 | + |
| 78 | + # validate date and wrap it up |
| 79 | + elif flow.last_question_asked == Question.DATE: |
| 80 | + validate_result = self._validate_date(user_input) |
| 81 | + if not validate_result.is_valid: |
| 82 | + await turn_context.send_activity(MessageFactory.text(validate_result.message)) |
| 83 | + else: |
| 84 | + profile.date = validate_result.value |
| 85 | + await turn_context.send_activity(MessageFactory.text( |
| 86 | + f"Your cab ride to the airport is scheduled for {profile.date}.") |
| 87 | + ) |
| 88 | + await turn_context.send_activity(MessageFactory.text( |
| 89 | + f"Thanks for completing the booking {profile.name}.") |
| 90 | + ) |
| 91 | + await turn_context.send_activity(MessageFactory.text("Type anything to run the bot again.")) |
| 92 | + flow.last_question_asked = Question.NONE |
| 93 | + |
| 94 | + def _validate_name(self, user_input: str) -> ValidationResult: |
| 95 | + if not user_input: |
| 96 | + return ValidationResult(is_valid=False, message="Please enter a name that contains at least one character.") |
| 97 | + else: |
| 98 | + return ValidationResult(is_valid=True, value=user_input) |
| 99 | + |
| 100 | + def _validate_age(self, user_input: str) -> ValidationResult: |
| 101 | + # Attempt to convert the Recognizer result to an integer. This works for "a dozen", "twelve", "12", and so on. |
| 102 | + # The recognizer returns a list of potential recognition results, if any. |
| 103 | + results = recognize_number(user_input, Culture.English) |
| 104 | + for result in results: |
| 105 | + if "value" in result.resolution: |
| 106 | + age = int(result.resolution["value"]) |
| 107 | + if 18 <= age <= 120: |
| 108 | + return ValidationResult(is_valid=True, value=age) |
| 109 | + |
| 110 | + return ValidationResult(is_valid=False, message="Please enter an age between 18 and 120.") |
| 111 | + |
| 112 | + def _validate_date(self, user_input: str) -> ValidationResult: |
| 113 | + try: |
| 114 | + # Try to recognize the input as a date-time. This works for responses such as "11/14/2018", "9pm", |
| 115 | + # "tomorrow", "Sunday at 5pm", and so on. The recognizer returns a list of potential recognition results, |
| 116 | + # if any. |
| 117 | + results = recognize_datetime(user_input, Culture.English) |
| 118 | + for result in results: |
| 119 | + for resolution in result.resolution["values"]: |
| 120 | + if "value" in resolution: |
| 121 | + now = datetime.now() |
| 122 | + |
| 123 | + value = resolution["value"] |
| 124 | + if resolution["type"] == "date": |
| 125 | + candidate = datetime.strptime(value, "%Y-%m-%d") |
| 126 | + elif resolution["type"] == "time": |
| 127 | + candidate = datetime.strptime(value, "%H:%M:%S") |
| 128 | + candidate = candidate.replace(year=now.year, month=now.month, day=now.day) |
| 129 | + else: |
| 130 | + candidate = datetime.strptime(value, "%Y-%m-%d %H:%M:%S") |
| 131 | + |
| 132 | + # user response must be more than an hour out |
| 133 | + diff = candidate - now |
| 134 | + if diff.total_seconds() >= 3600: |
| 135 | + return ValidationResult(is_valid=True, value=candidate.strftime("%m/%d/%y @ %H:%M")) |
| 136 | + |
| 137 | + return ValidationResult(is_valid=False, message="I'm sorry, please enter a date at least an hour out.") |
| 138 | + except ValueError: |
| 139 | + return ValidationResult(is_valid=False, message="I'm sorry, I could not interpret that as an appropriate " |
| 140 | + "date. Please enter a date at least an hour out.") |
0 commit comments