From f52a96d2d09583ed25d2377cc8ab2ceed8e0e8f0 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 6 Feb 2022 11:14:10 +0100 Subject: [PATCH 001/153] Single, squashed commit --- .deepsource.toml | 1 - .github/CONTRIBUTING.rst | 6 - .github/pull_request_template.md | 2 +- .github/workflows/test.yml | 9 - .gitmodules | 4 - .pre-commit-config.yaml | 19 +- AUTHORS.rst | 7 - README.rst | 6 - README_RAW.rst | 6 - ...tcher.rst => telegram.ext.application.rst} | 6 +- .../telegram.ext.applicationbuilder.rst | 7 + ...> telegram.ext.applicationhandlerstop.rst} | 6 +- .../source/telegram.ext.dispatcherbuilder.rst | 7 - docs/source/telegram.ext.rst | 7 +- docs/source/telegram.ext.updaterbuilder.rst | 7 - docs/source/telegram.request.baserequest.rst | 8 + docs/source/telegram.request.httpxrequest.rst | 8 + docs/source/telegram.request.requestdata.rst | 8 + docs/source/telegram.request.rst | 9 +- examples/arbitrarycallbackdatabot.py | 50 +- examples/chatmemberbot.py | 32 +- examples/contexttypesbot.py | 40 +- examples/conversationbot.py | 56 +- examples/conversationbot2.py | 42 +- examples/deeplinking.py | 62 +- examples/echobot.py | 38 +- examples/errorhandlerbot.py | 39 +- examples/inlinebot.py | 38 +- examples/inlinekeyboard.py | 34 +- examples/inlinekeyboard2.py | 66 +- examples/nestedconversationbot.py | 98 +- examples/passportbot.py | 45 +- examples/paymentbot.py | 66 +- examples/persistentconversationbot.py | 51 +- examples/pollbot.py | 67 +- examples/rawapibot.py | 24 +- examples/timerbot.py | 52 +- pyproject.toml | 3 +- requirements-dev.txt | 1 + requirements.txt | 4 +- setup.cfg | 11 +- setup.py | 11 +- telegram/__main__.py | 3 - telegram/_bot.py | 1712 ++++++++---- telegram/_callbackquery.py | 238 +- telegram/_chat.py | 633 +++-- telegram/_chatjoinrequest.py | 34 +- telegram/_files/_basemedium.py | 18 +- telegram/_files/chatphoto.py | 36 +- telegram/_files/file.py | 27 +- telegram/_files/inputfile.py | 36 +- telegram/_files/inputmedia.py | 13 +- telegram/_inline/inlinequery.py | 14 +- telegram/_message.py | 506 +++- telegram/_passport/passportfile.py | 18 +- telegram/_payment/precheckoutquery.py | 14 +- telegram/_payment/shippingquery.py | 14 +- telegram/_telegramobject.py | 3 +- telegram/_user.py | 372 ++- telegram/_utils/defaultvalue.py | 6 + telegram/_utils/files.py | 8 +- telegram/_utils/types.py | 14 +- telegram/error.py | 53 +- telegram/ext/__init__.py | 11 +- telegram/ext/_application.py | 1094 ++++++++ telegram/ext/_basepersistence.py | 144 +- telegram/ext/_builders.py | 1258 +++------ telegram/ext/_callbackcontext.py | 131 +- telegram/ext/_callbackqueryhandler.py | 29 +- telegram/ext/_chatjoinrequesthandler.py | 9 +- telegram/ext/_chatmemberhandler.py | 27 +- telegram/ext/_choseninlineresulthandler.py | 31 +- telegram/ext/_commandhandler.py | 46 +- telegram/ext/_conversationhandler.py | 303 ++- telegram/ext/_defaults.py | 63 +- telegram/ext/_dictpersistence.py | 51 +- telegram/ext/_dispatcher.py | 893 ------- telegram/ext/_extbot.py | 95 +- telegram/ext/_handler.py | 61 +- telegram/ext/_inlinequeryhandler.py | 30 +- telegram/ext/_jobqueue.py | 232 +- telegram/ext/_messagehandler.py | 28 +- telegram/ext/_picklepersistence.py | 49 +- telegram/ext/_pollanswerhandler.py | 9 +- telegram/ext/_pollhandler.py | 9 +- telegram/ext/_precheckoutqueryhandler.py | 9 +- telegram/ext/_shippingqueryhandler.py | 9 +- telegram/ext/_stringcommandhandler.py | 33 +- telegram/ext/_stringregexhandler.py | 31 +- telegram/ext/_typehandler.py | 27 +- telegram/ext/_updater.py | 785 +++--- telegram/ext/_utils/promise.py | 148 - telegram/ext/_utils/trackingdefaultdict.py | 219 ++ telegram/ext/_utils/types.py | 44 +- telegram/ext/_utils/webhookhandler.py | 142 +- telegram/request.py | 405 --- telegram/request/__init__.py | 24 + telegram/request/_baserequest.py | 345 +++ telegram/request/_httpxrequest.py | 211 ++ telegram/request/_requestdata.py | 109 + telegram/request/_requestparameter.py | 135 + telegram/vendor/__init__.py | 0 telegram/vendor/ptb_urllib3 | 1 - tests/bots.py | 19 - tests/conftest.py | 225 +- tests/data/text_file.txt | 2 +- tests/test_animation.py | 131 +- tests/test_audio.py | 120 +- tests/test_bot.py | 1390 ++++++---- tests/test_builders.py | 279 -- tests/test_callbackcontext.py | 228 -- tests/test_callbackdatacache.py | 381 --- tests/test_callbackquery.py | 177 +- tests/test_callbackqueryhandler.py | 211 -- tests/test_chat.py | 499 ++-- tests/test_chatjoinrequest.py | 26 +- tests/test_chatjoinrequesthandler.py | 141 - tests/test_chatmemberhandler.py | 153 -- tests/test_chatphoto.py | 100 +- tests/test_choseninlineresulthandler.py | 159 -- tests/test_commandhandler.py | 383 --- tests/test_constants.py | 16 +- tests/test_contact.py | 38 +- tests/test_conversationhandler.py | 1785 ------------- tests/test_defaults.py | 59 - tests/test_dispatcher.py | 1130 -------- tests/test_document.py | 124 +- tests/test_error.py | 24 +- tests/test_file.py | 79 +- tests/test_files.py | 10 +- tests/test_filters.py | 2274 ---------------- tests/test_forcereply.py | 5 +- tests/test_inlinekeyboardmarkup.py | 16 +- tests/test_inlinequery.py | 35 +- tests/test_inlinequeryhandler.py | 161 -- tests/test_inputfile.py | 22 +- tests/test_inputmedia.py | 126 +- tests/test_invoice.py | 55 +- tests/test_jobqueue.py | 528 ---- tests/test_location.py | 111 +- tests/test_message.py | 512 ++-- tests/test_messagehandler.py | 210 -- tests/test_official.py | 15 +- tests/test_passport.py | 49 +- tests/test_passportfile.py | 13 +- tests/test_persistence.py | 2371 ----------------- tests/test_photo.py | 341 +-- tests/test_pollanswerhandler.py | 109 - tests/test_pollhandler.py | 122 - tests/test_precheckoutquery.py | 20 +- tests/test_precheckoutqueryhandler.py | 114 - tests/test_promise.py | 149 -- tests/test_replykeyboardmarkup.py | 34 +- tests/test_replykeyboardremove.py | 7 +- tests/test_request.py | 463 +++- tests/test_requestdata.py | 208 ++ tests/test_requestparameter.py | 133 + tests/test_shippingquery.py | 18 +- tests/test_shippingqueryhandler.py | 118 - tests/test_slots.py | 4 +- tests/test_stack.py | 2 +- tests/test_sticker.py | 228 +- tests/test_stringcommandhandler.py | 118 - tests/test_stringregexhandler.py | 132 - tests/test_typehandler.py | 73 - tests/test_updater.py | 654 ----- tests/test_user.py | 297 ++- tests/test_venue.py | 40 +- tests/test_video.py | 135 +- tests/test_videonote.py | 110 +- tests/test_voice.py | 136 +- 171 files changed, 10788 insertions(+), 19484 deletions(-) delete mode 100644 .gitmodules rename docs/source/{telegram.ext.dispatcher.rst => telegram.ext.application.rst} (52%) create mode 100644 docs/source/telegram.ext.applicationbuilder.rst rename docs/source/{telegram.ext.dispatcherhandlerstop.rst => telegram.ext.applicationhandlerstop.rst} (50%) delete mode 100644 docs/source/telegram.ext.dispatcherbuilder.rst delete mode 100644 docs/source/telegram.ext.updaterbuilder.rst create mode 100644 docs/source/telegram.request.baserequest.rst create mode 100644 docs/source/telegram.request.httpxrequest.rst create mode 100644 docs/source/telegram.request.requestdata.rst create mode 100644 telegram/ext/_application.py delete mode 100644 telegram/ext/_dispatcher.py delete mode 100644 telegram/ext/_utils/promise.py create mode 100644 telegram/ext/_utils/trackingdefaultdict.py delete mode 100644 telegram/request.py create mode 100644 telegram/request/__init__.py create mode 100644 telegram/request/_baserequest.py create mode 100644 telegram/request/_httpxrequest.py create mode 100644 telegram/request/_requestdata.py create mode 100644 telegram/request/_requestparameter.py delete mode 100644 telegram/vendor/__init__.py delete mode 160000 telegram/vendor/ptb_urllib3 delete mode 100644 tests/test_builders.py delete mode 100644 tests/test_callbackcontext.py delete mode 100644 tests/test_callbackdatacache.py delete mode 100644 tests/test_callbackqueryhandler.py delete mode 100644 tests/test_chatjoinrequesthandler.py delete mode 100644 tests/test_chatmemberhandler.py delete mode 100644 tests/test_choseninlineresulthandler.py delete mode 100644 tests/test_commandhandler.py delete mode 100644 tests/test_conversationhandler.py delete mode 100644 tests/test_defaults.py delete mode 100644 tests/test_dispatcher.py delete mode 100644 tests/test_filters.py delete mode 100644 tests/test_inlinequeryhandler.py delete mode 100644 tests/test_jobqueue.py delete mode 100644 tests/test_messagehandler.py delete mode 100644 tests/test_persistence.py delete mode 100644 tests/test_pollanswerhandler.py delete mode 100644 tests/test_pollhandler.py delete mode 100644 tests/test_precheckoutqueryhandler.py delete mode 100644 tests/test_promise.py create mode 100644 tests/test_requestdata.py create mode 100644 tests/test_requestparameter.py delete mode 100644 tests/test_shippingqueryhandler.py delete mode 100644 tests/test_stringcommandhandler.py delete mode 100644 tests/test_stringregexhandler.py delete mode 100644 tests/test_typehandler.py delete mode 100644 tests/test_updater.py diff --git a/.deepsource.toml b/.deepsource.toml index a525644a99e..a081251483b 100644 --- a/.deepsource.toml +++ b/.deepsource.toml @@ -5,7 +5,6 @@ test_patterns = ["tests/**"] exclude_patterns = [ "tests/**", "docs/**", - "telegram/vendor/**", "setup.py", "setup-raw.py" ] diff --git a/.github/CONTRIBUTING.rst b/.github/CONTRIBUTING.rst index d204cd74bcc..68a98e2b27b 100644 --- a/.github/CONTRIBUTING.rst +++ b/.github/CONTRIBUTING.rst @@ -153,12 +153,6 @@ Here's how to make a one-off code change. $ git commit -a $ git push origin your-branch-name - - If after merging you see local modified files in ``telegram/vendor/`` directory, that you didn't actually touch, that means you need to update submodules with this command: - - .. code-block:: bash - - $ git submodule update --init --recursive - - At the end, the reviewer will merge the pull request. 6. **Tidy up!** Delete the feature branch from both your local clone and the GitHub repository: diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index ca3b7bf4f9f..f84a1efef23 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -8,7 +8,7 @@ Hey! You're PRing? Cool! Please have a look at the below checklist. It's here to - [ ] Created new or adapted existing unit tests - [ ] Documented code changes according to the [CSI standard](https://standards.mousepawmedia.com/en/stable/csi.html) - [ ] Added myself alphabetically to `AUTHORS.rst` (optional) -- [ ] Added new classes & modules to the docs +- [ ] Added new classes & modules to the docs and all suitable `__all__` s ### If the PR contains API changes (otherwise, you can delete this passage) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f43f62a8691..73a639512b9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,9 +20,6 @@ jobs: fail-fast: False steps: - uses: actions/checkout@v2 - - name: Initialize vendored libs - run: - git submodule update --init --recursive - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -77,9 +74,6 @@ jobs: fail-fast: False steps: - uses: actions/checkout@v2 - - name: Initialize vendored libs - run: - git submodule update --init --recursive - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -106,9 +100,6 @@ jobs: fail-fast: False steps: - uses: actions/checkout@v2 - - name: Initialize vendored libs - run: - git submodule update --init --recursive - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index ebe60816fb4..00000000000 --- a/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "telegram/vendor/urllib3"] - path = telegram/vendor/ptb_urllib3 - url = https://github.com/python-telegram-bot/urllib3.git - branch = ptb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9466cc51181..b78c12611f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: # run pylint across multiple cpu cores to speed it up- - --jobs=0 # See https://pylint.pycqa.org/en/latest/user_guide/run.html?#parallel-execution to know more additional_dependencies: - - certifi + - httpx >= 0.20.0,<1.0 - tornado>=6.1 - APScheduler==3.6.3 - cachetools==4.2.2 @@ -38,25 +38,12 @@ repos: - types-ujson - types-pytz - types-cryptography - - types-certifi - types-cachetools - - certifi + - httpx >= 0.20.0,<1.0 - tornado>=6.1 - APScheduler==3.6.3 - cachetools==4.2.2 - - . # this basically does `pip install -e .` - - id: mypy - name: mypy-examples - files: ^examples/.*\.py$ - args: - - --no-strict-optional - - --follow-imports=silent - additional_dependencies: - - certifi - - tornado>=6.1 - - APScheduler==3.6.3 - - cachetools==4.2.2 - - . # this basically does `pip install -e .` + - . # this basically does `pip install -e .`n - repo: https://github.com/asottile/pyupgrade rev: v2.29.0 hooks: diff --git a/AUTHORS.rst b/AUTHORS.rst index dad9eb83d5e..c4b7129f37b 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -14,13 +14,6 @@ Emeritus maintainers include `Jannes Höke `_ (`@jh0ker `_ on Telegram), `Noam Meltzer `_, `Pieter Schutz `_ and `Jasmin Bom `_. -Vendored packages ------------------ - -We're vendoring urllib3 as part of ``python-telegram-bot`` which is distributed under the MIT -license. For more info, full credits & license terms, the sources can be found here: -`https://github.com/python-telegram-bot/urllib3`. - Contributors ------------ diff --git a/README.rst b/README.rst index c80c8692c4e..f573e62f0c9 100644 --- a/README.rst +++ b/README.rst @@ -130,12 +130,6 @@ Or you can install from source with: $ git clone https://github.com/python-telegram-bot/python-telegram-bot --recursive $ cd python-telegram-bot $ python setup.py install - -In case you have a previously cloned local repository already, you should initialize the added urllib3 submodule before installing with: - -.. code:: shell - - $ git submodule update --init --recursive --------------------- Optional Dependencies diff --git a/README_RAW.rst b/README_RAW.rst index d25ad6efc38..676cc1c7231 100644 --- a/README_RAW.rst +++ b/README_RAW.rst @@ -125,12 +125,6 @@ Or you can install from source with: $ cd python-telegram-bot $ python setup-raw.py install -In case you have a previously cloned local repository already, you should initialize the added urllib3 submodule before installing with: - -.. code:: shell - - $ git submodule update --init --recursive - ---- Note ---- diff --git a/docs/source/telegram.ext.dispatcher.rst b/docs/source/telegram.ext.application.rst similarity index 52% rename from docs/source/telegram.ext.dispatcher.rst rename to docs/source/telegram.ext.application.rst index 268be2ac0eb..8009517b743 100644 --- a/docs/source/telegram.ext.dispatcher.rst +++ b/docs/source/telegram.ext.application.rst @@ -1,8 +1,8 @@ -:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/dispatcher.py +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_application.py -telegram.ext.Dispatcher +telegram.ext.Application ======================= -.. autoclass:: telegram.ext.Dispatcher +.. autoclass:: telegram.ext.Application :members: :show-inheritance: diff --git a/docs/source/telegram.ext.applicationbuilder.rst b/docs/source/telegram.ext.applicationbuilder.rst new file mode 100644 index 00000000000..fbdec5357a1 --- /dev/null +++ b/docs/source/telegram.ext.applicationbuilder.rst @@ -0,0 +1,7 @@ +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_builders.py + +telegram.ext.ApplicationBuilder +=============================== + +.. autoclass:: telegram.ext.ApplicationBuilder + :members: diff --git a/docs/source/telegram.ext.dispatcherhandlerstop.rst b/docs/source/telegram.ext.applicationhandlerstop.rst similarity index 50% rename from docs/source/telegram.ext.dispatcherhandlerstop.rst rename to docs/source/telegram.ext.applicationhandlerstop.rst index 6894a840f46..15ad832cca6 100644 --- a/docs/source/telegram.ext.dispatcherhandlerstop.rst +++ b/docs/source/telegram.ext.applicationhandlerstop.rst @@ -1,8 +1,8 @@ -:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/dispatcher.py +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_application.py -telegram.ext.DispatcherHandlerStop +telegram.ext.ApplicationHandlerStop ================================== -.. autoclass:: telegram.ext.DispatcherHandlerStop +.. autoclass:: telegram.ext.ApplicationHandlerStop :members: :show-inheritance: diff --git a/docs/source/telegram.ext.dispatcherbuilder.rst b/docs/source/telegram.ext.dispatcherbuilder.rst deleted file mode 100644 index 292c2fb9e5e..00000000000 --- a/docs/source/telegram.ext.dispatcherbuilder.rst +++ /dev/null @@ -1,7 +0,0 @@ -:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/builders.py - -telegram.ext.DispatcherBuilder -============================== - -.. autoclass:: telegram.ext.DispatcherBuilder - :members: diff --git a/docs/source/telegram.ext.rst b/docs/source/telegram.ext.rst index 576fac40d5f..7a26d54cb5c 100644 --- a/docs/source/telegram.ext.rst +++ b/docs/source/telegram.ext.rst @@ -4,11 +4,10 @@ telegram.ext package .. toctree:: telegram.ext.extbot - telegram.ext.updaterbuilder + telegram.ext.applicationbuilder + telegram.ext.application + telegram.ext.applicationhandlerstop telegram.ext.updater - telegram.ext.dispatcherbuilder - telegram.ext.dispatcher - telegram.ext.dispatcherhandlerstop telegram.ext.callbackcontext telegram.ext.job telegram.ext.jobqueue diff --git a/docs/source/telegram.ext.updaterbuilder.rst b/docs/source/telegram.ext.updaterbuilder.rst deleted file mode 100644 index ee82f103c61..00000000000 --- a/docs/source/telegram.ext.updaterbuilder.rst +++ /dev/null @@ -1,7 +0,0 @@ -:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/builders.py - -telegram.ext.UpdaterBuilder -=========================== - -.. autoclass:: telegram.ext.UpdaterBuilder - :members: diff --git a/docs/source/telegram.request.baserequest.rst b/docs/source/telegram.request.baserequest.rst new file mode 100644 index 00000000000..4ab11dbb772 --- /dev/null +++ b/docs/source/telegram.request.baserequest.rst @@ -0,0 +1,8 @@ +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/request/_baserequest.py + +telegram.request.BaseRequest +============================ + +.. autoclass:: telegram.request.BaseRequest + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/telegram.request.httpxrequest.rst b/docs/source/telegram.request.httpxrequest.rst new file mode 100644 index 00000000000..676f3d1d1b0 --- /dev/null +++ b/docs/source/telegram.request.httpxrequest.rst @@ -0,0 +1,8 @@ +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/request/_httpxrequest.py + +telegram.request.HTTPXRequest +============================= + +.. autoclass:: telegram.request.HTTPXRequest + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/telegram.request.requestdata.rst b/docs/source/telegram.request.requestdata.rst new file mode 100644 index 00000000000..f020347bdaa --- /dev/null +++ b/docs/source/telegram.request.requestdata.rst @@ -0,0 +1,8 @@ +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/request/_requestdata.py + +telegram.request.RequestData +============================ + +.. autoclass:: telegram.request.RequestData + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/telegram.request.rst b/docs/source/telegram.request.rst index c05e4671390..ce724e76f28 100644 --- a/docs/source/telegram.request.rst +++ b/docs/source/telegram.request.rst @@ -1,8 +1,9 @@ -:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/request.py +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/request telegram.request Module ======================= -.. automodule:: telegram.request - :members: - :show-inheritance: +.. toctree:: + telegram.request.baserequest + telegram.request.requestdata + telegram.request.httpxrequest diff --git a/examples/arbitrarycallbackdatabot.py b/examples/arbitrarycallbackdatabot.py index 3cb3e9aa9b2..354ec934c2f 100644 --- a/examples/arbitrarycallbackdatabot.py +++ b/examples/arbitrarycallbackdatabot.py @@ -16,7 +16,7 @@ CallbackQueryHandler, InvalidCallbackData, PicklePersistence, - Updater, + Application, CallbackContext, ) @@ -28,25 +28,25 @@ logger = logging.getLogger(__name__) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Sends a message with 5 inline buttons attached.""" number_list: List[int] = [] - update.message.reply_text('Please choose:', reply_markup=build_keyboard(number_list)) + await update.message.reply_text('Please choose:', reply_markup=build_keyboard(number_list)) -def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Displays info on how to use the bot.""" - update.message.reply_text( + await update.message.reply_text( "Use /start to test this bot. Use /clear to clear the stored data so that you can see " "what happens, if the button data is not available. " ) -def clear(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def clear(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Clears the callback data cache""" context.bot.callback_data_cache.clear_callback_data() context.bot.callback_data_cache.clear_callback_queries() - update.effective_message.reply_text('All clear!') + await update.effective_message.reply_text('All clear!') def build_keyboard(current_list: List[int]) -> InlineKeyboardMarkup: @@ -56,10 +56,10 @@ def build_keyboard(current_list: List[int]) -> InlineKeyboardMarkup: ) -def list_button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def list_button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Parses the CallbackQuery and updates the message text.""" query = update.callback_query - query.answer() + await query.answer() # Get the data from the callback_data. # If you're using a type checker like MyPy, you'll have to use typing.cast # to make the checker get the expected type of the callback_data @@ -67,7 +67,7 @@ def list_button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: # append the number to the list number_list.append(number) - query.edit_message_text( + await query.edit_message_text( text=f"So far you've selected {number_list}. Choose the next item:", reply_markup=build_keyboard(number_list), ) @@ -76,10 +76,10 @@ def list_button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: context.drop_callback_data(query) -def handle_invalid_button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def handle_invalid_button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Informs the user that the button is no longer available.""" - update.callback_query.answer() - update.effective_message.edit_text( + await update.callback_query.answer() + await update.effective_message.edit_text( 'Sorry, I could not process this button click 😕 Please send /start to get a new keyboard.' ) @@ -88,29 +88,25 @@ def main() -> None: """Run the bot.""" # We use persistence to demonstrate how buttons can still work after the bot was restarted persistence = PicklePersistence(filepath='arbitrarycallbackdatabot') - # Create the Updater and pass it your bot's token. - updater = ( - Updater.builder() + # Create the Application and pass it your bot's token. + application = ( + Application.builder() .token("TOKEN") .persistence(persistence) .arbitrary_callback_data(True) .build() ) - updater.dispatcher.add_handler(CommandHandler('start', start)) - updater.dispatcher.add_handler(CommandHandler('help', help_command)) - updater.dispatcher.add_handler(CommandHandler('clear', clear)) - updater.dispatcher.add_handler( + application.application.add_handler(CommandHandler('start', start)) + application.application.add_handler(CommandHandler('help', help_command)) + application.application.add_handler(CommandHandler('clear', clear)) + application.application.add_handler( CallbackQueryHandler(handle_invalid_button, pattern=InvalidCallbackData) ) - updater.dispatcher.add_handler(CallbackQueryHandler(list_button)) + application.application.add_handler(CallbackQueryHandler(list_button)) - # Start the Bot - updater.start_polling() - - # Run the bot until the user presses Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/chatmemberbot.py b/examples/chatmemberbot.py index 4725e0661f0..c3606602be2 100644 --- a/examples/chatmemberbot.py +++ b/examples/chatmemberbot.py @@ -19,7 +19,7 @@ from telegram.ext import ( CommandHandler, ChatMemberHandler, - Updater, + Application, CallbackContext, ) @@ -68,7 +68,7 @@ def extract_status_change( return was_member, is_member -def track_chats(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def track_chats(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Tracks the chats the bot is in.""" result = extract_status_change(update.my_chat_member) if result is None: @@ -103,7 +103,7 @@ def track_chats(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: context.bot_data.setdefault("channel_ids", set()).discard(chat.id) -def show_chats(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def show_chats(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Shows which chats the bot is in""" user_ids = ", ".join(str(uid) for uid in context.bot_data.setdefault("user_ids", set())) group_ids = ", ".join(str(gid) for gid in context.bot_data.setdefault("group_ids", set())) @@ -113,10 +113,10 @@ def show_chats(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: f" Moreover it is a member of the groups with IDs {group_ids} " f"and administrator in the channels with IDs {channel_ids}." ) - update.effective_message.reply_text(text) + await update.effective_message.reply_text(text) -def greet_chat_members(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def greet_chat_members(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Greets new users in chats and announces when someone leaves""" result = extract_status_change(update.chat_member) if result is None: @@ -140,28 +140,20 @@ def greet_chat_members(update: Update, context: CallbackContext.DEFAULT_TYPE) -> def main() -> None: """Start the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # Keep track of which chats the bot is in - dispatcher.add_handler(ChatMemberHandler(track_chats, ChatMemberHandler.MY_CHAT_MEMBER)) - dispatcher.add_handler(CommandHandler("show_chats", show_chats)) + application.add_handler(ChatMemberHandler(track_chats, ChatMemberHandler.MY_CHAT_MEMBER)) + application.add_handler(CommandHandler("show_chats", show_chats)) # Handle members joining/leaving chats. - dispatcher.add_handler(ChatMemberHandler(greet_chat_members, ChatMemberHandler.CHAT_MEMBER)) + application.add_handler(ChatMemberHandler(greet_chat_members, ChatMemberHandler.CHAT_MEMBER)) - # Start the Bot + # Run the bot until the user presses Ctrl-C # We pass 'allowed_updates' handle *all* updates including `chat_member` updates # To reset this, simply pass `allowed_updates=[]` - updater.start_polling(allowed_updates=Update.ALL_TYPES) - - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + application.run_polling()(allowed_updates=Update.ALL_TYPES) if __name__ == "__main__": diff --git a/examples/contexttypesbot.py b/examples/contexttypesbot.py index 07787813d38..c931f92ca33 100644 --- a/examples/contexttypesbot.py +++ b/examples/contexttypesbot.py @@ -21,9 +21,8 @@ ContextTypes, CallbackQueryHandler, TypeHandler, - Dispatcher, ExtBot, - Updater, + Application, ) @@ -38,8 +37,8 @@ def __init__(self) -> None: class CustomContext(CallbackContext[ExtBot, dict, ChatData, dict]): """Custom class for context.""" - def __init__(self, dispatcher: Dispatcher): - super().__init__(dispatcher=dispatcher) + def __init__(self, application: Application): + super().__init__(application=application) self._message_id: Optional[int] = None @property @@ -62,10 +61,10 @@ def message_clicks(self, value: int) -> None: self.chat_data.clicks_per_message[self._message_id] = value @classmethod - def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CustomContext': + def from_update(cls, update: object, application: 'Application') -> 'CustomContext': """Override from_update to set _message_id.""" # Make sure to call super() - context = super().from_update(update, dispatcher) + context = super().from_update(update, application) if context.chat_data and isinstance(update, Update) and update.effective_message: # pylint: disable=protected-access @@ -75,9 +74,9 @@ def from_update(cls, update: object, dispatcher: 'Dispatcher') -> 'CustomContext return context -def start(update: Update, context: CustomContext) -> None: +async def start(update: Update, context: CustomContext) -> None: """Display a message with a button.""" - update.message.reply_html( + await update.message.reply_html( 'This button was clicked 0 times.', reply_markup=InlineKeyboardMarkup.from_button( InlineKeyboardButton(text='Click me!', callback_data='button') @@ -85,10 +84,10 @@ def start(update: Update, context: CustomContext) -> None: ) -def count_click(update: Update, context: CustomContext) -> None: +async def count_click(update: Update, context: CustomContext) -> None: """Update the click count for the message.""" context.message_clicks += 1 - update.callback_query.answer() + await update.callback_query.answer() update.effective_message.edit_text( f'This button was clicked {context.message_clicks} times.', reply_markup=InlineKeyboardMarkup.from_button( @@ -98,15 +97,15 @@ def count_click(update: Update, context: CustomContext) -> None: ) -def print_users(update: Update, context: CustomContext) -> None: +async def print_users(update: Update, context: CustomContext) -> None: """Show which users have been using this bot.""" - update.message.reply_text( + await update.message.reply_text( 'The following user IDs have used this bot: ' f'{", ".join(map(str, context.bot_user_ids))}' ) -def track_users(update: Update, context: CustomContext) -> None: +async def track_users(update: Update, context: CustomContext) -> None: """Store the user id of the incoming update, if any.""" if update.effective_user: context.bot_user_ids.add(update.effective_user.id) @@ -115,17 +114,16 @@ def track_users(update: Update, context: CustomContext) -> None: def main() -> None: """Run the bot.""" context_types = ContextTypes(context=CustomContext, chat_data=ChatData) - updater = Updater.builder().token("TOKEN").context_types(context_types).build() + application = Application.builder().token("TOKEN").context_types(context_types).build() - dispatcher = updater.dispatcher + application = application.application # run track_users in its own group to not interfere with the user handlers - dispatcher.add_handler(TypeHandler(Update, track_users), group=-1) - dispatcher.add_handler(CommandHandler("start", start)) - dispatcher.add_handler(CallbackQueryHandler(count_click)) - dispatcher.add_handler(CommandHandler("print_users", print_users)) + application.add_handler(TypeHandler(Update, track_users), group=-1) + application.add_handler(CommandHandler("start", start)) + application.add_handler(CallbackQueryHandler(count_click)) + application.add_handler(CommandHandler("print_users", print_users)) - updater.start_polling() - updater.idle() + application.run_polling() if __name__ == '__main__': diff --git a/examples/conversationbot.py b/examples/conversationbot.py index 1b0b1983042..691e982c13f 100644 --- a/examples/conversationbot.py +++ b/examples/conversationbot.py @@ -4,7 +4,7 @@ """ First, a few callback functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -22,7 +22,7 @@ MessageHandler, filters, ConversationHandler, - Updater, + Application, CallbackContext, ) @@ -36,11 +36,11 @@ GENDER, PHOTO, LOCATION, BIO = range(4) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Starts the conversation and asks the user about their gender.""" reply_keyboard = [['Boy', 'Girl', 'Other']] - update.message.reply_text( + await update.message.reply_text( 'Hi! My name is Professor Bot. I will hold a conversation with you. ' 'Send /cancel to stop talking to me.\n\n' 'Are you a boy or a girl?', @@ -52,11 +52,11 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: return GENDER -def gender(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def gender(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Stores the selected gender and asks for a photo.""" user = update.message.from_user logger.info("Gender of %s: %s", user.first_name, update.message.text) - update.message.reply_text( + await update.message.reply_text( 'I see! Please send me a photo of yourself, ' 'so I know what you look like, or send /skip if you don\'t want to.', reply_markup=ReplyKeyboardRemove(), @@ -65,69 +65,69 @@ def gender(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: return PHOTO -def photo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def photo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Stores the photo and asks for a location.""" user = update.message.from_user - photo_file = update.message.photo[-1].get_file() + photo_file = await update.message.photo[-1].get_file() photo_file.download('user_photo.jpg') logger.info("Photo of %s: %s", user.first_name, 'user_photo.jpg') - update.message.reply_text( + await update.message.reply_text( 'Gorgeous! Now, send me your location please, or send /skip if you don\'t want to.' ) return LOCATION -def skip_photo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def skip_photo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Skips the photo and asks for a location.""" user = update.message.from_user logger.info("User %s did not send a photo.", user.first_name) - update.message.reply_text( + await update.message.reply_text( 'I bet you look great! Now, send me your location please, or send /skip.' ) return LOCATION -def location(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def location(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Stores the location and asks for some info about the user.""" user = update.message.from_user user_location = update.message.location logger.info( "Location of %s: %f / %f", user.first_name, user_location.latitude, user_location.longitude ) - update.message.reply_text( + await update.message.reply_text( 'Maybe I can visit you sometime! At last, tell me something about yourself.' ) return BIO -def skip_location(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def skip_location(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Skips the location and asks for info about the user.""" user = update.message.from_user logger.info("User %s did not send a location.", user.first_name) - update.message.reply_text( + await update.message.reply_text( 'You seem a bit paranoid! At last, tell me something about yourself.' ) return BIO -def bio(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def bio(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Stores the info about the user and ends the conversation.""" user = update.message.from_user logger.info("Bio of %s: %s", user.first_name, update.message.text) - update.message.reply_text('Thank you! I hope we can talk again some day.') + await update.message.reply_text('Thank you! I hope we can talk again some day.') return ConversationHandler.END -def cancel(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def cancel(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Cancels and ends the conversation.""" user = update.message.from_user logger.info("User %s canceled the conversation.", user.first_name) - update.message.reply_text( + await update.message.reply_text( 'Bye! I hope we can talk again some day.', reply_markup=ReplyKeyboardRemove() ) @@ -136,11 +136,8 @@ def cancel(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # Add conversation handler with the states GENDER, PHOTO, LOCATION and BIO conv_handler = ConversationHandler( @@ -157,15 +154,10 @@ def main() -> None: fallbacks=[CommandHandler('cancel', cancel)], ) - dispatcher.add_handler(conv_handler) - - # Start the Bot - updater.start_polling() + application.add_handler(conv_handler) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/conversationbot2.py b/examples/conversationbot2.py index dfdd5f5aa2c..6d5737e8f8b 100644 --- a/examples/conversationbot2.py +++ b/examples/conversationbot2.py @@ -4,7 +4,7 @@ """ First, a few callback functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -23,7 +23,7 @@ MessageHandler, filters, ConversationHandler, - Updater, + Application, CallbackContext, ) @@ -50,9 +50,9 @@ def facts_to_str(user_data: Dict[str, str]) -> str: return "\n".join(facts).join(['\n', '\n']) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Start the conversation and ask user for input.""" - update.message.reply_text( + await update.message.reply_text( "Hi! My name is Doctor Botter. I will hold a more complex conversation with you. " "Why don't you tell me something about yourself?", reply_markup=markup, @@ -61,25 +61,25 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: return CHOOSING -def regular_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def regular_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Ask the user for info about the selected predefined choice.""" text = update.message.text context.user_data['choice'] = text - update.message.reply_text(f'Your {text.lower()}? Yes, I would love to hear about that!') + await update.message.reply_text(f'Your {text.lower()}? Yes, I would love to hear about that!') return TYPING_REPLY -def custom_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def custom_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Ask the user for a description of a custom category.""" - update.message.reply_text( + await update.message.reply_text( 'Alright, please send me the category first, for example "Most impressive skill"' ) return TYPING_CHOICE -def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Store info provided by user and ask for the next category.""" user_data = context.user_data text = update.message.text @@ -87,7 +87,7 @@ def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) user_data[category] = text del user_data['choice'] - update.message.reply_text( + await update.message.reply_text( "Neat! Just so you know, this is what you already told me:" f"{facts_to_str(user_data)} You can tell me more, or change your opinion" " on something.", @@ -97,13 +97,13 @@ def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) return CHOOSING -def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Display the gathered info and end the conversation.""" user_data = context.user_data if 'choice' in user_data: del user_data['choice'] - update.message.reply_text( + await update.message.reply_text( f"I learned these facts about you: {facts_to_str(user_data)}Until next time!", reply_markup=ReplyKeyboardRemove(), ) @@ -114,11 +114,8 @@ def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # Add conversation handler with the states CHOOSING, TYPING_CHOICE and TYPING_REPLY conv_handler = ConversationHandler( @@ -145,15 +142,10 @@ def main() -> None: fallbacks=[MessageHandler(filters.Regex('^Done$'), done)], ) - dispatcher.add_handler(conv_handler) - - # Start the Bot - updater.start_polling() + application.add_handler(conv_handler) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/deeplinking.py b/examples/deeplinking.py index 88a7cd45bad..f0644d16aef 100644 --- a/examples/deeplinking.py +++ b/examples/deeplinking.py @@ -6,10 +6,10 @@ This program is dedicated to the public domain under the CC0 license. -This Bot uses the Updater class to handle the bot. +This Bot uses the Application class to handle the bot. First, a few handler functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -26,7 +26,7 @@ CommandHandler, CallbackQueryHandler, filters, - Updater, + Application, CallbackContext, ) @@ -47,15 +47,15 @@ KEYBOARD_CALLBACKDATA = "keyboard-callback-data" -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a deep-linked URL when the command /start is issued.""" bot = context.bot url = helpers.create_deep_linked_url(bot.username, CHECK_THIS_OUT, group=True) text = "Feel free to tell your friends about it:\n\n" + url - update.message.reply_text(text) + await update.message.reply_text(text) -def deep_linked_level_1(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def deep_linked_level_1(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Reached through the CHECK_THIS_OUT payload""" bot = context.bot url = helpers.create_deep_linked_url(bot.username, SO_COOL) @@ -66,20 +66,22 @@ def deep_linked_level_1(update: Update, context: CallbackContext.DEFAULT_TYPE) - keyboard = InlineKeyboardMarkup.from_button( InlineKeyboardButton(text="Continue here!", url=url) ) - update.message.reply_text(text, reply_markup=keyboard) + await update.message.reply_text(text, reply_markup=keyboard) -def deep_linked_level_2(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def deep_linked_level_2(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Reached through the SO_COOL payload""" bot = context.bot url = helpers.create_deep_linked_url(bot.username, USING_ENTITIES) text = f"You can also mask the deep-linked URLs as links: [▶️ CLICK HERE]({url})." - update.message.reply_text(text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True) + await update.message.reply_text( + text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True + ) -def deep_linked_level_3(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def deep_linked_level_3(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Reached through the USING_ENTITIES payload""" - update.message.reply_text( + await update.message.reply_text( "It is also possible to make deep-linking using InlineKeyboardButtons.", reply_markup=InlineKeyboardMarkup( [[InlineKeyboardButton(text="Like this!", callback_data=KEYBOARD_CALLBACKDATA)]] @@ -87,65 +89,59 @@ def deep_linked_level_3(update: Update, context: CallbackContext.DEFAULT_TYPE) - ) -def deep_link_level_3_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def deep_link_level_3_callback( + update: Update, context: CallbackContext.DEFAULT_TYPE +) -> None: """Answers CallbackQuery with deeplinking url.""" bot = context.bot url = helpers.create_deep_linked_url(bot.username, USING_KEYBOARD) - update.callback_query.answer(url=url) + await update.callback_query.answer(url=url) -def deep_linked_level_4(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def deep_linked_level_4(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Reached through the USING_KEYBOARD payload""" payload = context.args - update.message.reply_text( + await update.message.reply_text( f"Congratulations! This is as deep as it gets 👏🏻\n\nThe payload was: {payload}" ) def main() -> None: """Start the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # More info on what deep linking actually is (read this first if it's unclear to you): # https://core.telegram.org/bots#deep-linking # Register a deep-linking handler - dispatcher.add_handler( + application.add_handler( CommandHandler("start", deep_linked_level_1, filters.Regex(CHECK_THIS_OUT)) ) # This one works with a textual link instead of an URL - dispatcher.add_handler(CommandHandler("start", deep_linked_level_2, filters.Regex(SO_COOL))) + application.add_handler(CommandHandler("start", deep_linked_level_2, filters.Regex(SO_COOL))) # We can also pass on the deep-linking payload - dispatcher.add_handler( + application.add_handler( CommandHandler("start", deep_linked_level_3, filters.Regex(USING_ENTITIES)) ) # Possible with inline keyboard buttons as well - dispatcher.add_handler( + application.add_handler( CommandHandler("start", deep_linked_level_4, filters.Regex(USING_KEYBOARD)) ) # register callback handler for inline keyboard button - dispatcher.add_handler( + application.add_handler( CallbackQueryHandler(deep_link_level_3_callback, pattern=KEYBOARD_CALLBACKDATA) ) # Make sure the deep-linking handlers occur *before* the normal /start handler. - dispatcher.add_handler(CommandHandler("start", start)) - - # Start the Bot - updater.start_polling() + application.add_handler(CommandHandler("start", start)) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == "__main__": diff --git a/examples/echobot.py b/examples/echobot.py index 278df7d9a70..95c34d2d084 100644 --- a/examples/echobot.py +++ b/examples/echobot.py @@ -6,7 +6,7 @@ Simple Bot to reply to Telegram messages. First, a few handler functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -22,7 +22,7 @@ CommandHandler, MessageHandler, filters, - Updater, + Application, CallbackContext, ) @@ -36,47 +36,39 @@ # Define a few command handlers. These usually take the two arguments update and # context. -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a message when the command /start is issued.""" user = update.effective_user - update.message.reply_markdown_v2( + await update.message.reply_markdown_v2( fr'Hi {user.mention_markdown_v2()}\!', reply_markup=ForceReply(selective=True), ) -def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a message when the command /help is issued.""" - update.message.reply_text('Help!') + await update.message.reply_text('Help!') -def echo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def echo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Echo the user message.""" - update.message.reply_text(update.message.text) + await update.message.reply_text(update.message.text) def main() -> None: """Start the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # on different commands - answer in Telegram - dispatcher.add_handler(CommandHandler("start", start)) - dispatcher.add_handler(CommandHandler("help", help_command)) + application.add_handler(CommandHandler("start", start)) + application.add_handler(CommandHandler("help", help_command)) # on non command i.e message - echo the message on Telegram - dispatcher.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) - - # Start the Bot - updater.start_polling() + application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/errorhandlerbot.py b/examples/errorhandlerbot.py index e6853e789ff..8b0079d1648 100644 --- a/examples/errorhandlerbot.py +++ b/examples/errorhandlerbot.py @@ -10,7 +10,7 @@ from telegram import Update from telegram.constants import ParseMode -from telegram.ext import CommandHandler, Updater, CallbackContext +from telegram.ext import CommandHandler, Application, CallbackContext # Enable logging logging.basicConfig( @@ -26,7 +26,7 @@ DEVELOPER_CHAT_ID = 123456789 -def error_handler(update: object, context: CallbackContext.DEFAULT_TYPE) -> None: +async def error_handler(update: object, context: CallbackContext.DEFAULT_TYPE) -> None: """Log the error and send a telegram message to notify the developer.""" # Log the error before we do anything else, so we can see it even if something breaks. logger.error(msg="Exception while handling an update:", exc_info=context.error) @@ -49,17 +49,19 @@ def error_handler(update: object, context: CallbackContext.DEFAULT_TYPE) -> None ) # Finally, send the message - context.bot.send_message(chat_id=DEVELOPER_CHAT_ID, text=message, parse_mode=ParseMode.HTML) + await context.bot.send_message( + chat_id=DEVELOPER_CHAT_ID, text=message, parse_mode=ParseMode.HTML + ) -def bad_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def bad_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Raise an error to trigger the error handler.""" - context.bot.wrong_method_name() # type: ignore[attr-defined] + await context.bot.wrong_method_name() # type: ignore[attr-defined] -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Displays info on how to trigger an error.""" - update.effective_message.reply_html( + await update.effective_message.reply_html( 'Use /bad_command to cause an error.\n' f'Your chat id is {update.effective_chat.id}.' ) @@ -67,26 +69,21 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token(BOT_TOKEN).build() + # Create the Application and pass it your bot's token. + application = Application.builder().token(BOT_TOKEN).build() - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Get the application to register handlers + application = application.application # Register the commands... - dispatcher.add_handler(CommandHandler('start', start)) - dispatcher.add_handler(CommandHandler('bad_command', bad_command)) + application.add_handler(CommandHandler('start', start)) + application.add_handler(CommandHandler('bad_command', bad_command)) # ...and the error handler - dispatcher.add_error_handler(error_handler) - - # Start the Bot - updater.start_polling() + application.add_error_handler(error_handler) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/inlinebot.py b/examples/inlinebot.py index ef86101a95d..c1cfac18547 100644 --- a/examples/inlinebot.py +++ b/examples/inlinebot.py @@ -4,7 +4,7 @@ """ First, a few handler functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -18,7 +18,7 @@ from telegram import InlineQueryResultArticle, InputTextMessageContent, Update from telegram.constants import ParseMode from telegram.helpers import escape_markdown -from telegram.ext import Updater, InlineQueryHandler, CommandHandler, CallbackContext +from telegram.ext import Application, InlineQueryHandler, CommandHandler, CallbackContext # Enable logging logging.basicConfig( @@ -29,17 +29,17 @@ # Define a few command handlers. These usually take the two arguments update and # context. Error handlers also receive the raised TelegramError object in error. -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a message when the command /start is issued.""" - update.message.reply_text('Hi!') + await update.message.reply_text('Hi!') -def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a message when the command /help is issued.""" - update.message.reply_text('Help!') + await update.message.reply_text('Help!') -def inlinequery(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def inlinequery(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Handle the inline query.""" query = update.inline_query.query @@ -68,31 +68,23 @@ def inlinequery(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: ), ] - update.inline_query.answer(results) + await update.inline_query.answer(results) def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # on different commands - answer in Telegram - dispatcher.add_handler(CommandHandler("start", start)) - dispatcher.add_handler(CommandHandler("help", help_command)) + application.add_handler(CommandHandler("start", start)) + application.add_handler(CommandHandler("help", help_command)) # on non command i.e message - echo the message on Telegram - dispatcher.add_handler(InlineQueryHandler(inlinequery)) - - # Start the Bot - updater.start_polling() + application.add_handler(InlineQueryHandler(inlinequery)) - # Block until the user presses Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/inlinekeyboard.py b/examples/inlinekeyboard.py index d527e9071c3..730e70b23cd 100644 --- a/examples/inlinekeyboard.py +++ b/examples/inlinekeyboard.py @@ -12,7 +12,7 @@ from telegram.ext import ( CommandHandler, CallbackQueryHandler, - Updater, + Application, CallbackContext, ) @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Sends a message with three inline buttons attached.""" keyboard = [ [ @@ -36,40 +36,36 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: reply_markup = InlineKeyboardMarkup(keyboard) - update.message.reply_text('Please choose:', reply_markup=reply_markup) + await update.message.reply_text('Please choose:', reply_markup=reply_markup) -def button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def button(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Parses the CallbackQuery and updates the message text.""" query = update.callback_query # CallbackQueries need to be answered, even if no notification to the user is needed # Some clients may have trouble otherwise. See https://core.telegram.org/bots/api#callbackquery - query.answer() + await query.answer() - query.edit_message_text(text=f"Selected option: {query.data}") + await query.edit_message_text(text=f"Selected option: {query.data}") -def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Displays info on how to use the bot.""" - update.message.reply_text("Use /start to test this bot.") + await update.message.reply_text("Use /start to test this bot.") def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() - updater.dispatcher.add_handler(CommandHandler('start', start)) - updater.dispatcher.add_handler(CallbackQueryHandler(button)) - updater.dispatcher.add_handler(CommandHandler('help', help_command)) + application.application.add_handler(CommandHandler('start', start)) + application.application.add_handler(CallbackQueryHandler(button)) + application.application.add_handler(CommandHandler('help', help_command)) - # Start the Bot - updater.start_polling() - - # Run the bot until the user presses Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/inlinekeyboard2.py b/examples/inlinekeyboard2.py index a42bf5cf9fd..cb95637e666 100644 --- a/examples/inlinekeyboard2.py +++ b/examples/inlinekeyboard2.py @@ -4,9 +4,9 @@ """Simple inline keyboard bot with multiple CallbackQueryHandlers. -This Bot uses the Updater class to handle the bot. +This Bot uses the Application class to handle the bot. First, a few callback functions are defined as callback query handler. Then, those functions are -passed to the Dispatcher and registered at their respective places. +passed to the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: Example of a bot that uses inline keyboard that has multiple CallbackQueryHandlers arranged in a @@ -20,7 +20,7 @@ CommandHandler, CallbackQueryHandler, ConversationHandler, - Updater, + Application, CallbackContext, ) @@ -37,7 +37,7 @@ ONE, TWO, THREE, FOUR = range(4) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Send message on `/start`.""" # Get user that sent /start and log his name user = update.message.from_user @@ -54,18 +54,18 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: ] reply_markup = InlineKeyboardMarkup(keyboard) # Send message with text and appended InlineKeyboard - update.message.reply_text("Start handler, Choose a route", reply_markup=reply_markup) + await update.message.reply_text("Start handler, Choose a route", reply_markup=reply_markup) # Tell ConversationHandler that we're in state `FIRST` now return FIRST -def start_over(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def start_over(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Prompt same text & keyboard as `start` does but not as new message""" # Get CallbackQuery from Update query = update.callback_query # CallbackQueries need to be answered, even if no notification to the user is needed # Some clients may have trouble otherwise. See https://core.telegram.org/bots/api#callbackquery - query.answer() + await query.answer() keyboard = [ [ InlineKeyboardButton("1", callback_data=str(ONE)), @@ -76,14 +76,14 @@ def start_over(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: # Instead of sending a new message, edit the message that # originated the CallbackQuery. This gives the feeling of an # interactive menu. - query.edit_message_text(text="Start handler, Choose a route", reply_markup=reply_markup) + await query.edit_message_text(text="Start handler, Choose a route", reply_markup=reply_markup) return FIRST -def one(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def one(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Show new choice of buttons""" query = update.callback_query - query.answer() + await query.answer() keyboard = [ [ InlineKeyboardButton("3", callback_data=str(THREE)), @@ -91,16 +91,16 @@ def one(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: ] ] reply_markup = InlineKeyboardMarkup(keyboard) - query.edit_message_text( + await query.edit_message_text( text="First CallbackQueryHandler, Choose a route", reply_markup=reply_markup ) return FIRST -def two(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def two(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Show new choice of buttons""" query = update.callback_query - query.answer() + await query.answer() keyboard = [ [ InlineKeyboardButton("1", callback_data=str(ONE)), @@ -108,16 +108,16 @@ def two(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: ] ] reply_markup = InlineKeyboardMarkup(keyboard) - query.edit_message_text( + await query.edit_message_text( text="Second CallbackQueryHandler, Choose a route", reply_markup=reply_markup ) return FIRST -def three(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def three(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Show new choice of buttons""" query = update.callback_query - query.answer() + await query.answer() keyboard = [ [ InlineKeyboardButton("Yes, let's do it again!", callback_data=str(ONE)), @@ -125,17 +125,17 @@ def three(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: ] ] reply_markup = InlineKeyboardMarkup(keyboard) - query.edit_message_text( + await query.edit_message_text( text="Third CallbackQueryHandler. Do want to start over?", reply_markup=reply_markup ) # Transfer to conversation state `SECOND` return SECOND -def four(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def four(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Show new choice of buttons""" query = update.callback_query - query.answer() + await query.answer() keyboard = [ [ InlineKeyboardButton("2", callback_data=str(TWO)), @@ -143,29 +143,26 @@ def four(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: ] ] reply_markup = InlineKeyboardMarkup(keyboard) - query.edit_message_text( + await query.edit_message_text( text="Fourth CallbackQueryHandler, Choose a route", reply_markup=reply_markup ) return FIRST -def end(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def end(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Returns `ConversationHandler.END`, which tells the ConversationHandler that the conversation is over. """ query = update.callback_query - query.answer() - query.edit_message_text(text="See you next time!") + await query.answer() + await query.edit_message_text(text="See you next time!") return ConversationHandler.END def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # Setup conversation handler with the states FIRST and SECOND # Use the pattern parameter to pass CallbackQueries with specific @@ -190,16 +187,11 @@ def main() -> None: fallbacks=[CommandHandler('start', start)], ) - # Add ConversationHandler to dispatcher that will be used for handling updates - dispatcher.add_handler(conv_handler) - - # Start the Bot - updater.start_polling() + # Add ConversationHandler to application that will be used for handling updates + application.add_handler(conv_handler) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/nestedconversationbot.py b/examples/nestedconversationbot.py index 414a90d61e5..d80c5f7e997 100644 --- a/examples/nestedconversationbot.py +++ b/examples/nestedconversationbot.py @@ -4,7 +4,7 @@ """ First, a few callback functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -24,7 +24,7 @@ filters, ConversationHandler, CallbackQueryHandler, - Updater, + Application, CallbackContext, ) @@ -71,7 +71,7 @@ def _name_switcher(level: str) -> Tuple[str, str]: # Top level conversation callbacks -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Select an action: Adding parent/child or show data.""" text = ( "You may choose to add a family member, yourself, show the gathered data, or end the " @@ -92,32 +92,32 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: # If we're starting over we don't need to send a new message if context.user_data.get(START_OVER): - update.callback_query.answer() - update.callback_query.edit_message_text(text=text, reply_markup=keyboard) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text, reply_markup=keyboard) else: - update.message.reply_text( + await update.message.reply_text( "Hi, I'm Family Bot and I'm here to help you gather information about your family." ) - update.message.reply_text(text=text, reply_markup=keyboard) + await update.message.reply_text(text=text, reply_markup=keyboard) context.user_data[START_OVER] = False return SELECTING_ACTION -def adding_self(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def adding_self(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Add information about yourself.""" context.user_data[CURRENT_LEVEL] = SELF text = 'Okay, please tell me about yourself.' button = InlineKeyboardButton(text='Add info', callback_data=str(MALE)) keyboard = InlineKeyboardMarkup.from_button(button) - update.callback_query.answer() - update.callback_query.edit_message_text(text=text, reply_markup=keyboard) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text, reply_markup=keyboard) return DESCRIBING_SELF -def show_data(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def show_data(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Pretty print gathered data.""" def prettyprint(user_data: Dict[str, Any], level: str) -> str: @@ -145,32 +145,32 @@ def prettyprint(user_data: Dict[str, Any], level: str) -> str: buttons = [[InlineKeyboardButton(text='Back', callback_data=str(END))]] keyboard = InlineKeyboardMarkup(buttons) - update.callback_query.answer() - update.callback_query.edit_message_text(text=text, reply_markup=keyboard) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text, reply_markup=keyboard) user_data[START_OVER] = True return SHOWING -def stop(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def stop(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """End Conversation by command.""" - update.message.reply_text('Okay, bye.') + await update.message.reply_text('Okay, bye.') return END -def end(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def end(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """End conversation from InlineKeyboardButton.""" - update.callback_query.answer() + await update.callback_query.answer() text = 'See you around!' - update.callback_query.edit_message_text(text=text) + await update.callback_query.edit_message_text(text=text) return END # Second level conversation callbacks -def select_level(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def select_level(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Choose to add a parent or a child.""" text = 'You may add a parent or a child. Also you can show the gathered data or go back.' buttons = [ @@ -185,13 +185,13 @@ def select_level(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: ] keyboard = InlineKeyboardMarkup(buttons) - update.callback_query.answer() - update.callback_query.edit_message_text(text=text, reply_markup=keyboard) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text, reply_markup=keyboard) return SELECTING_LEVEL -def select_gender(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def select_gender(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Choose to add mother or father.""" level = update.callback_query.data context.user_data[CURRENT_LEVEL] = level @@ -212,22 +212,22 @@ def select_gender(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: ] keyboard = InlineKeyboardMarkup(buttons) - update.callback_query.answer() - update.callback_query.edit_message_text(text=text, reply_markup=keyboard) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text, reply_markup=keyboard) return SELECTING_GENDER -def end_second_level(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def end_second_level(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Return to top level conversation.""" context.user_data[START_OVER] = True - start(update, context) + await start(update, context) return END # Third level callbacks -def select_feature(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def select_feature(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Select a feature to update for the person.""" buttons = [ [ @@ -243,39 +243,39 @@ def select_feature(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str context.user_data[FEATURES] = {GENDER: update.callback_query.data} text = 'Please select a feature to update.' - update.callback_query.answer() - update.callback_query.edit_message_text(text=text, reply_markup=keyboard) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text, reply_markup=keyboard) # But after we do that, we need to send a new message else: text = 'Got it! Please select a feature to update.' - update.message.reply_text(text=text, reply_markup=keyboard) + await update.message.reply_text(text=text, reply_markup=keyboard) context.user_data[START_OVER] = False return SELECTING_FEATURE -def ask_for_input(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def ask_for_input(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Prompt user to input data for selected feature.""" context.user_data[CURRENT_FEATURE] = update.callback_query.data text = 'Okay, tell me.' - update.callback_query.answer() - update.callback_query.edit_message_text(text=text) + await update.callback_query.answer() + await update.callback_query.edit_message_text(text=text) return TYPING -def save_input(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def save_input(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Save input for feature and return to feature selection.""" user_data = context.user_data user_data[FEATURES][user_data[CURRENT_FEATURE]] = update.message.text user_data[START_OVER] = True - return select_feature(update, context) + return await select_feature(update, context) -def end_describing(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def end_describing(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """End gathering of features and return to parent conversation.""" user_data = context.user_data level = user_data[CURRENT_LEVEL] @@ -286,27 +286,24 @@ def end_describing(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int # Print upper level menu if level == SELF: user_data[START_OVER] = True - start(update, context) + await start(update, context) else: - select_level(update, context) + await select_level(update, context) return END -def stop_nested(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: +async def stop_nested(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Completely end conversation from within nested conversation.""" - update.message.reply_text('Okay, bye.') + await update.message.reply_text('Okay, bye.') return STOPPING def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # Set up third level ConversationHandler (collecting features) description_conv = ConversationHandler( @@ -378,15 +375,10 @@ def main() -> None: fallbacks=[CommandHandler('stop', stop)], ) - dispatcher.add_handler(conv_handler) - - # Start the Bot - updater.start_polling() + application.add_handler(conv_handler) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/passportbot.py b/examples/passportbot.py index 69af6c1b13a..47d5402ab7d 100644 --- a/examples/passportbot.py +++ b/examples/passportbot.py @@ -15,7 +15,7 @@ from pathlib import Path from telegram import Update -from telegram.ext import MessageHandler, filters, Updater, CallbackContext +from telegram.ext import MessageHandler, filters, Application, CallbackContext # Enable logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def msg(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def msg(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Downloads and prints the received passport data.""" # Retrieve passport data passport_data = update.message.passport_data @@ -62,27 +62,27 @@ def msg(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: ): print(data.type, len(data.files), 'files') for file in data.files: - actual_file = file.get_file() + actual_file = await file.get_file() print(actual_file) - actual_file.download() + await actual_file.download() if ( data.type in ('passport', 'driver_license', 'identity_card', 'internal_passport') and data.front_side ): - front_file = data.front_side.get_file() + front_file = await data.front_side.get_file() print(data.type, front_file) - front_file.download() + await front_file.download() if data.type in ('driver_license' and 'identity_card') and data.reverse_side: - reverse_file = data.reverse_side.get_file() + reverse_file = await data.reverse_side.get_file() print(data.type, reverse_file) - reverse_file.download() + await reverse_file.download() if ( data.type in ('passport', 'driver_license', 'identity_card', 'internal_passport') and data.selfie ): - selfie_file = data.selfie.get_file() + selfie_file = await data.selfie.get_file() print(data.type, selfie_file) - selfie_file.download() + await selfie_file.download() if data.type in ( 'passport', 'driver_license', @@ -96,30 +96,27 @@ def msg(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: ): print(data.type, len(data.translation), 'translation') for file in data.translation: - actual_file = file.get_file() + actual_file = await file.get_file() print(actual_file) - actual_file.download() + await actual_file.download() def main() -> None: """Start the bot.""" - # Create the Updater and pass it your token and private key + # Create the Application and pass it your token and private key private_key = Path('private.key') - updater = Updater.builder().token("TOKEN").private_key(private_key.read_bytes()).build() + application = ( + Application.builder().token("TOKEN").private_key(private_key.read_bytes()).build() + ) - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Get the application to register handlers + application = application.application # On messages that include passport data call msg - dispatcher.add_handler(MessageHandler(filters.PASSPORT_DATA, msg)) + application.add_handler(MessageHandler(filters.PASSPORT_DATA, msg)) - # Start the Bot - updater.start_polling() - - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/paymentbot.py b/examples/paymentbot.py index e44c0fcbf31..4ccafd1ed29 100644 --- a/examples/paymentbot.py +++ b/examples/paymentbot.py @@ -13,7 +13,7 @@ filters, PreCheckoutQueryHandler, ShippingQueryHandler, - Updater, + Application, CallbackContext, ) @@ -25,17 +25,19 @@ logger = logging.getLogger(__name__) -def start_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Displays info on how to use the bot.""" msg = ( "Use /shipping to get an invoice for shipping-payment, or /noshipping for an " "invoice without shipping." ) - update.message.reply_text(msg) + await update.message.reply_text(msg) -def start_with_shipping_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start_with_shipping_callback( + update: Update, context: CallbackContext.DEFAULT_TYPE +) -> None: """Sends an invoice with shipping-payment.""" chat_id = update.message.chat_id title = "Payment Example" @@ -53,7 +55,7 @@ def start_with_shipping_callback(update: Update, context: CallbackContext.DEFAUL # optionally pass need_name=True, need_phone_number=True, # need_email=True, need_shipping_address=True, is_flexible=True - context.bot.send_invoice( + await context.bot.send_invoice( chat_id, title, description, @@ -69,7 +71,9 @@ def start_with_shipping_callback(update: Update, context: CallbackContext.DEFAUL ) -def start_without_shipping_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start_without_shipping_callback( + update: Update, context: CallbackContext.DEFAULT_TYPE +) -> None: """Sends an invoice without shipping-payment.""" chat_id = update.message.chat_id title = "Payment Example" @@ -86,18 +90,18 @@ def start_without_shipping_callback(update: Update, context: CallbackContext.DEF # optionally pass need_name=True, need_phone_number=True, # need_email=True, need_shipping_address=True, is_flexible=True - context.bot.send_invoice( + await context.bot.send_invoice( chat_id, title, description, payload, provider_token, currency, prices ) -def shipping_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def shipping_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Answers the ShippingQuery with ShippingOptions""" query = update.shipping_query # check the payload, is this from your bot? if query.invoice_payload != 'Custom-Payload': # answer False pre_checkout_query - query.answer(ok=False, error_message="Something went wrong...") + await query.answer(ok=False, error_message="Something went wrong...") return # First option has a single LabeledPrice @@ -105,59 +109,55 @@ def shipping_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> # second option has an array of LabeledPrice objects price_list = [LabeledPrice('B1', 150), LabeledPrice('B2', 200)] options.append(ShippingOption('2', 'Shipping Option B', price_list)) - query.answer(ok=True, shipping_options=options) + await query.answer(ok=True, shipping_options=options) # after (optional) shipping, it's the pre-checkout -def precheckout_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def precheckout_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Answers the PreQecheckoutQuery""" query = update.pre_checkout_query # check the payload, is this from your bot? if query.invoice_payload != 'Custom-Payload': # answer False pre_checkout_query - query.answer(ok=False, error_message="Something went wrong...") + await query.answer(ok=False, error_message="Something went wrong...") else: - query.answer(ok=True) + await query.answer(ok=True) # finally, after contacting the payment provider... -def successful_payment_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def successful_payment_callback( + update: Update, context: CallbackContext.DEFAULT_TYPE +) -> None: """Confirms the successful payment.""" # do something after successfully receiving payment? - update.message.reply_text("Thank you for your payment!") + await update.message.reply_text("Thank you for your payment!") def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # simple start function - dispatcher.add_handler(CommandHandler("start", start_callback)) + application.add_handler(CommandHandler("start", start_callback)) # Add command handler to start the payment invoice - dispatcher.add_handler(CommandHandler("shipping", start_with_shipping_callback)) - dispatcher.add_handler(CommandHandler("noshipping", start_without_shipping_callback)) + application.add_handler(CommandHandler("shipping", start_with_shipping_callback)) + application.add_handler(CommandHandler("noshipping", start_without_shipping_callback)) # Optional handler if your product requires shipping - dispatcher.add_handler(ShippingQueryHandler(shipping_callback)) + application.add_handler(ShippingQueryHandler(shipping_callback)) # Pre-checkout handler to final check - dispatcher.add_handler(PreCheckoutQueryHandler(precheckout_callback)) + application.add_handler(PreCheckoutQueryHandler(precheckout_callback)) # Success! Notify your user! - dispatcher.add_handler(MessageHandler(filters.SUCCESSFUL_PAYMENT, successful_payment_callback)) - - # Start the Bot - updater.start_polling() + application.add_handler( + MessageHandler(filters.SUCCESSFUL_PAYMENT, successful_payment_callback) + ) - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/persistentconversationbot.py b/examples/persistentconversationbot.py index 8defda533bd..71eb2f5bcba 100644 --- a/examples/persistentconversationbot.py +++ b/examples/persistentconversationbot.py @@ -4,7 +4,7 @@ """ First, a few callback functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -24,7 +24,7 @@ filters, ConversationHandler, PicklePersistence, - Updater, + Application, CallbackContext, ) @@ -51,7 +51,7 @@ def facts_to_str(user_data: Dict[str, str]) -> str: return "\n".join(facts).join(['\n', '\n']) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Start the conversation, display any stored data and ask user for input.""" reply_text = "Hi! My name is Doctor Botter." if context.user_data: @@ -64,12 +64,12 @@ def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: " I will hold a more complex conversation with you. Why don't you tell me " "something about yourself?" ) - update.message.reply_text(reply_text, reply_markup=markup) + await update.message.reply_text(reply_text, reply_markup=markup) return CHOOSING -def regular_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def regular_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Ask the user for info about the selected predefined choice.""" text = update.message.text.lower() context.user_data['choice'] = text @@ -79,28 +79,28 @@ def regular_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int ) else: reply_text = f'Your {text}? Yes, I would love to hear about that!' - update.message.reply_text(reply_text) + await update.message.reply_text(reply_text) return TYPING_REPLY -def custom_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def custom_choice(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Ask the user for a description of a custom category.""" - update.message.reply_text( + await update.message.reply_text( 'Alright, please send me the category first, for example "Most impressive skill"' ) return TYPING_CHOICE -def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Store info provided by user and ask for the next category.""" text = update.message.text category = context.user_data['choice'] context.user_data[category] = text.lower() del context.user_data['choice'] - update.message.reply_text( + await update.message.reply_text( "Neat! Just so you know, this is what you already told me:" f"{facts_to_str(context.user_data)}" "You can tell me more, or change your opinion on something.", @@ -110,20 +110,20 @@ def received_information(update: Update, context: CallbackContext.DEFAULT_TYPE) return CHOOSING -def show_data(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def show_data(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Display the gathered info.""" - update.message.reply_text( + await update.message.reply_text( f"This is what you already told me: {facts_to_str(context.user_data)}" ) -def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: +async def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Display the gathered info and end the conversation.""" if 'choice' in context.user_data: del context.user_data['choice'] - update.message.reply_text( - f"I learned these facts about you: {facts_to_str(context.user_data)}Until next time!", + await update.message.reply_text( + f"I learned these facts about you: {facts_to_str(context.user_data)} Until next time!", reply_markup=ReplyKeyboardRemove(), ) return ConversationHandler.END @@ -131,12 +131,12 @@ def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: def main() -> None: """Run the bot.""" - # Create the Updater and pass it your bot's token. + # Create the Application and pass it your bot's token. persistence = PicklePersistence(filepath='conversationbot') - updater = Updater.builder().token("TOKEN").persistence(persistence).build() + application = Application.builder().token("TOKEN").persistence(persistence).build() - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Get the application to register handlers + application = application.application # Add conversation handler with the states CHOOSING, TYPING_CHOICE and TYPING_REPLY conv_handler = ConversationHandler( @@ -165,18 +165,13 @@ def main() -> None: persistent=True, ) - dispatcher.add_handler(conv_handler) + application.add_handler(conv_handler) show_data_handler = CommandHandler('show_data', show_data) - dispatcher.add_handler(show_data_handler) + application.add_handler(show_data_handler) - # Start the Bot - updater.start_polling() - - # Run the bot until you press Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT. This should be used most of the time, since - # start_polling() is non-blocking and will stop the bot gracefully. - updater.idle() + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/pollbot.py b/examples/pollbot.py index 85680613bd7..2b7898cd1f3 100644 --- a/examples/pollbot.py +++ b/examples/pollbot.py @@ -24,7 +24,7 @@ PollHandler, MessageHandler, filters, - Updater, + Application, CallbackContext, ) @@ -36,18 +36,18 @@ logger = logging.getLogger(__name__) -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Inform user about what this bot can do""" - update.message.reply_text( + await update.message.reply_text( 'Please select /poll to get a Poll, /quiz to get a Quiz or /preview' ' to generate a preview for your poll' ) -def poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Sends a predefined poll""" questions = ["Good", "Really good", "Fantastic", "Great"] - message = context.bot.send_poll( + message = await context.bot.send_poll( update.effective_chat.id, "How are you?", questions, @@ -66,7 +66,7 @@ def poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: context.bot_data.update(payload) -def receive_poll_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def receive_poll_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Summarize a users poll vote""" answer = update.poll_answer poll_id = answer.poll_id @@ -82,7 +82,7 @@ def receive_poll_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) - answer_string += questions[question_id] + " and " else: answer_string += questions[question_id] - context.bot.send_message( + await context.bot.send_message( context.bot_data[poll_id]["chat_id"], f"{update.effective_user.mention_html()} feels {answer_string}!", parse_mode=ParseMode.HTML, @@ -90,15 +90,15 @@ def receive_poll_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) - context.bot_data[poll_id]["answers"] += 1 # Close poll after three participants voted if context.bot_data[poll_id]["answers"] == 3: - context.bot.stop_poll( + await context.bot.stop_poll( context.bot_data[poll_id]["chat_id"], context.bot_data[poll_id]["message_id"] ) -def quiz(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def quiz(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a predefined poll""" questions = ["1", "2", "4", "20"] - message = update.effective_message.reply_poll( + message = await update.effective_message.reply_poll( "How many eggs do you need for a cake?", questions, type=Poll.QUIZ, correct_option_id=2 ) # Save some info about the poll the bot_data for later use in receive_quiz_answer @@ -108,7 +108,7 @@ def quiz(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: context.bot_data.update(payload) -def receive_quiz_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def receive_quiz_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Close quiz after three participants took it""" # the bot can receive closed poll updates we don't care about if update.poll.is_closed: @@ -119,26 +119,26 @@ def receive_quiz_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) - # this means this poll answer update is from an old poll, we can't stop it then except KeyError: return - context.bot.stop_poll(quiz_data["chat_id"], quiz_data["message_id"]) + await context.bot.stop_poll(quiz_data["chat_id"], quiz_data["message_id"]) -def preview(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def preview(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Ask user to create a poll and display a preview of it""" # using this without a type lets the user chooses what he wants (quiz or poll) button = [[KeyboardButton("Press me!", request_poll=KeyboardButtonPollType())]] message = "Press the button to let the bot generate a preview for your poll" # using one_time_keyboard to hide the keyboard - update.effective_message.reply_text( + await update.effective_message.reply_text( message, reply_markup=ReplyKeyboardMarkup(button, one_time_keyboard=True) ) -def receive_poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def receive_poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """On receiving polls, reply to it by a closed poll copying the received poll""" actual_poll = update.effective_message.poll # Only need to set the question and options, since all other parameters don't matter for # a closed poll - update.effective_message.reply_poll( + await update.effective_message.reply_poll( question=actual_poll.question, options=[o.text for o in actual_poll.options], # with is_closed true, the poll/quiz is immediately closed @@ -147,31 +147,26 @@ def receive_poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: ) -def help_handler(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def help_handler(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Display a help message""" - update.message.reply_text("Use /quiz, /poll or /preview to test this bot.") + await update.message.reply_text("Use /quiz, /poll or /preview to test this bot.") def main() -> None: """Run bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - dispatcher = updater.dispatcher - dispatcher.add_handler(CommandHandler('start', start)) - dispatcher.add_handler(CommandHandler('poll', poll)) - dispatcher.add_handler(PollAnswerHandler(receive_poll_answer)) - dispatcher.add_handler(CommandHandler('quiz', quiz)) - dispatcher.add_handler(PollHandler(receive_quiz_answer)) - dispatcher.add_handler(CommandHandler('preview', preview)) - dispatcher.add_handler(MessageHandler(filters.POLL, receive_poll)) - dispatcher.add_handler(CommandHandler('help', help_handler)) - - # Start the Bot - updater.start_polling() - - # Run the bot until the user presses Ctrl-C or the process receives SIGINT, - # SIGTERM or SIGABRT - updater.idle() + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() + application.add_handler(CommandHandler('start', start)) + application.add_handler(CommandHandler('poll', poll)) + application.add_handler(PollAnswerHandler(receive_poll_answer)) + application.add_handler(CommandHandler('quiz', quiz)) + application.add_handler(PollHandler(receive_quiz_answer)) + application.add_handler(CommandHandler('preview', preview)) + application.add_handler(MessageHandler(filters.POLL, receive_poll)) + application.add_handler(CommandHandler('help', help_handler)) + + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/examples/rawapibot.py b/examples/rawapibot.py index 09e7e3a7c90..19a9d512fe0 100644 --- a/examples/rawapibot.py +++ b/examples/rawapibot.py @@ -6,27 +6,27 @@ on the telegram.ext bot framework. This program is dedicated to the public domain under the CC0 license. """ +import asyncio import logging from typing import NoReturn -from time import sleep import telegram -from telegram.error import NetworkError, Unauthorized +from telegram.error import NetworkError, Forbidden UPDATE_ID = None -def main() -> NoReturn: +async def main() -> NoReturn: """Run the bot.""" global UPDATE_ID # Telegram Bot Authorization Token bot = telegram.Bot('TOKEN') # get the first pending update_id, this is so we can skip over it in case - # we get an "Unauthorized" exception. + # we get an "Forbidden" exception. try: - UPDATE_ID = bot.get_updates()[0].update_id + UPDATE_ID = (await bot.get_updates())[0].update_id except IndexError: UPDATE_ID = None @@ -34,27 +34,27 @@ def main() -> NoReturn: while True: try: - echo(bot) + await echo(bot) except NetworkError: - sleep(1) - except Unauthorized: + await asyncio.sleep(1) + except Forbidden: # The user has removed or blocked the bot. UPDATE_ID += 1 -def echo(bot: telegram.Bot) -> None: +async def echo(bot: telegram.Bot) -> None: """Echo the message the user sent.""" global UPDATE_ID # Request updates after the last update_id - for update in bot.get_updates(offset=UPDATE_ID, timeout=10): + for update in await bot.get_updates(offset=UPDATE_ID, timeout=10): UPDATE_ID = update.update_id + 1 # your bot can receive updates without messages # and not all messages contain text if update.message and update.message.text: # Reply to the message - update.message.reply_text(update.message.text) + await update.message.reply_text(update.message.text) if __name__ == '__main__': - main() + asyncio.run(main()) diff --git a/examples/timerbot.py b/examples/timerbot.py index 19e864fcce9..0d04241b5f8 100644 --- a/examples/timerbot.py +++ b/examples/timerbot.py @@ -5,11 +5,11 @@ """ Simple Bot to send timed Telegram messages. -This Bot uses the Updater class to handle the bot and the JobQueue to send +This Bot uses the Application class to handle the bot and the JobQueue to send timed messages. First, a few handler functions are defined. Then, those functions are passed to -the Dispatcher and registered at their respective places. +the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: @@ -21,7 +21,7 @@ import logging from telegram import Update -from telegram.ext import CommandHandler, Updater, CallbackContext +from telegram.ext import CommandHandler, Application, CallbackContext # Enable logging logging.basicConfig( @@ -36,15 +36,15 @@ # since context is an unused local variable. # This being an example and not having context present confusing beginners, # we decided to have it present as context. -def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Sends explanation on how to use the bot.""" - update.message.reply_text('Hi! Use /set to set a timer') + await update.message.reply_text('Hi! Use /set to set a timer') -def alarm(context: CallbackContext.DEFAULT_TYPE) -> None: +async def alarm(context: CallbackContext.DEFAULT_TYPE) -> None: """Send the alarm message.""" job = context.job - context.bot.send_message(job.context, text='Beep!') + await context.bot.send_message(job.context, text='Beep!') def remove_job_if_exists(name: str, context: CallbackContext.DEFAULT_TYPE) -> bool: @@ -57,14 +57,14 @@ def remove_job_if_exists(name: str, context: CallbackContext.DEFAULT_TYPE) -> bo return True -def set_timer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def set_timer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Add a job to the queue.""" chat_id = update.message.chat_id try: # args[0] should contain the time for the timer in seconds due = int(context.args[0]) if due < 0: - update.message.reply_text('Sorry we can not go back to future!') + await update.message.reply_text('Sorry we can not go back to future!') return job_removed = remove_job_if_exists(str(chat_id), context) @@ -73,41 +73,33 @@ def set_timer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: text = 'Timer successfully set!' if job_removed: text += ' Old one was removed.' - update.message.reply_text(text) + await update.message.reply_text(text) except (IndexError, ValueError): - update.message.reply_text('Usage: /set ') + await update.message.reply_text('Usage: /set ') -def unset(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: +async def unset(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Remove the job if the user changed their mind.""" chat_id = update.message.chat_id job_removed = remove_job_if_exists(str(chat_id), context) text = 'Timer successfully cancelled!' if job_removed else 'You have no active timer.' - update.message.reply_text(text) + await update.message.reply_text(text) def main() -> None: """Run bot.""" - # Create the Updater and pass it your bot's token. - updater = Updater.builder().token("TOKEN").build() - - # Get the dispatcher to register handlers - dispatcher = updater.dispatcher + # Create the Application and pass it your bot's token. + application = Application.builder().token("TOKEN").build() # on different commands - answer in Telegram - dispatcher.add_handler(CommandHandler("start", start)) - dispatcher.add_handler(CommandHandler("help", start)) - dispatcher.add_handler(CommandHandler("set", set_timer)) - dispatcher.add_handler(CommandHandler("unset", unset)) - - # Start the Bot - updater.start_polling() - - # Block until you press Ctrl-C or the process receives SIGINT, SIGTERM or - # SIGABRT. This should be used most of the time, since start_polling() is - # non-blocking and will stop the bot gracefully. - updater.idle() + application.add_handler(CommandHandler("start", start)) + application.add_handler(CommandHandler("help", start)) + application.add_handler(CommandHandler("set", set_timer)) + application.add_handler(CommandHandler("unset", unset)) + + # Run the bot until the user presses Ctrl-C + application.run_polling() if __name__ == '__main__': diff --git a/pyproject.toml b/pyproject.toml index 38ece5d5b6e..a6c381a6ffb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,5 +7,4 @@ skip-string-normalization = true # so that pre-commit run --all-files does the correct thing # see https://github.com/psf/black/issues/1778 force-exclude = '^(?!/(telegram|examples|tests)/).*\.py$' -include = '(telegram|examples|tests)/.*\.py$' -exclude = 'telegram/vendor' \ No newline at end of file +include = '(telegram|examples|tests)/.*\.py$' \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 03ae811a147..ee7a8b6c744 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,6 +10,7 @@ mypy==0.910 pyupgrade==2.29.0 pytest==6.2.5 +pytest-asyncio==0.16.0 flaky beautifulsoup4 diff --git a/requirements.txt b/requirements.txt index 967fd782804..b452ab92b47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ # Make sure to install those as additional_dependencies in the # pre-commit hooks for pylint & mypy -certifi +httpx ~= 0.22.0 # only telegram.ext: # Keep this line here; used in setup(-raw).py tornado>=6.1 -APScheduler==3.6.3 +APScheduler==3.8.1 pytz>=2018.6 cachetools==4.2.2 diff --git a/setup.cfg b/setup.cfg index b4510d76c69..b39b7bf5f68 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,10 +13,9 @@ upload-dir = docs/build/html max-line-length = 99 ignore = W503, W605 extend-ignore = E203 -exclude = setup.py, setup-raw.py docs/source/conf.py, telegram/vendor - -[pylint] -ignore=vendor +exclude = setup.py, setup-raw.py docs/source/conf.py +per-file-ignores = + telegram/ext/_jobqueue.py:E402 [pylint.message-control] disable = C0330,R0801,R0913,R0904,R0903,R0902,W0511,C0116,C0115,W0703,R0914,R0914,C0302,R0912,R0915,R0401 @@ -40,7 +39,6 @@ concurrency = thread, multiprocessing omit = tests/ telegram/__main__.py - telegram/vendor/* [coverage:report] exclude_lines = @@ -69,8 +67,5 @@ strict_optional = False [mypy-telegram.ext._utils.webhookhandler] warn_unused_ignores = False -[mypy-urllib3.*] -ignore_missing_imports = True - [mypy-apscheduler.*] ignore_missing_imports = True diff --git a/setup.py b/setup.py index 608d095af3e..f14126c96cd 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,11 @@ #!/usr/bin/env python """The setup and build script for the python-telegram-bot library.""" import subprocess -import sys from pathlib import Path +import sys from setuptools import setup, find_packages -UPSTREAM_URLLIB3_FLAG = '--with-upstream-urllib3' - def get_requirements(raw=False): """Build the requirements list for this project""" @@ -33,11 +31,6 @@ def get_packages_requirements(raw=False): exclude.append('telegram.ext*') packs = find_packages(exclude=exclude) - # Allow for a package install to not use the vendored urllib3 - if UPSTREAM_URLLIB3_FLAG in sys.argv: - sys.argv.remove(UPSTREAM_URLLIB3_FLAG) - reqs.append('urllib3 >= 1.19.1') - packs = [x for x in packs if not x.startswith('telegram.vendor.ptb_urllib3')] return packs, reqs @@ -79,7 +72,7 @@ def get_setup_kwargs(raw=False): install_requires=requirements, extras_require={ 'json': 'ujson', - 'socks': 'PySocks', + 'socks': 'httpx[socks]', # 3.4-3.4.3 contained some cyclical import bugs 'passport': 'cryptography!=3.4,!=3.4.1,!=3.4.2,!=3.4.3', }, diff --git a/telegram/__main__.py b/telegram/__main__.py index f6025d6db43..542c56fe8d5 100644 --- a/telegram/__main__.py +++ b/telegram/__main__.py @@ -21,8 +21,6 @@ import sys from typing import Optional -import certifi - from . import __version__ as telegram_ver from .constants import BOT_API_VERSION @@ -41,7 +39,6 @@ def print_ver_info() -> None: # skipcq: PY-D0003 git_revision = _git_revision() print(f'python-telegram-bot {telegram_ver}' + (f' ({git_revision})' if git_revision else '')) print(f'Bot API {BOT_API_VERSION}') - print(f'certifi {certifi.__version__}') # type: ignore[attr-defined] sys_version = sys.version.replace('\n', ' ') print(f'Python {sys_version}') diff --git a/telegram/_bot.py b/telegram/_bot.py index 406d4394930..9d4e541f07d 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -19,10 +19,12 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains an object that represents a Telegram Bot.""" - +import asyncio import functools import logging +from contextlib import AbstractAsyncContextManager from datetime import datetime +from types import TracebackType from typing import ( TYPE_CHECKING, @@ -37,6 +39,7 @@ cast, Sequence, Any, + Type, ) try: @@ -89,15 +92,13 @@ WebhookInfo, InlineKeyboardMarkup, ChatInviteLink, - ReplyKeyboardMarkup, - ReplyKeyboardRemove, - ForceReply, ) from telegram.error import InvalidToken, TelegramError from telegram.constants import InlineQueryLimit -from telegram.request import Request -from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue, DEFAULT_20 -from telegram._utils.datetime import to_timestamp +from telegram.request import BaseRequest, RequestData +from telegram.request._requestparameter import RequestParameter +from telegram.request._httpxrequest import HTTPXRequest +from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue from telegram._utils.files import is_local_file, parse_file_input from telegram._utils.types import FileInput, JSONDict, ODVInput, DVInput, ReplyMarkup @@ -113,14 +114,28 @@ ) RT = TypeVar('RT') +BT = TypeVar('BT', bound='Bot') -class Bot(TelegramObject): +class Bot(TelegramObject, AbstractAsyncContextManager): """This object represents a Telegram Bot. - .. versionadded:: 13.2 - Objects of this class are comparable in terms of equality. Two objects of this class are - considered equal, if their :attr:`bot` is equal. + Instances of this class can be used as asyncio context managers, where + + .. code:: python + + async with bot: + # code + + is roughly equivalent to + + .. code:: python + + try: + await bot.initialize() + # code + finally: + await request_object.shutdown() Note: Most bot methods have the argument ``api_kwargs`` which allows to pass arbitrary keywords @@ -128,6 +143,10 @@ class Bot(TelegramObject): incorporated into PTB. However, this is not guaranteed to work, i.e. it will fail for passing files. + .. versionadded:: 13.2 + Objects of this class are comparable in terms of equality. Two objects of this class are + considered equal, if their :attr:`bot` is equal. + .. versionchanged:: 14.0 * Removed the deprecated methods ``kick_chat_member``, ``kickChatMember``, @@ -141,8 +160,14 @@ class Bot(TelegramObject): token (:obj:`str`): Bot's unique authentication. base_url (:obj:`str`, optional): Telegram Bot API service URL. base_file_url (:obj:`str`, optional): Telegram Bot API file URL. - request (:obj:`telegram.request.Request`, optional): Pre initialized - :obj:`telegram.request.Request`. + request (:class:`telegram.request.BaseRequest`, optional): Pre initialized + :class:`telegram.request.BaseRequest` instances. Will be used for all bot methods + *except* for :attr:`get_updates`. If not passed, an instance of + :class:`telegram.request.HTTPXRequest` will be used. + request (:class:`telegram.request.BaseRequest`, optional): Pre initialized + :class:`telegram.request.BaseRequest` instances. Will be used exclusively for + :attr:`get_updates`. If not passed, an instance of + :class:`telegram.request.HTTPXRequest` will be used. private_key (:obj:`bytes`, optional): Private key for decryption of telegram passport data. private_key_password (:obj:`bytes`, optional): Password for above private key. @@ -155,7 +180,8 @@ class Bot(TelegramObject): 'private_key', '_bot_user', '_request', - 'logger', + '_logger', + '_initialized', ) def __init__( @@ -163,7 +189,8 @@ def __init__( token: str, base_url: str = 'https://api.telegram.org/bot', base_file_url: str = 'https://api.telegram.org/file/bot', - request: 'Request' = None, + request: BaseRequest = None, + get_updates_request: BaseRequest = None, private_key: bytes = None, private_key_password: bytes = None, ): @@ -172,9 +199,14 @@ def __init__( self.base_url = base_url + self.token self.base_file_url = base_file_url + self.token self._bot_user: Optional[User] = None - self._request = request or Request() self.private_key = None - self.logger = logging.getLogger(__name__) + self._logger = logging.getLogger(__name__) + self._initialized = False + + self._request: Tuple[BaseRequest, BaseRequest] = ( + HTTPXRequest() if get_updates_request is None else get_updates_request, + HTTPXRequest() if request is None else request, + ) if private_key: if not CRYPTO_INSTALLED: @@ -192,18 +224,16 @@ def _log(func: Any): # type: ignore[no-untyped-def] # skipcq: PY-D0003 logger = logging.getLogger(func.__module__) @functools.wraps(func) - def decorator(*args, **kwargs): # type: ignore[no-untyped-def] + async def decorator(*args, **kwargs): # type: ignore[no-untyped-def] logger.debug('Entering: %s', func.__name__) - result = func(*args, **kwargs) + result = await func(*args, **kwargs) logger.debug(result) logger.debug('Exiting: %s', func.__name__) return result return decorator - def _insert_defaults( # pylint: disable=no-self-use - self, data: Dict[str, object], timeout: ODVInput[float] - ) -> Optional[float]: + def _insert_defaults(self, data: Dict[str, object]) -> None: # pylint: disable=no-self-use """This method is here to make ext.Defaults work. Because we need to be able to tell e.g. `send_message(chat_id, text)` from `send_message(chat_id, text, parse_mode=None)`, the default values for `parse_mode` etc are not `None` but `DEFAULT_NONE`. While this *could* @@ -234,44 +264,50 @@ def _insert_defaults( # pylint: disable=no-self-use else: data[key] = DefaultValue.get_value(val) - return DefaultValue.get_value(timeout) - - def _post( + async def _post( self, - endpoint: str, - data: JSONDict = None, - timeout: ODVInput[float] = DEFAULT_NONE, - api_kwargs: JSONDict = None, + endpoint: str, # 'sendMessage', 'sendPhoto', 'getMe' + data: JSONDict = None, # {'chat_id': 123, 'text': 'Hello there!'} + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, # {'new_param': whatever} ) -> Union[bool, JSONDict, None]: if data is None: data = {} if api_kwargs: - if data: - data.update(api_kwargs) - else: - data = api_kwargs + data.update(api_kwargs) # Insert is in-place, so no return value for data - if endpoint != 'getUpdates': - effective_timeout = self._insert_defaults(data, timeout) - else: - effective_timeout = cast(float, timeout) + self._insert_defaults(data) # Drop any None values because Telegram doesn't handle them well data = {key: value for key, value in data.items() if value is not None} - # We do this here so that _insert_defaults (see above) has a chance to convert + # This also converts datetimes into timestamps. + # We don't do this earlier so that _insert_defaults (see above) has a chance to convert # to the default timezone in case this is called by ExtBot - for key, value in data.items(): - if isinstance(value, datetime): - data[key] = to_timestamp(value) + request_data = RequestData( + parameters=[RequestParameter.from_input(key, value) for key, value in data.items()], + ) - return self.request.post( - f'{self.base_url}/{endpoint}', data=data, timeout=effective_timeout + if endpoint == 'getUpdates': + request = self._request[0] + else: + request = self._request[1] + + return await request.post( + url=f"{self.base_url}/{endpoint}", + request_data=request_data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, ) - def _message( + async def _send_message( self, endpoint: str, data: JSONDict, @@ -279,7 +315,10 @@ def _message( disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> Union[bool, Message]: @@ -293,24 +332,68 @@ def _message( data['protect_content'] = protect_content if reply_markup is not None: - markups = (InlineKeyboardMarkup, ReplyKeyboardMarkup, ForceReply, ReplyKeyboardRemove) - if isinstance(reply_markup, markups): - # We need to_json() instead of to_dict() here, because reply_markups may be - # attached to media messages, which aren't json dumped by telegram.request - data['reply_markup'] = reply_markup.to_json() - else: - data['reply_markup'] = reply_markup + data['reply_markup'] = reply_markup - result = self._post(endpoint, data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + endpoint, + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) if result is True: return result return Message.de_json(result, self) # type: ignore[return-value, arg-type] + async def initialize(self) -> None: + """Initialize resources used by this class. Currently calls :meth:`get_me` to + cache :attr:`bot` and calls :meth:`telegram.request.BaseRequest.initialize` for + :attr:`request`. + """ + if not self._initialized: + await asyncio.gather(self._request[0].initialize(), self._request[1].initialize()) + await self.get_me() + self._initialized = True + + async def shutdown(self) -> None: + """Stop & clear resources used by this class. Currently just calls + :meth:`telegram.request.BaseRequest.stop` for :attr:`request`. + """ + if self._initialized: + await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown()) + self._initialized = False + + async def __aenter__(self: BT) -> BT: + try: + await self.initialize() + return self + except Exception as exc: + await self.shutdown() + raise exc + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # Make sure not to return `True` so that exceptions are not suppressed + # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ + await self.shutdown() + @property - def request(self) -> Request: # skip-cq: PY-D0003 - return self._request + def request(self) -> BaseRequest: + """The :class:`~telegram.request.BaseRequest` object used by this bot. + + Warning: + Requests to the Bot API are made by the various methods of this class. This attribute + should *not* be used manually. + """ + return self._request[1] @staticmethod def _validate_token(token: str) -> str: @@ -326,29 +409,48 @@ def _validate_token(token: str) -> str: @property def bot(self) -> User: - """:class:`telegram.User`: User instance for the bot as returned by :meth:`get_me`.""" + """:class:`telegram.User`: User instance for the bot as returned by :meth:`get_me`. + + Warning: + This value is the cached return value of :meth:`get_me`. If the bots profile is + changed during runtime, this value won't reflect the changes until :meth:`get_me` is + called again. + + .. seealso:: :meth:`initialize` + """ if self._bot_user is None: - self._bot_user = self.get_me() + raise RuntimeError( + f'{self.__class__.__name__} is not properly initialized. Call ' + f'`{self.__class__.__name__}.initialize` before accessing this property.' + ) return self._bot_user @property def id(self) -> int: # pylint: disable=invalid-name - """:obj:`int`: Unique identifier for this bot.""" + """:obj:`int`: Unique identifier for this bot. Shortcut for the corresponding attribute of + :attr:`bot`. + """ return self.bot.id @property def first_name(self) -> str: - """:obj:`str`: Bot's first name.""" + """:obj:`str`: Bot's first name. Shortcut for the corresponding attribute of + :attr:`bot`. + """ return self.bot.first_name @property def last_name(self) -> str: - """:obj:`str`: Optional. Bot's last name.""" + """:obj:`str`: Optional. Bot's last name. Shortcut for the corresponding attribute of + :attr:`bot`. + """ return self.bot.last_name # type: ignore @property def username(self) -> str: - """:obj:`str`: Bot's username.""" + """:obj:`str`: Bot's username. Shortcut for the corresponding attribute of + :attr:`bot`. + """ return self.bot.username # type: ignore @property @@ -358,26 +460,39 @@ def link(self) -> str: @property def can_join_groups(self) -> bool: - """:obj:`bool`: Bot's :attr:`telegram.User.can_join_groups` attribute.""" + """:obj:`bool`: Bot's :attr:`telegram.User.can_join_groups` attribute. Shortcut for the + corresponding attribute of :attr:`bot`. + """ return self.bot.can_join_groups # type: ignore @property def can_read_all_group_messages(self) -> bool: - """:obj:`bool`: Bot's :attr:`telegram.User.can_read_all_group_messages` attribute.""" + """:obj:`bool`: Bot's :attr:`telegram.User.can_read_all_group_messages` attribute. + Shortcut for the corresponding attribute of :attr:`bot`. + """ return self.bot.can_read_all_group_messages # type: ignore @property def supports_inline_queries(self) -> bool: - """:obj:`bool`: Bot's :attr:`telegram.User.supports_inline_queries` attribute.""" + """:obj:`bool`: Bot's :attr:`telegram.User.supports_inline_queries` attribute. + Shortcut for the corresponding attribute of :attr:`bot`. + """ return self.bot.supports_inline_queries # type: ignore @property def name(self) -> str: - """:obj:`str`: Bot's @username.""" + """:obj:`str`: Bot's @username. Shortcut for the corresponding attribute of :attr:`bot`.""" return f'@{self.username}' @_log - def get_me(self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None) -> User: + async def get_me( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, + ) -> User: """A simple method for testing your bot's auth token. Requires no parameters. Args: @@ -395,14 +510,19 @@ def get_me(self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = :class:`telegram.error.TelegramError` """ - result = self._post('getMe', timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'getMe', + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) self._bot_user = User.de_json(result, self) # type: ignore[return-value, arg-type] - return self._bot_user # type: ignore[return-value] @_log - def send_message( + async def send_message( self, chat_id: Union[int, str], text: str, @@ -411,7 +531,10 @@ def send_message( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -467,26 +590,32 @@ def send_message( } if entities: - data['entities'] = [me.to_dict() for me in entities] + data['entities'] = entities - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendMessage', data, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def delete_message( + async def delete_message( self, chat_id: Union[str, int], message_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -522,19 +651,28 @@ def delete_message( """ data: JSONDict = {'chat_id': chat_id, 'message_id': message_id} - - result = self._post('deleteMessage', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'deleteMessage', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def forward_message( + async def forward_message( self, chat_id: Union[int, str], from_chat_id: Union[str, int], message_id: int, disable_notification: DVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> Message: @@ -582,17 +720,20 @@ def forward_message( data['from_chat_id'] = from_chat_id if message_id: data['message_id'] = message_id - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'forwardMessage', data, disable_notification=disable_notification, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_photo( + async def send_photo( self, chat_id: Union[int, str], photo: Union[FileInput, 'PhotoSize'], @@ -600,7 +741,10 @@ def send_photo( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -676,22 +820,25 @@ def send_photo( data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendPhoto', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_audio( + async def send_audio( self, chat_id: Union[int, str], audio: Union[FileInput, 'Audio'], @@ -702,7 +849,10 @@ def send_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -805,24 +955,27 @@ def send_audio( data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities if thumb: - data['thumb'] = parse_file_input(thumb, attach=True) + data['thumb'] = parse_file_input(thumb) - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendAudio', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_document( + async def send_document( self, chat_id: Union[int, str], document: Union[FileInput, 'Document'], @@ -831,7 +984,10 @@ def send_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -922,33 +1078,39 @@ def send_document( data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities if disable_content_type_detection is not None: data['disable_content_type_detection'] = disable_content_type_detection if thumb: - data['thumb'] = parse_file_input(thumb, attach=True) + data['thumb'] = parse_file_input(thumb) - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendDocument', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_sticker( + async def send_sticker( self, chat_id: Union[int, str], sticker: Union[FileInput, 'Sticker'], disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -999,21 +1161,23 @@ def send_sticker( """ data: JSONDict = {'chat_id': chat_id, 'sticker': parse_file_input(sticker, Sticker)} - - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendSticker', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_video( + async def send_video( self, chat_id: Union[int, str], video: Union[FileInput, 'Video'], @@ -1022,7 +1186,10 @@ def send_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, height: int = None, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1126,7 +1293,7 @@ def send_video( if caption: data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities if supports_streaming: data['supports_streaming'] = supports_streaming if width: @@ -1134,22 +1301,25 @@ def send_video( if height: data['height'] = height if thumb: - data['thumb'] = parse_file_input(thumb, attach=True) + data['thumb'] = parse_file_input(thumb) - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendVideo', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_video_note( + async def send_video_note( self, chat_id: Union[int, str], video_note: Union[FileInput, 'VideoNote'], @@ -1158,7 +1328,10 @@ def send_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1241,22 +1414,25 @@ def send_video_note( if length is not None: data['length'] = length if thumb: - data['thumb'] = parse_file_input(thumb, attach=True) + data['thumb'] = parse_file_input(thumb) - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendVideoNote', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_animation( + async def send_animation( self, chat_id: Union[int, str], animation: Union[FileInput, 'Animation'], @@ -1269,7 +1445,10 @@ def send_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -1365,26 +1544,29 @@ def send_animation( if height: data['height'] = height if thumb: - data['thumb'] = parse_file_input(thumb, attach=True) + data['thumb'] = parse_file_input(thumb) if caption: data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendAnimation', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_voice( + async def send_voice( self, chat_id: Union[int, str], voice: Union[FileInput, 'Voice'], @@ -1393,7 +1575,10 @@ def send_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1477,22 +1662,25 @@ def send_voice( data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendVoice', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_media_group( + async def send_media_group( self, chat_id: Union[int, str], media: List[ @@ -1500,7 +1688,10 @@ def send_media_group( ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -1545,12 +1736,20 @@ def send_media_group( if reply_to_message_id: data['reply_to_message_id'] = reply_to_message_id - result = self._post('sendMediaGroup', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'sendMediaGroup', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return Message.de_list(result, self) # type: ignore @_log - def send_location( + async def send_location( self, chat_id: Union[int, str], latitude: float = None, @@ -1558,7 +1757,10 @@ def send_location( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, location: Location = None, live_period: int = None, api_kwargs: JSONDict = None, @@ -1643,20 +1845,23 @@ def send_location( if proximity_alert_radius: data['proximity_alert_radius'] = proximity_alert_radius - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendLocation', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def edit_message_live_location( + async def edit_message_live_location( self, chat_id: Union[str, int] = None, message_id: int = None, @@ -1665,7 +1870,10 @@ def edit_message_live_location( longitude: float = None, location: Location = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, horizontal_accuracy: float = None, heading: int = None, @@ -1737,22 +1945,28 @@ def edit_message_live_location( if proximity_alert_radius: data['proximity_alert_radius'] = proximity_alert_radius - return self._message( + return await self._send_message( 'editMessageLiveLocation', data, - timeout=timeout, reply_markup=reply_markup, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def stop_message_live_location( + async def stop_message_live_location( self, chat_id: Union[str, int] = None, message_id: int = None, inline_message_id: int = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """Use this method to stop updating a live location message sent by the bot or via the bot @@ -1787,16 +2001,19 @@ def stop_message_live_location( if inline_message_id: data['inline_message_id'] = inline_message_id - return self._message( + return await self._send_message( 'stopMessageLiveLocation', data, - timeout=timeout, reply_markup=reply_markup, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def send_venue( + async def send_venue( self, chat_id: Union[int, str], latitude: float = None, @@ -1807,7 +2024,10 @@ def send_venue( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, venue: Venue = None, foursquare_type: str = None, api_kwargs: JSONDict = None, @@ -1903,20 +2123,23 @@ def send_venue( if google_place_type: data['google_place_type'] = google_place_type - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendVenue', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_contact( + async def send_contact( self, chat_id: Union[int, str], phone_number: str = None, @@ -1925,7 +2148,10 @@ def send_contact( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, contact: Contact = None, vcard: str = None, api_kwargs: JSONDict = None, @@ -1997,27 +2223,33 @@ def send_contact( if vcard: data['vcard'] = vcard - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendContact', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_game( + async def send_game( self, chat_id: Union[int, str], game_short_name: str, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -2057,24 +2289,30 @@ def send_game( """ data: JSONDict = {'chat_id': chat_id, 'game_short_name': game_short_name} - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendGame', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def send_chat_action( + async def send_chat_action( self, chat_id: Union[str, int], action: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -2103,9 +2341,15 @@ def send_chat_action( """ data: JSONDict = {'chat_id': chat_id, 'action': action} - - result = self._post('sendChatAction', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'sendChatAction', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] def _effective_inline_results( # pylint: disable=no-self-use @@ -2187,7 +2431,7 @@ def _insert_defaults_for_ilq_results( # pylint: disable=no-self-use ) @_log - def answer_inline_query( + async def answer_inline_query( self, inline_query_id: str, results: Union[ @@ -2198,7 +2442,10 @@ def answer_inline_query( next_offset: str = None, switch_pm_text: str = None, switch_pm_parameter: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, current_offset: str = None, api_kwargs: JSONDict = None, ) -> bool: @@ -2268,9 +2515,7 @@ def answer_inline_query( for result in effective_results: self._insert_defaults_for_ilq_results(result) - results_dicts = [res.to_dict() for res in effective_results] - - data: JSONDict = {'inline_query_id': inline_query_id, 'results': results_dicts} + data: JSONDict = {'inline_query_id': inline_query_id, 'results': effective_results} if cache_time or cache_time == 0: data['cache_time'] = cache_time @@ -2283,20 +2528,26 @@ def answer_inline_query( if switch_pm_parameter: data['switch_pm_parameter'] = switch_pm_parameter - return self._post( # type: ignore[return-value] + return await self._post( # type: ignore[return-value] 'answerInlineQuery', data, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def get_user_profile_photos( + async def get_user_profile_photos( self, user_id: Union[str, int], offset: int = None, limit: int = 100, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Optional[UserProfilePhotos]: """Use this method to get a list of profile pictures for a user. @@ -2327,17 +2578,28 @@ def get_user_profile_photos( if limit: data['limit'] = limit - result = self._post('getUserProfilePhotos', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'getUserProfilePhotos', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return UserProfilePhotos.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def get_file( + async def get_file( self, file_id: Union[ str, Animation, Audio, ChatPhoto, Document, PhotoSize, Sticker, Video, VideoNote, Voice ], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> File: """ @@ -2382,21 +2644,34 @@ def get_file( data: JSONDict = {'file_id': file_id} - result = self._post('getFile', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'getFile', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) if result.get('file_path') and not is_local_file( # type: ignore[union-attr] result['file_path'] # type: ignore[index] ): - result['file_path'] = f"{self.base_file_url}/{result['file_path']}" # type: ignore + result[ # type: ignore[index] + 'file_path' + ] = f"{self.base_file_url}/{result['file_path']}" # type: ignore[index] return File.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def ban_chat_member( + async def ban_chat_member( self, chat_id: Union[str, int], user_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, until_date: Union[int, datetime] = None, api_kwargs: JSONDict = None, revoke_messages: bool = None, @@ -2446,16 +2721,27 @@ def ban_chat_member( if revoke_messages is not None: data['revoke_messages'] = revoke_messages - result = self._post('banChatMember', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'banChatMember', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def ban_chat_sender_chat( + async def ban_chat_sender_chat( self, chat_id: Union[str, int], sender_chat_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -2485,16 +2771,27 @@ def ban_chat_sender_chat( """ data: JSONDict = {'chat_id': chat_id, 'sender_chat_id': sender_chat_id} - result = self._post('banChatSenderChat', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'banChatSenderChat', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def unban_chat_member( + async def unban_chat_member( self, chat_id: Union[str, int], user_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, only_if_banned: bool = None, ) -> bool: @@ -2529,16 +2826,27 @@ def unban_chat_member( if only_if_banned is not None: data['only_if_banned'] = only_if_banned - result = self._post('unbanChatMember', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'unbanChatMember', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def unban_chat_sender_chat( + async def unban_chat_sender_chat( self, chat_id: Union[str, int], sender_chat_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to unban a previously banned channel in a supergroup or channel. @@ -2566,19 +2874,30 @@ def unban_chat_sender_chat( """ data: JSONDict = {'chat_id': chat_id, 'sender_chat_id': sender_chat_id} - result = self._post('unbanChatSenderChat', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'unbanChatSenderChat', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def answer_callback_query( + async def answer_callback_query( self, callback_query_id: str, text: str = None, show_alert: bool = False, url: str = None, cache_time: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -2630,12 +2949,20 @@ def answer_callback_query( if cache_time is not None: data['cache_time'] = cache_time - result = self._post('answerCallbackQuery', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'answerCallbackQuery', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def edit_message_text( + async def edit_message_text( self, text: str, chat_id: Union[str, int] = None, @@ -2644,7 +2971,10 @@ def edit_message_text( parse_mode: ODVInput[str] = DEFAULT_NONE, disable_web_page_preview: ODVInput[bool] = DEFAULT_NONE, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, ) -> Union[Message, bool]: @@ -2700,23 +3030,29 @@ def edit_message_text( if entities: data['entities'] = [me.to_dict() for me in entities] - return self._message( + return await self._send_message( 'editMessageText', data, - timeout=timeout, reply_markup=reply_markup, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def edit_message_caption( + async def edit_message_caption( self, chat_id: Union[str, int] = None, message_id: int = None, inline_message_id: int = None, caption: str = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -2768,7 +3104,7 @@ def edit_message_caption( if caption: data['caption'] = caption if caption_entities: - data['caption_entities'] = [me.to_dict() for me in caption_entities] + data['caption_entities'] = caption_entities if chat_id: data['chat_id'] = chat_id if message_id: @@ -2776,23 +3112,29 @@ def edit_message_caption( if inline_message_id: data['inline_message_id'] = inline_message_id - return self._message( + return await self._send_message( 'editMessageCaption', data, - timeout=timeout, reply_markup=reply_markup, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def edit_message_media( + async def edit_message_media( self, media: 'InputMedia', chat_id: Union[str, int] = None, message_id: int = None, inline_message_id: int = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """ @@ -2842,22 +3184,28 @@ def edit_message_media( if inline_message_id: data['inline_message_id'] = inline_message_id - return self._message( + return await self._send_message( 'editMessageMedia', data, - timeout=timeout, reply_markup=reply_markup, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def edit_message_reply_markup( + async def edit_message_reply_markup( self, chat_id: Union[str, int] = None, message_id: int = None, inline_message_id: int = None, reply_markup: Optional['InlineKeyboardMarkup'] = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """ @@ -2903,21 +3251,27 @@ def edit_message_reply_markup( if inline_message_id: data['inline_message_id'] = inline_message_id - return self._message( + return await self._send_message( 'editMessageReplyMarkup', data, - timeout=timeout, reply_markup=reply_markup, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def get_updates( + async def get_updates( self, offset: int = None, limit: int = 100, - timeout: float = 0, - read_latency: float = 2.0, + timeout: int = 0, + read_timeout: float = 2, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, allowed_updates: List[str] = None, api_kwargs: JSONDict = None, ) -> List[Update]: @@ -2980,27 +3334,33 @@ def get_updates( # dropped in real time. result = cast( List[JSONDict], - self._post( + await self._post( 'getUpdates', data, - timeout=float(read_latency) + float(timeout), + read_timeout=read_timeout + timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ), ) if result: - self.logger.debug('Getting updates: %s', [u['update_id'] for u in result]) + self._logger.debug('Getting updates: %s', [u['update_id'] for u in result]) else: - self.logger.debug('No new updates found.') + self._logger.debug('No new updates found.') return Update.de_list(result, self) # type: ignore[return-value] @_log - def set_webhook( + async def set_webhook( self, url: str, certificate: FileInput = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, max_connections: int = 40, allowed_updates: List[str] = None, api_kwargs: JSONDict = None, @@ -3083,14 +3443,25 @@ def set_webhook( if drop_pending_updates: data['drop_pending_updates'] = drop_pending_updates - result = self._post('setWebhook', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'setWebhook', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def delete_webhook( + async def delete_webhook( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, drop_pending_updates: bool = None, ) -> bool: @@ -3119,15 +3490,26 @@ def delete_webhook( if drop_pending_updates: data['drop_pending_updates'] = drop_pending_updates - result = self._post('deleteWebhook', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'deleteWebhook', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def leave_chat( + async def leave_chat( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method for your bot to leave a group, supergroup or channel. @@ -3150,15 +3532,26 @@ def leave_chat( """ data: JSONDict = {'chat_id': chat_id} - result = self._post('leaveChat', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'leaveChat', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def get_chat( + async def get_chat( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Chat: """ @@ -3183,15 +3576,26 @@ def get_chat( """ data: JSONDict = {'chat_id': chat_id} - result = self._post('getChat', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'getChat', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return Chat.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def get_chat_administrators( + async def get_chat_administrators( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> List[ChatMember]: """ @@ -3217,16 +3621,25 @@ def get_chat_administrators( """ data: JSONDict = {'chat_id': chat_id} - - result = self._post('getChatAdministrators', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'getChatAdministrators', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return ChatMember.de_list(result, self) # type: ignore @_log - def get_chat_member_count( + async def get_chat_member_count( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> int: """Use this method to get the number of members in a chat. @@ -3250,17 +3663,26 @@ def get_chat_member_count( """ data: JSONDict = {'chat_id': chat_id} - - result = self._post('getChatMemberCount', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'getChatMemberCount', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def get_chat_member( + async def get_chat_member( self, chat_id: Union[str, int], user_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> ChatMember: """Use this method to get information about a member of a chat. @@ -3283,17 +3705,26 @@ def get_chat_member( """ data: JSONDict = {'chat_id': chat_id, 'user_id': user_id} - - result = self._post('getChatMember', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'getChatMember', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return ChatMember.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def set_chat_sticker_set( + async def set_chat_sticker_set( self, chat_id: Union[str, int], sticker_set_name: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to set a new group sticker set for a supergroup. @@ -3316,16 +3747,25 @@ def set_chat_sticker_set( :obj:`bool`: On success, :obj:`True` is returned. """ data: JSONDict = {'chat_id': chat_id, 'sticker_set_name': sticker_set_name} - - result = self._post('setChatStickerSet', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'setChatStickerSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def delete_chat_sticker_set( + async def delete_chat_sticker_set( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to delete a group sticker set from a supergroup. The bot must be an @@ -3346,13 +3786,24 @@ def delete_chat_sticker_set( :obj:`bool`: On success, :obj:`True` is returned. """ data: JSONDict = {'chat_id': chat_id} - - result = self._post('deleteChatStickerSet', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'deleteChatStickerSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] - def get_webhook_info( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_webhook_info( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> WebhookInfo: """Use this method to get current webhook status. Requires no parameters. @@ -3370,12 +3821,19 @@ def get_webhook_info( :class:`telegram.WebhookInfo` """ - result = self._post('getWebhookInfo', None, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'getWebhookInfo', + None, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return WebhookInfo.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def set_game_score( + async def set_game_score( self, user_id: Union[int, str], score: int, @@ -3384,7 +3842,10 @@ def set_game_score( inline_message_id: int = None, force: bool = None, disable_edit_message: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """ @@ -3431,21 +3892,27 @@ def set_game_score( if disable_edit_message is not None: data['disable_edit_message'] = disable_edit_message - return self._message( + return await self._send_message( 'setGameScore', data, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @_log - def get_game_high_scores( + async def get_game_high_scores( self, user_id: Union[int, str], chat_id: Union[str, int] = None, message_id: int = None, inline_message_id: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> List[GameHighScore]: """ @@ -3487,12 +3954,20 @@ def get_game_high_scores( if inline_message_id: data['inline_message_id'] = inline_message_id - result = self._post('getGameHighScores', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'getGameHighScores', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return GameHighScore.de_list(result, self) # type: ignore @_log - def send_invoice( + async def send_invoice( self, chat_id: Union[int, str], title: str, @@ -3517,7 +3992,10 @@ def send_invoice( provider_data: Union[str, object] = None, send_phone_number_to_provider: bool = None, send_email_to_provider: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, max_tip_amount: int = None, @@ -3629,7 +4107,7 @@ def send_invoice( 'payload': payload, 'provider_token': provider_token, 'currency': currency, - 'prices': [p.to_dict() for p in prices], + 'prices': prices, } if max_tip_amount is not None: data['max_tip_amount'] = max_tip_amount @@ -3665,26 +4143,32 @@ def send_invoice( if send_email_to_provider is not None: data['send_email_to_provider'] = send_email_to_provider - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendInvoice', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def answer_shipping_query( # pylint: disable=invalid-name + async def answer_shipping_query( # pylint: disable=invalid-name self, shipping_query_id: str, ok: bool, shipping_options: List[ShippingOption] = None, error_message: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -3742,17 +4226,28 @@ def answer_shipping_query( # pylint: disable=invalid-name if error_message is not None: data['error_message'] = error_message - result = self._post('answerShippingQuery', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'answerShippingQuery', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def answer_pre_checkout_query( # pylint: disable=invalid-name + async def answer_pre_checkout_query( # pylint: disable=invalid-name self, pre_checkout_query_id: str, ok: bool, error_message: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -3801,18 +4296,29 @@ def answer_pre_checkout_query( # pylint: disable=invalid-name if error_message is not None: data['error_message'] = error_message - result = self._post('answerPreCheckoutQuery', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'answerPreCheckoutQuery', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def restrict_chat_member( + async def restrict_chat_member( self, chat_id: Union[str, int], user_id: Union[str, int], permissions: ChatPermissions, until_date: Union[int, datetime] = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -3852,18 +4358,26 @@ def restrict_chat_member( data: JSONDict = { 'chat_id': chat_id, 'user_id': user_id, - 'permissions': permissions.to_dict(), + 'permissions': permissions, } if until_date is not None: data['until_date'] = until_date - result = self._post('restrictChatMember', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'restrictChatMember', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def promote_chat_member( + async def promote_chat_member( self, chat_id: Union[str, int], user_id: Union[str, int], @@ -3875,7 +4389,10 @@ def promote_chat_member( can_restrict_members: bool = None, can_pin_messages: bool = None, can_promote_members: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, is_anonymous: bool = None, can_manage_chat: bool = None, @@ -3960,16 +4477,27 @@ def promote_chat_member( if can_manage_voice_chats is not None: data['can_manage_voice_chats'] = can_manage_voice_chats - result = self._post('promoteChatMember', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'promoteChatMember', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_chat_permissions( + async def set_chat_permissions( self, chat_id: Union[str, int], permissions: ChatPermissions, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -3994,19 +4522,28 @@ def set_chat_permissions( :class:`telegram.error.TelegramError` """ - data: JSONDict = {'chat_id': chat_id, 'permissions': permissions.to_dict()} - - result = self._post('setChatPermissions', data, timeout=timeout, api_kwargs=api_kwargs) - + data: JSONDict = {'chat_id': chat_id, 'permissions': permissions} + result = await self._post( + 'setChatPermissions', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_chat_administrator_custom_title( + async def set_chat_administrator_custom_title( self, chat_id: Union[int, str], user_id: Union[int, str], custom_title: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -4034,17 +4571,26 @@ def set_chat_administrator_custom_title( """ data: JSONDict = {'chat_id': chat_id, 'user_id': user_id, 'custom_title': custom_title} - result = self._post( - 'setChatAdministratorCustomTitle', data, timeout=timeout, api_kwargs=api_kwargs + result = await self._post( + 'setChatAdministratorCustomTitle', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) return result # type: ignore[return-value] @_log - def export_chat_invite_link( + async def export_chat_invite_link( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> str: """ @@ -4076,18 +4622,27 @@ def export_chat_invite_link( """ data: JSONDict = {'chat_id': chat_id} - - result = self._post('exportChatInviteLink', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'exportChatInviteLink', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def create_chat_invite_link( + async def create_chat_invite_link( self, chat_id: Union[str, int], expire_date: Union[int, datetime] = None, member_limit: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, name: str = None, creates_join_request: bool = None, @@ -4152,18 +4707,29 @@ def create_chat_invite_link( if creates_join_request is not None: data['creates_join_request'] = creates_join_request - result = self._post('createChatInviteLink', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'createChatInviteLink', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return ChatInviteLink.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def edit_chat_invite_link( + async def edit_chat_invite_link( self, chat_id: Union[str, int], invite_link: Union[str, 'ChatInviteLink'], expire_date: Union[int, datetime] = None, member_limit: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, name: str = None, creates_join_request: bool = None, @@ -4236,16 +4802,27 @@ def edit_chat_invite_link( if creates_join_request is not None: data['creates_join_request'] = creates_join_request - result = self._post('editChatInviteLink', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'editChatInviteLink', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return ChatInviteLink.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def revoke_chat_invite_link( + async def revoke_chat_invite_link( self, chat_id: Union[str, int], invite_link: Union[str, 'ChatInviteLink'], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> ChatInviteLink: """ @@ -4278,16 +4855,27 @@ def revoke_chat_invite_link( link = invite_link.invite_link if isinstance(invite_link, ChatInviteLink) else invite_link data: JSONDict = {'chat_id': chat_id, 'invite_link': link} - result = self._post('revokeChatInviteLink', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'revokeChatInviteLink', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return ChatInviteLink.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def approve_chat_join_request( + async def approve_chat_join_request( self, chat_id: Union[str, int], user_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to approve a chat join request. @@ -4315,16 +4903,27 @@ def approve_chat_join_request( """ data: JSONDict = {'chat_id': chat_id, 'user_id': user_id} - result = self._post('approveChatJoinRequest', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'approveChatJoinRequest', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def decline_chat_join_request( + async def decline_chat_join_request( self, chat_id: Union[str, int], user_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to decline a chat join request. @@ -4352,16 +4951,27 @@ def decline_chat_join_request( """ data: JSONDict = {'chat_id': chat_id, 'user_id': user_id} - result = self._post('declineChatJoinRequest', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'declineChatJoinRequest', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_chat_photo( + async def set_chat_photo( self, chat_id: Union[str, int], photo: FileInput, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to set a new profile photo for the chat. @@ -4390,16 +5000,25 @@ def set_chat_photo( """ data: JSONDict = {'chat_id': chat_id, 'photo': parse_file_input(photo)} - - result = self._post('setChatPhoto', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'setChatPhoto', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def delete_chat_photo( + async def delete_chat_photo( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -4424,17 +5043,26 @@ def delete_chat_photo( """ data: JSONDict = {'chat_id': chat_id} - - result = self._post('deleteChatPhoto', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'deleteChatPhoto', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_chat_title( + async def set_chat_title( self, chat_id: Union[str, int], title: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -4460,17 +5088,26 @@ def set_chat_title( """ data: JSONDict = {'chat_id': chat_id, 'title': title} - - result = self._post('setChatTitle', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'setChatTitle', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_chat_description( + async def set_chat_description( self, chat_id: Union[str, int], description: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -4499,18 +5136,27 @@ def set_chat_description( if description is not None: data['description'] = description - - result = self._post('setChatDescription', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'setChatDescription', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def pin_chat_message( + async def pin_chat_message( self, chat_id: Union[str, int], message_id: int, disable_notification: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -4545,15 +5191,24 @@ def pin_chat_message( 'disable_notification': disable_notification, } - return self._post( # type: ignore[return-value] - 'pinChatMessage', data, timeout=timeout, api_kwargs=api_kwargs + return await self._post( # type: ignore[return-value] + 'pinChatMessage', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) @_log - def unpin_chat_message( + async def unpin_chat_message( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, message_id: int = None, ) -> bool: @@ -4586,15 +5241,24 @@ def unpin_chat_message( if message_id is not None: data['message_id'] = message_id - return self._post( # type: ignore[return-value] - 'unpinChatMessage', data, timeout=timeout, api_kwargs=api_kwargs + return await self._post( # type: ignore[return-value] + 'unpinChatMessage', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) @_log - def unpin_all_chat_messages( + async def unpin_all_chat_messages( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -4620,16 +5284,24 @@ def unpin_all_chat_messages( """ data: JSONDict = {'chat_id': chat_id} - - return self._post( # type: ignore[return-value] - 'unpinAllChatMessages', data, timeout=timeout, api_kwargs=api_kwargs + return await self._post( # type: ignore[return-value] + 'unpinAllChatMessages', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) @_log - def get_sticker_set( + async def get_sticker_set( self, name: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> StickerSet: """Use this method to get a sticker set. @@ -4650,17 +5322,26 @@ def get_sticker_set( """ data: JSONDict = {'name': name} - - result = self._post('getStickerSet', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'getStickerSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return StickerSet.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def upload_sticker_file( + async def upload_sticker_file( self, user_id: Union[str, int], png_sticker: FileInput, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> File: """ @@ -4694,13 +5375,19 @@ def upload_sticker_file( """ data: JSONDict = {'user_id': user_id, 'png_sticker': parse_file_input(png_sticker)} - - result = self._post('uploadStickerFile', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'uploadStickerFile', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return File.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def create_new_sticker_set( + async def create_new_sticker_set( self, user_id: Union[str, int], name: str, @@ -4709,7 +5396,10 @@ def create_new_sticker_set( png_sticker: FileInput = None, contains_masks: bool = None, mask_position: MaskPosition = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, tgs_sticker: FileInput = None, api_kwargs: JSONDict = None, webm_sticker: FileInput = None, @@ -4790,23 +5480,32 @@ def create_new_sticker_set( if contains_masks is not None: data['contains_masks'] = contains_masks if mask_position is not None: - # We need to_json() instead of to_dict() here, because we're sending a media - # message here, which isn't json dumped by telegram.request - data['mask_position'] = mask_position.to_json() + data['mask_position'] = mask_position - result = self._post('createNewStickerSet', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'createNewStickerSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def add_sticker_to_set( + async def add_sticker_to_set( self, user_id: Union[str, int], name: str, emojis: str, png_sticker: FileInput = None, mask_position: MaskPosition = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, tgs_sticker: FileInput = None, api_kwargs: JSONDict = None, webm_sticker: FileInput = None, @@ -4879,20 +5578,29 @@ def add_sticker_to_set( if webm_sticker is not None: data['webm_sticker'] = parse_file_input(webm_sticker) if mask_position is not None: - # We need to_json() instead of to_dict() here, because we're sending a media - # message here, which isn't json dumped by telegram.request - data['mask_position'] = mask_position.to_json() + data['mask_position'] = mask_position - result = self._post('addStickerToSet', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'addStickerToSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_sticker_position_in_set( + async def set_sticker_position_in_set( self, sticker: str, position: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to move a sticker in a set created by the bot to a specific position. @@ -4914,18 +5622,25 @@ def set_sticker_position_in_set( """ data: JSONDict = {'sticker': sticker, 'position': position} - - result = self._post( - 'setStickerPositionInSet', data, timeout=timeout, api_kwargs=api_kwargs + result = await self._post( + 'setStickerPositionInSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - return result # type: ignore[return-value] @_log - def delete_sticker_from_set( + async def delete_sticker_from_set( self, sticker: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to delete a sticker from a set created by the bot. @@ -4946,18 +5661,27 @@ def delete_sticker_from_set( """ data: JSONDict = {'sticker': sticker} - - result = self._post('deleteStickerFromSet', data, timeout=timeout, api_kwargs=api_kwargs) - + result = await self._post( + 'deleteStickerFromSet', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_sticker_set_thumb( + async def set_sticker_set_thumb( self, name: str, user_id: Union[str, int], thumb: FileInput = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Use this method to set the thumbnail of a sticker set. Animated thumbnails can be set @@ -4999,20 +5723,30 @@ def set_sticker_set_thumb( """ data: JSONDict = {'name': name, 'user_id': user_id} - if thumb is not None: data['thumb'] = parse_file_input(thumb) - result = self._post('setStickerSetThumb', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'setStickerSetThumb', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def set_passport_data_errors( + async def set_passport_data_errors( self, user_id: Union[str, int], errors: List[PassportElementError], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """ @@ -5042,14 +5776,20 @@ def set_passport_data_errors( :class:`telegram.error.TelegramError` """ - data: JSONDict = {'user_id': user_id, 'errors': [error.to_dict() for error in errors]} - - result = self._post('setPassportDataErrors', data, timeout=timeout, api_kwargs=api_kwargs) - + data: JSONDict = {'user_id': user_id, 'errors': errors} + result = await self._post( + 'setPassportDataErrors', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def send_poll( + async def send_poll( self, chat_id: Union[int, str], question: str, @@ -5062,7 +5802,10 @@ def send_poll( disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, explanation: str = None, explanation_parse_mode: ODVInput[str] = DEFAULT_NONE, open_period: int = None, @@ -5158,31 +5901,37 @@ def send_poll( if explanation: data['explanation'] = explanation if explanation_entities: - data['explanation_entities'] = [me.to_dict() for me in explanation_entities] + data['explanation_entities'] = explanation_entities if open_period: data['open_period'] = open_period if close_date: data['close_date'] = close_date - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendPoll', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def stop_poll( + async def stop_poll( self, chat_id: Union[int, str], message_id: int, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Poll: """ @@ -5210,26 +5959,30 @@ def stop_poll( data: JSONDict = {'chat_id': chat_id, 'message_id': message_id} if reply_markup: - markups = (InlineKeyboardMarkup, ReplyKeyboardMarkup, ForceReply, ReplyKeyboardRemove) - if isinstance(reply_markup, markups): - # We need to_json() instead of to_dict() here, because reply_markups may be - # attached to media messages, which aren't json dumped by telegram.request - data['reply_markup'] = reply_markup.to_json() - else: - data['reply_markup'] = reply_markup - - result = self._post('stopPoll', data, timeout=timeout, api_kwargs=api_kwargs) + data['reply_markup'] = reply_markup + result = await self._post( + 'stopPoll', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return Poll.de_json(result, self) # type: ignore[return-value, arg-type] @_log - def send_dice( + async def send_dice( self, chat_id: Union[int, str], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, emoji: str = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -5282,26 +6035,31 @@ def send_dice( """ data: JSONDict = {'chat_id': chat_id} - if emoji: data['emoji'] = emoji - return self._message( # type: ignore[return-value] + return await self._send_message( # type: ignore[return-value] 'sendDice', data, - timeout=timeout, - disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, + disable_notification=disable_notification, reply_markup=reply_markup, allow_sending_without_reply=allow_sending_without_reply, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @_log - def get_my_commands( + async def get_my_commands( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, scope: BotCommandScope = None, language_code: str = None, @@ -5337,20 +6095,31 @@ def get_my_commands( data: JSONDict = {} if scope: - data['scope'] = scope.to_dict() + data['scope'] = scope if language_code: data['language_code'] = language_code - result = self._post('getMyCommands', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'getMyCommands', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return BotCommand.de_list(result, self) # type: ignore[return-value,arg-type] @_log - def set_my_commands( + async def set_my_commands( self, commands: List[Union[BotCommand, Tuple[str, str]]], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, scope: BotCommandScope = None, language_code: str = None, @@ -5389,16 +6158,23 @@ def set_my_commands( """ cmds = [c if isinstance(c, BotCommand) else BotCommand(c[0], c[1]) for c in commands] - - data: JSONDict = {'commands': [c.to_dict() for c in cmds]} + data: JSONDict = {'commands': cmds} if scope: - data['scope'] = scope.to_dict() + data['scope'] = scope if language_code: data['language_code'] = language_code - result = self._post('setMyCommands', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'setMyCommands', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @@ -5408,7 +6184,10 @@ def delete_my_commands( scope: BotCommandScope = None, language_code: str = None, api_kwargs: JSONDict = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> bool: """ Use this method to delete the list of the bot's commands for the given scope and user @@ -5440,17 +6219,31 @@ def delete_my_commands( data: JSONDict = {} if scope: - data['scope'] = scope.to_dict() + data['scope'] = scope if language_code: data['language_code'] = language_code - result = self._post('deleteMyCommands', data, timeout=timeout, api_kwargs=api_kwargs) + result = self._post( + 'deleteMyCommands', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return result # type: ignore[return-value] @_log - def log_out(self, timeout: ODVInput[float] = DEFAULT_NONE) -> bool: + async def log_out( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> bool: """ Use this method to log out from the cloud Bot API server before launching the bot locally. You *must* log out the bot before running it locally, otherwise there is no guarantee that @@ -5470,10 +6263,22 @@ def log_out(self, timeout: ODVInput[float] = DEFAULT_NONE) -> bool: :class:`telegram.error.TelegramError` """ - return self._post('logOut', timeout=timeout) # type: ignore[return-value] + return await self._post( # type: ignore[return-value] + 'logOut', + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) @_log - def close(self, timeout: ODVInput[float] = DEFAULT_NONE) -> bool: + async def close( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> bool: """ Use this method to close the bot instance before moving it from one local server to another. You need to delete the webhook before calling this method to ensure that the bot @@ -5492,10 +6297,16 @@ def close(self, timeout: ODVInput[float] = DEFAULT_NONE) -> bool: :class:`telegram.error.TelegramError` """ - return self._post('close', timeout=timeout) # type: ignore[return-value] + return await self._post( # type: ignore[return-value] + 'close', + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) @_log - def copy_message( + async def copy_message( self, chat_id: Union[int, str], from_chat_id: Union[str, int], @@ -5507,7 +6318,10 @@ def copy_message( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> MessageId: @@ -5573,15 +6387,17 @@ def copy_message( if reply_to_message_id: data['reply_to_message_id'] = reply_to_message_id if reply_markup: - markups = (InlineKeyboardMarkup, ReplyKeyboardMarkup, ForceReply, ReplyKeyboardRemove) - if isinstance(reply_markup, markups): - # We need to_json() instead of to_dict() here, because reply_markups may be - # attached to media messages, which aren't json dumped by telegram.request - data['reply_markup'] = reply_markup.to_json() - else: - data['reply_markup'] = reply_markup + data['reply_markup'] = reply_markup - result = self._post('copyMessage', data, timeout=timeout, api_kwargs=api_kwargs) + result = await self._post( + 'copyMessage', + data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return MessageId.de_json(result, self) # type: ignore[return-value, arg-type] def to_dict(self) -> JSONDict: diff --git a/telegram/_callbackquery.py b/telegram/_callbackquery.py index 60abd5b7a45..f2e7e1d97da 100644 --- a/telegram/_callbackquery.py +++ b/telegram/_callbackquery.py @@ -141,13 +141,16 @@ def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['CallbackQuer return cls(bot=bot, **data) - def answer( + async def answer( self, text: str = None, show_alert: bool = False, url: str = None, cache_time: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -161,23 +164,29 @@ def answer( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().answer_callback_query( + return await self.get_bot().answer_callback_query( callback_query_id=self.id, text=text, show_alert=show_alert, url=url, cache_time=cache_time, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def edit_message_text( + async def edit_message_text( self, text: str, parse_mode: ODVInput[str] = DEFAULT_NONE, disable_web_page_preview: ODVInput[bool] = DEFAULT_NONE, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, ) -> Union[Message, bool]: @@ -199,33 +208,42 @@ def edit_message_text( """ if self.inline_message_id: - return self.get_bot().edit_message_text( + return await self.get_bot().edit_message_text( inline_message_id=self.inline_message_id, text=text, parse_mode=parse_mode, disable_web_page_preview=disable_web_page_preview, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, entities=entities, chat_id=None, message_id=None, ) - return self.message.edit_text( + return await self.message.edit_text( text=text, parse_mode=parse_mode, disable_web_page_preview=disable_web_page_preview, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, entities=entities, ) - def edit_message_caption( + async def edit_message_caption( self, caption: str = None, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -249,30 +267,39 @@ def edit_message_caption( """ if self.inline_message_id: - return self.get_bot().edit_message_caption( + return await self.get_bot().edit_message_caption( caption=caption, inline_message_id=self.inline_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, caption_entities=caption_entities, chat_id=None, message_id=None, ) - return self.message.edit_caption( + return await self.message.edit_caption( caption=caption, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, caption_entities=caption_entities, ) - def edit_message_reply_markup( + async def edit_message_reply_markup( self, reply_markup: Optional['InlineKeyboardMarkup'] = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """Shortcut for either:: @@ -302,25 +329,34 @@ def edit_message_reply_markup( """ if self.inline_message_id: - return self.get_bot().edit_message_reply_markup( + return await self.get_bot().edit_message_reply_markup( reply_markup=reply_markup, inline_message_id=self.inline_message_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, chat_id=None, message_id=None, ) - return self.message.edit_reply_markup( + return await self.message.edit_reply_markup( reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def edit_message_media( + async def edit_message_media( self, media: 'InputMedia', reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """Shortcut for either:: @@ -341,29 +377,38 @@ def edit_message_media( """ if self.inline_message_id: - return self.get_bot().edit_message_media( + return await self.get_bot().edit_message_media( inline_message_id=self.inline_message_id, media=media, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, chat_id=None, message_id=None, ) - return self.message.edit_media( + return await self.message.edit_media( media=media, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def edit_message_live_location( + async def edit_message_live_location( self, latitude: float = None, longitude: float = None, location: Location = None, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, horizontal_accuracy: float = None, heading: int = None, @@ -390,13 +435,16 @@ def edit_message_live_location( """ if self.inline_message_id: - return self.get_bot().edit_message_live_location( + return await self.get_bot().edit_message_live_location( inline_message_id=self.inline_message_id, latitude=latitude, longitude=longitude, location=location, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, horizontal_accuracy=horizontal_accuracy, heading=heading, @@ -404,22 +452,28 @@ def edit_message_live_location( chat_id=None, message_id=None, ) - return self.message.edit_live_location( + return await self.message.edit_live_location( latitude=latitude, longitude=longitude, location=location, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, horizontal_accuracy=horizontal_accuracy, heading=heading, proximity_alert_radius=proximity_alert_radius, ) - def stop_message_live_location( + async def stop_message_live_location( self, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """Shortcut for either:: @@ -443,27 +497,36 @@ def stop_message_live_location( """ if self.inline_message_id: - return self.get_bot().stop_message_live_location( + return await self.get_bot().stop_message_live_location( inline_message_id=self.inline_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, chat_id=None, message_id=None, ) - return self.message.stop_live_location( + return await self.message.stop_live_location( reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def set_game_score( + async def set_game_score( self, user_id: Union[int, str], score: int, force: bool = None, disable_edit_message: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union[Message, bool]: """Shortcut for either:: @@ -484,30 +547,39 @@ def set_game_score( """ if self.inline_message_id: - return self.get_bot().set_game_score( + return await self.get_bot().set_game_score( inline_message_id=self.inline_message_id, user_id=user_id, score=score, force=force, disable_edit_message=disable_edit_message, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, chat_id=None, message_id=None, ) - return self.message.set_game_score( + return await self.message.set_game_score( user_id=user_id, score=score, force=force, disable_edit_message=disable_edit_message, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def get_game_high_scores( + async def get_game_high_scores( self, user_id: Union[int, str], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> List['GameHighScore']: """Shortcut for either:: @@ -527,23 +599,32 @@ def get_game_high_scores( """ if self.inline_message_id: - return self.get_bot().get_game_high_scores( + return await self.get_bot().get_game_high_scores( inline_message_id=self.inline_message_id, user_id=user_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, chat_id=None, message_id=None, ) - return self.message.get_game_high_scores( + return await self.message.get_game_high_scores( user_id=user_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def delete_message( + async def delete_message( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -557,15 +638,21 @@ def delete_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.message.delete( - timeout=timeout, + return await self.message.delete( + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def pin_message( + async def pin_message( self, disable_notification: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -579,15 +666,21 @@ def pin_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.message.pin( + return await self.message.pin( disable_notification=disable_notification, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def unpin_message( + async def unpin_message( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -601,12 +694,15 @@ def unpin_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.message.unpin( - timeout=timeout, + return await self.message.unpin( + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def copy_message( + async def copy_message( self, chat_id: Union[int, str], caption: str = None, @@ -616,7 +712,10 @@ def copy_message( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'MessageId': @@ -636,7 +735,7 @@ def copy_message( :class:`telegram.MessageId`: On success, returns the MessageId of the sent message. """ - return self.message.copy( + return await self.message.copy( chat_id=chat_id, caption=caption, parse_mode=parse_mode, @@ -645,7 +744,10 @@ def copy_message( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) diff --git a/telegram/_chat.py b/telegram/_chat.py index 8eda5ec351c..755adaadaa8 100644 --- a/telegram/_chat.py +++ b/telegram/_chat.py @@ -23,7 +23,7 @@ from telegram import ChatPhoto, TelegramObject, constants from telegram._utils.types import JSONDict, FileInput, ODVInput, DVInput, ReplyMarkup -from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_20 +from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._chatpermissions import ChatPermissions from telegram._chatlocation import ChatLocation @@ -300,7 +300,14 @@ def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['Chat']: return cls(bot=bot, **data) - def leave(self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None) -> bool: + async def leave( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, + ) -> bool: """Shortcut for:: bot.leave_chat(update.effective_chat.id, *args, **kwargs) @@ -311,14 +318,22 @@ def leave(self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().leave_chat( + return await self.get_bot().leave_chat( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def get_administrators( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_administrators( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> List['ChatMember']: """Shortcut for:: @@ -334,14 +349,22 @@ def get_administrators( and no administrators were appointed, only the creator will be returned. """ - return self.get_bot().get_chat_administrators( + return await self.get_bot().get_chat_administrators( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def get_member_count( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_member_count( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> int: """Shortcut for:: @@ -353,16 +376,22 @@ def get_member_count( Returns: :obj:`int` """ - return self.get_bot().get_chat_member_count( + return await self.get_bot().get_chat_member_count( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def get_member( + async def get_member( self, user_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> 'ChatMember': """Shortcut for:: @@ -375,17 +404,23 @@ def get_member( :class:`telegram.ChatMember` """ - return self.get_bot().get_chat_member( + return await self.get_bot().get_chat_member( chat_id=self.id, user_id=user_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def ban_member( + async def ban_member( self, user_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, until_date: Union[int, datetime] = None, api_kwargs: JSONDict = None, revoke_messages: bool = None, @@ -400,19 +435,25 @@ def ban_member( Returns: :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().ban_chat_member( + return await self.get_bot().ban_chat_member( chat_id=self.id, user_id=user_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, until_date=until_date, api_kwargs=api_kwargs, revoke_messages=revoke_messages, ) - def ban_sender_chat( + async def ban_sender_chat( self, sender_chat_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -428,14 +469,23 @@ def ban_sender_chat( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().ban_chat_sender_chat( - chat_id=self.id, sender_chat_id=sender_chat_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().ban_chat_sender_chat( + chat_id=self.id, + sender_chat_id=sender_chat_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def ban_chat( + async def ban_chat( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -451,14 +501,23 @@ def ban_chat( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().ban_chat_sender_chat( - chat_id=chat_id, sender_chat_id=self.id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().ban_chat_sender_chat( + chat_id=chat_id, + sender_chat_id=self.id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def unban_sender_chat( + async def unban_sender_chat( self, sender_chat_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -474,14 +533,23 @@ def unban_sender_chat( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unban_chat_sender_chat( - chat_id=self.id, sender_chat_id=sender_chat_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().unban_chat_sender_chat( + chat_id=self.id, + sender_chat_id=sender_chat_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def unban_chat( + async def unban_chat( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -497,14 +565,23 @@ def unban_chat( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unban_chat_sender_chat( - chat_id=chat_id, sender_chat_id=self.id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().unban_chat_sender_chat( + chat_id=chat_id, + sender_chat_id=self.id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def unban_member( + async def unban_member( self, user_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, only_if_banned: bool = None, ) -> bool: @@ -518,15 +595,18 @@ def unban_member( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unban_chat_member( + return await self.get_bot().unban_chat_member( chat_id=self.id, user_id=user_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, only_if_banned=only_if_banned, ) - def promote_member( + async def promote_member( self, user_id: Union[str, int], can_change_info: bool = None, @@ -537,7 +617,10 @@ def promote_member( can_restrict_members: bool = None, can_pin_messages: bool = None, can_promote_members: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, is_anonymous: bool = None, can_manage_chat: bool = None, @@ -556,7 +639,7 @@ def promote_member( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().promote_chat_member( + return await self.get_bot().promote_chat_member( chat_id=self.id, user_id=user_id, can_change_info=can_change_info, @@ -567,19 +650,25 @@ def promote_member( can_restrict_members=can_restrict_members, can_pin_messages=can_pin_messages, can_promote_members=can_promote_members, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, is_anonymous=is_anonymous, can_manage_chat=can_manage_chat, can_manage_voice_chats=can_manage_voice_chats, ) - def restrict_member( + async def restrict_member( self, user_id: Union[str, int], permissions: ChatPermissions, until_date: Union[int, datetime] = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -595,19 +684,25 @@ def restrict_member( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().restrict_chat_member( + return await self.get_bot().restrict_chat_member( chat_id=self.id, user_id=user_id, permissions=permissions, until_date=until_date, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def set_permissions( + async def set_permissions( self, permissions: ChatPermissions, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -621,18 +716,24 @@ def set_permissions( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().set_chat_permissions( + return await self.get_bot().set_chat_permissions( chat_id=self.id, permissions=permissions, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def set_administrator_custom_title( + async def set_administrator_custom_title( self, user_id: Union[int, str], custom_title: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -646,19 +747,25 @@ def set_administrator_custom_title( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().set_chat_administrator_custom_title( + return await self.get_bot().set_chat_administrator_custom_title( chat_id=self.id, user_id=user_id, custom_title=custom_title, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def pin_message( + async def pin_message( self, message_id: int, disable_notification: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -674,17 +781,23 @@ def pin_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().pin_chat_message( + return await self.get_bot().pin_chat_message( chat_id=self.id, message_id=message_id, disable_notification=disable_notification, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def unpin_message( + async def unpin_message( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, message_id: int = None, ) -> bool: @@ -701,16 +814,22 @@ def unpin_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unpin_chat_message( + return await self.get_bot().unpin_chat_message( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, message_id=message_id, ) - def unpin_all_messages( + async def unpin_all_messages( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -726,13 +845,16 @@ def unpin_all_messages( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unpin_all_chat_messages( + return await self.get_bot().unpin_all_chat_messages( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def send_message( + async def send_message( self, text: str, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -740,7 +862,10 @@ def send_message( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -756,7 +881,7 @@ def send_message( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_message( + return await self.get_bot().send_message( chat_id=self.id, text=text, parse_mode=parse_mode, @@ -764,21 +889,27 @@ def send_message( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, entities=entities, protect_content=protect_content, ) - def send_media_group( + async def send_media_group( self, media: List[ Union['InputMediaAudio', 'InputMediaDocument', 'InputMediaPhoto', 'InputMediaVideo'] ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -793,21 +924,27 @@ def send_media_group( List[:class:`telegram.Message`]: On success, instance representing the message posted. """ - return self.get_bot().send_media_group( + return await self.get_bot().send_media_group( chat_id=self.id, media=media, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_chat_action( + async def send_chat_action( self, action: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -820,24 +957,30 @@ def send_chat_action( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().send_chat_action( + return await self.get_bot().send_chat_action( chat_id=self.id, action=action, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) send_action = send_chat_action """Alias for :attr:`send_chat_action`""" - def send_photo( + async def send_photo( self, photo: Union[FileInput, 'PhotoSize'], caption: str = None, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -855,14 +998,17 @@ def send_photo( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_photo( + return await self.get_bot().send_photo( chat_id=self.id, photo=photo, caption=caption, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -871,7 +1017,7 @@ def send_photo( protect_content=protect_content, ) - def send_contact( + async def send_contact( self, phone_number: str = None, first_name: str = None, @@ -879,7 +1025,10 @@ def send_contact( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, contact: 'Contact' = None, vcard: str = None, api_kwargs: JSONDict = None, @@ -896,7 +1045,7 @@ def send_contact( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_contact( + return await self.get_bot().send_contact( chat_id=self.id, phone_number=phone_number, first_name=first_name, @@ -904,7 +1053,10 @@ def send_contact( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, contact=contact, vcard=vcard, api_kwargs=api_kwargs, @@ -912,7 +1064,7 @@ def send_contact( protect_content=protect_content, ) - def send_audio( + async def send_audio( self, audio: Union[FileInput, 'Audio'], duration: int = None, @@ -922,7 +1074,10 @@ def send_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -941,7 +1096,7 @@ def send_audio( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_audio( + return await self.get_bot().send_audio( chat_id=self.id, audio=audio, duration=duration, @@ -951,7 +1106,10 @@ def send_audio( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, thumb=thumb, api_kwargs=api_kwargs, @@ -961,7 +1119,7 @@ def send_audio( protect_content=protect_content, ) - def send_document( + async def send_document( self, document: Union[FileInput, 'Document'], filename: str = None, @@ -969,7 +1127,10 @@ def send_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -988,7 +1149,7 @@ def send_document( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_document( + return await self.get_bot().send_document( chat_id=self.id, document=document, filename=filename, @@ -996,7 +1157,10 @@ def send_document( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, thumb=thumb, api_kwargs=api_kwargs, @@ -1006,12 +1170,15 @@ def send_document( protect_content=protect_content, ) - def send_dice( + async def send_dice( self, disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, emoji: str = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1027,25 +1194,31 @@ def send_dice( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_dice( + return await self.get_bot().send_dice( chat_id=self.id, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, emoji=emoji, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_game( + async def send_game( self, game_short_name: str, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -1060,19 +1233,22 @@ def send_game( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_game( + return await self.get_bot().send_game( chat_id=self.id, game_short_name=game_short_name, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_invoice( + async def send_invoice( self, title: str, description: str, @@ -1096,7 +1272,10 @@ def send_invoice( provider_data: Union[str, object] = None, send_phone_number_to_provider: bool = None, send_email_to_provider: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, max_tip_amount: int = None, @@ -1121,7 +1300,7 @@ def send_invoice( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_invoice( + return await self.get_bot().send_invoice( chat_id=self.id, title=title, description=description, @@ -1145,7 +1324,10 @@ def send_invoice( provider_data=provider_data, send_phone_number_to_provider=send_phone_number_to_provider, send_email_to_provider=send_email_to_provider, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, max_tip_amount=max_tip_amount, @@ -1153,14 +1335,17 @@ def send_invoice( protect_content=protect_content, ) - def send_location( + async def send_location( self, latitude: float = None, longitude: float = None, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, location: 'Location' = None, live_period: int = None, api_kwargs: JSONDict = None, @@ -1180,14 +1365,17 @@ def send_location( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_location( + return await self.get_bot().send_location( chat_id=self.id, latitude=latitude, longitude=longitude, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, location=location, live_period=live_period, api_kwargs=api_kwargs, @@ -1198,7 +1386,7 @@ def send_location( protect_content=protect_content, ) - def send_animation( + async def send_animation( self, animation: Union[FileInput, 'Animation'], duration: int = None, @@ -1210,7 +1398,10 @@ def send_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -1227,7 +1418,7 @@ def send_animation( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_animation( + return await self.get_bot().send_animation( chat_id=self.id, animation=animation, duration=duration, @@ -1239,7 +1430,10 @@ def send_animation( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, caption_entities=caption_entities, @@ -1247,13 +1441,16 @@ def send_animation( protect_content=protect_content, ) - def send_sticker( + async def send_sticker( self, sticker: Union[FileInput, 'Sticker'], disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -1268,19 +1465,22 @@ def send_sticker( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_sticker( + return await self.get_bot().send_sticker( chat_id=self.id, sticker=sticker, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_venue( + async def send_venue( self, latitude: float = None, longitude: float = None, @@ -1290,7 +1490,10 @@ def send_venue( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, venue: 'Venue' = None, foursquare_type: str = None, api_kwargs: JSONDict = None, @@ -1309,7 +1512,7 @@ def send_venue( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_venue( + return await self.get_bot().send_venue( chat_id=self.id, latitude=latitude, longitude=longitude, @@ -1319,7 +1522,10 @@ def send_venue( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, venue=venue, foursquare_type=foursquare_type, api_kwargs=api_kwargs, @@ -1329,7 +1535,7 @@ def send_venue( protect_content=protect_content, ) - def send_video( + async def send_video( self, video: Union[FileInput, 'Video'], duration: int = None, @@ -1337,7 +1543,10 @@ def send_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, height: int = None, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1359,7 +1568,7 @@ def send_video( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_video( + return await self.get_bot().send_video( chat_id=self.id, video=video, duration=duration, @@ -1367,7 +1576,10 @@ def send_video( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, width=width, height=height, parse_mode=parse_mode, @@ -1380,7 +1592,7 @@ def send_video( protect_content=protect_content, ) - def send_video_note( + async def send_video_note( self, video_note: Union[FileInput, 'VideoNote'], duration: int = None, @@ -1388,7 +1600,10 @@ def send_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1405,7 +1620,7 @@ def send_video_note( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_video_note( + return await self.get_bot().send_video_note( chat_id=self.id, video_note=video_note, duration=duration, @@ -1413,7 +1628,10 @@ def send_video_note( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, thumb=thumb, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1421,7 +1639,7 @@ def send_video_note( protect_content=protect_content, ) - def send_voice( + async def send_voice( self, voice: Union[FileInput, 'Voice'], duration: int = None, @@ -1429,7 +1647,10 @@ def send_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1447,7 +1668,7 @@ def send_voice( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_voice( + return await self.get_bot().send_voice( chat_id=self.id, voice=voice, duration=duration, @@ -1455,7 +1676,10 @@ def send_voice( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1464,7 +1688,7 @@ def send_voice( protect_content=protect_content, ) - def send_poll( + async def send_poll( self, question: str, options: List[str], @@ -1477,7 +1701,10 @@ def send_poll( disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, explanation: str = None, explanation_parse_mode: ODVInput[str] = DEFAULT_NONE, open_period: int = None, @@ -1497,7 +1724,7 @@ def send_poll( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_poll( + return await self.get_bot().send_poll( chat_id=self.id, question=question, options=options, @@ -1509,7 +1736,10 @@ def send_poll( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, explanation=explanation, explanation_parse_mode=explanation_parse_mode, open_period=open_period, @@ -1520,7 +1750,7 @@ def send_poll( protect_content=protect_content, ) - def send_copy( + async def send_copy( self, from_chat_id: Union[str, int], message_id: int, @@ -1531,7 +1761,10 @@ def send_copy( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'MessageId': @@ -1545,7 +1778,7 @@ def send_copy( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().copy_message( + return await self.get_bot().copy_message( chat_id=self.id, from_chat_id=from_chat_id, message_id=message_id, @@ -1556,12 +1789,15 @@ def send_copy( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def copy_message( + async def copy_message( self, chat_id: Union[int, str], message_id: int, @@ -1572,7 +1808,10 @@ def copy_message( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'MessageId': @@ -1586,7 +1825,7 @@ def copy_message( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().copy_message( + return await self.get_bot().copy_message( from_chat_id=self.id, chat_id=chat_id, message_id=message_id, @@ -1597,14 +1836,20 @@ def copy_message( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def export_invite_link( + async def export_invite_link( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> str: """Shortcut for:: @@ -1620,15 +1865,23 @@ def export_invite_link( :obj:`str`: New invite link on success. """ - return self.get_bot().export_chat_invite_link( - chat_id=self.id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().export_chat_invite_link( + chat_id=self.id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def create_invite_link( + async def create_invite_link( self, expire_date: Union[int, datetime] = None, member_limit: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, name: str = None, creates_join_request: bool = None, @@ -1650,22 +1903,28 @@ def create_invite_link( :class:`telegram.ChatInviteLink` """ - return self.get_bot().create_chat_invite_link( + return await self.get_bot().create_chat_invite_link( chat_id=self.id, expire_date=expire_date, member_limit=member_limit, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, name=name, creates_join_request=creates_join_request, ) - def edit_invite_link( + async def edit_invite_link( self, invite_link: Union[str, 'ChatInviteLink'], expire_date: Union[int, datetime] = None, member_limit: int = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, name: str = None, creates_join_request: bool = None, @@ -1686,21 +1945,27 @@ def edit_invite_link( :class:`telegram.ChatInviteLink` """ - return self.get_bot().edit_chat_invite_link( + return await self.get_bot().edit_chat_invite_link( chat_id=self.id, invite_link=invite_link, expire_date=expire_date, member_limit=member_limit, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, name=name, creates_join_request=creates_join_request, ) - def revoke_invite_link( + async def revoke_invite_link( self, invite_link: Union[str, 'ChatInviteLink'], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> 'ChatInviteLink': """Shortcut for:: @@ -1716,14 +1981,23 @@ def revoke_invite_link( :class:`telegram.ChatInviteLink` """ - return self.get_bot().revoke_chat_invite_link( - chat_id=self.id, invite_link=invite_link, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().revoke_chat_invite_link( + chat_id=self.id, + invite_link=invite_link, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def approve_join_request( + async def approve_join_request( self, user_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -1739,14 +2013,23 @@ def approve_join_request( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().approve_chat_join_request( - chat_id=self.id, user_id=user_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().approve_chat_join_request( + chat_id=self.id, + user_id=user_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def decline_join_request( + async def decline_join_request( self, user_id: int, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -1762,6 +2045,12 @@ def decline_join_request( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().decline_chat_join_request( - chat_id=self.id, user_id=user_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().decline_chat_join_request( + chat_id=self.id, + user_id=user_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) diff --git a/telegram/_chatjoinrequest.py b/telegram/_chatjoinrequest.py index 5440e85d8f6..98cc5808a4c 100644 --- a/telegram/_chatjoinrequest.py +++ b/telegram/_chatjoinrequest.py @@ -115,9 +115,12 @@ def to_dict(self) -> JSONDict: return data - def approve( + async def approve( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -132,13 +135,22 @@ def approve( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().approve_chat_join_request( - chat_id=self.chat.id, user_id=self.from_user.id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().approve_chat_join_request( + chat_id=self.chat.id, + user_id=self.from_user.id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def decline( + async def decline( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -153,6 +165,12 @@ def decline( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().decline_chat_join_request( - chat_id=self.chat.id, user_id=self.from_user.id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().decline_chat_join_request( + chat_id=self.chat.id, + user_id=self.from_user.id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) diff --git a/telegram/_files/_basemedium.py b/telegram/_files/_basemedium.py index 06e9a485ccd..ec0daf2c389 100644 --- a/telegram/_files/_basemedium.py +++ b/telegram/_files/_basemedium.py @@ -65,8 +65,13 @@ def __init__( self._id_attrs = (self.file_unique_id,) - def get_file( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_file( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> 'File': """Convenience wrapper over :attr:`telegram.Bot.get_file` @@ -79,6 +84,11 @@ def get_file( :class:`telegram.error.TelegramError` """ - return self.get_bot().get_file( - file_id=self.file_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().get_file( + file_id=self.file_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) diff --git a/telegram/_files/chatphoto.py b/telegram/_files/chatphoto.py index f3d8a7b9b49..01bcc990a67 100644 --- a/telegram/_files/chatphoto.py +++ b/telegram/_files/chatphoto.py @@ -93,8 +93,13 @@ def __init__( self.big_file_unique_id, ) - def get_small_file( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_small_file( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> 'File': """Convenience wrapper over :attr:`telegram.Bot.get_file` for getting the small (160x160) chat photo @@ -108,12 +113,22 @@ def get_small_file( :class:`telegram.error.TelegramError` """ - return self.get_bot().get_file( - file_id=self.small_file_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().get_file( + file_id=self.small_file_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def get_big_file( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_big_file( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> 'File': """Convenience wrapper over :attr:`telegram.Bot.get_file` for getting the big (640x640) chat photo @@ -127,6 +142,11 @@ def get_big_file( :class:`telegram.error.TelegramError` """ - return self.get_bot().get_file( - file_id=self.big_file_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().get_file( + file_id=self.big_file_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) diff --git a/telegram/_files/file.py b/telegram/_files/file.py index c1fe7ce0286..8473f812108 100644 --- a/telegram/_files/file.py +++ b/telegram/_files/file.py @@ -25,8 +25,9 @@ from telegram import TelegramObject from telegram._passport.credentials import decrypt +from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.files import is_local_file -from telegram._utils.types import FilePathInput +from telegram._utils.types import FilePathInput, ODVInput if TYPE_CHECKING: from telegram import Bot, FileCredentials @@ -96,8 +97,14 @@ def __init__( self._id_attrs = (self.file_unique_id,) - def download( - self, custom_path: FilePathInput = None, out: IO = None, timeout: int = None + async def download( + self, + custom_path: FilePathInput = None, + out: IO = None, + read_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> Union[Path, IO]: """ Download this file. By default, the file is saved in the current working directory with its @@ -146,7 +153,7 @@ def download( if local_file: buf = path.read_bytes() else: - buf = self.get_bot().request.retrieve(url) + buf = await self.get_bot().request.retrieve(url) if self._credentials: buf = decrypt( b64decode(self._credentials.secret), b64decode(self._credentials.hash), buf @@ -167,7 +174,13 @@ def download( else: filename = Path.cwd() / self.file_id - buf = self.get_bot().request.retrieve(url, timeout=timeout) + buf = await self.get_bot().request.retrieve( + url, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) if self._credentials: buf = decrypt( b64decode(self._credentials.secret), b64decode(self._credentials.hash), buf @@ -184,7 +197,7 @@ def _get_encoded_url(self) -> str: ) ) - def download_as_bytearray(self, buf: bytearray = None) -> bytes: + async def download_as_bytearray(self, buf: bytearray = None) -> bytes: """Download this file and return it as a bytearray. Args: @@ -200,7 +213,7 @@ def download_as_bytearray(self, buf: bytearray = None) -> bytes: if is_local_file(self.file_path): buf.extend(Path(self.file_path).read_bytes()) else: - buf.extend(self.get_bot().request.retrieve(self._get_encoded_url())) + buf.extend(await self.get_bot().request.retrieve(self._get_encoded_url())) return buf def set_credentials(self, credentials: 'FileCredentials') -> None: diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index cffae12c477..398c05e79f2 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -23,10 +23,12 @@ import logging import mimetypes from pathlib import Path -from typing import IO, Optional, Tuple, Union +from typing import IO, Optional, Union from uuid import uuid4 -DEFAULT_MIME_TYPE = 'application/octet-stream' +from telegram._utils.types import FieldTuple + +_DEFAULT_MIME_TYPE = 'application/octet-stream' logger = logging.getLogger(__name__) @@ -36,29 +38,32 @@ class InputFile: Args: obj (:obj:`File handler` | :obj:`bytes`): An open file descriptor or the files content as bytes. + + Note: + If ``obj`` is a string, it will be encoded as bytes via ``obj.encode('utf-8')``. filename (:obj:`str`, optional): Filename for this InputFile. - attach (:obj:`bool`, optional): Whether this should be send as one file or is part of a - collection of files. Raises: TelegramError Attributes: input_file_content (:obj:`bytes`): The binary content of the file to send. + attach_name (:obj:`str`): Attach name. filename (:obj:`str`): Optional. Filename for the file to be sent. - attach (:obj:`str`): Optional. Attach id for sending multiple files. mimetype (:obj:`str`): Optional. The mimetype inferred from the file to be sent. """ - __slots__ = ('filename', 'attach', 'input_file_content', 'mimetype') + __slots__ = ('filename', 'attach_name', 'input_file_content', 'mimetype') - def __init__(self, obj: Union[IO, bytes], filename: str = None, attach: bool = None): + def __init__(self, obj: Union[IO[bytes], bytes, str], filename: str = None): if isinstance(obj, bytes): self.input_file_content = obj + elif isinstance(obj, str): + self.input_file_content = obj.encode('utf-8') else: self.input_file_content = obj.read() - self.attach = 'attached' + uuid4().hex if attach else None + self.attach_name = 'attached' + uuid4().hex if ( not filename @@ -71,14 +76,14 @@ def __init__(self, obj: Union[IO, bytes], filename: str = None, attach: bool = N if image_mime_type: self.mimetype = image_mime_type elif filename: - self.mimetype = mimetypes.guess_type(filename)[0] or DEFAULT_MIME_TYPE + self.mimetype = mimetypes.guess_type(filename)[0] or _DEFAULT_MIME_TYPE else: - self.mimetype = DEFAULT_MIME_TYPE + self.mimetype = _DEFAULT_MIME_TYPE self.filename = filename or self.mimetype.replace('/', '.') @property - def field_tuple(self) -> Tuple[str, bytes, str]: # skipcq: PY-D0003 + def field_tuple(self) -> FieldTuple: # skipcq: PY-D0003 return self.filename, self.input_file_content, self.mimetype @staticmethod @@ -108,8 +113,7 @@ def is_image(stream: bytes) -> Optional[str]: def is_file(obj: object) -> bool: # skipcq: PY-D0003 return hasattr(obj, 'read') - def to_dict(self) -> Optional[str]: - """See :meth:`telegram.TelegramObject.to_dict`.""" - if self.attach: - return 'attach://' + self.attach - return None + @property + def attach_uri(self) -> str: + """URI to insert into the JSON data for uploading the file.""" + return f'attach://{self.attach_name}' diff --git a/telegram/_files/inputmedia.py b/telegram/_files/inputmedia.py index c20909b7a78..e43c414daea 100644 --- a/telegram/_files/inputmedia.py +++ b/telegram/_files/inputmedia.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """Base class for Telegram InputMedia Objects.""" - from typing import Union, List, Tuple, Optional from telegram import ( @@ -102,7 +101,7 @@ def to_dict(self) -> JSONDict: @staticmethod def _parse_thumb_input(thumb: Optional[FileInput]) -> Optional[Union[str, InputFile]]: - return parse_file_input(thumb, attach=True) if thumb is not None else thumb + return parse_file_input(thumb) if thumb is not None else thumb class InputMediaAnimation(InputMedia): @@ -182,7 +181,7 @@ def __init__( duration = media.duration if duration is None else duration media = media.file_id else: - media = parse_file_input(media, attach=True, filename=filename) + media = parse_file_input(media, filename=filename) super().__init__(InputMediaType.ANIMATION, media, caption, caption_entities, parse_mode) self.thumb = self._parse_thumb_input(thumb) @@ -237,7 +236,7 @@ def __init__( caption_entities: Union[List[MessageEntity], Tuple[MessageEntity, ...]] = None, filename: str = None, ): - media = parse_file_input(media, PhotoSize, attach=True, filename=filename) + media = parse_file_input(media, PhotoSize, filename=filename) super().__init__(InputMediaType.PHOTO, media, caption, caption_entities, parse_mode) @@ -327,7 +326,7 @@ def __init__( duration = duration if duration is not None else media.duration media = media.file_id else: - media = parse_file_input(media, attach=True, filename=filename) + media = parse_file_input(media, filename=filename) super().__init__(InputMediaType.VIDEO, media, caption, caption_entities, parse_mode) self.width = width @@ -417,7 +416,7 @@ def __init__( title = media.title if title is None else title media = media.file_id else: - media = parse_file_input(media, attach=True, filename=filename) + media = parse_file_input(media, filename=filename) super().__init__(InputMediaType.AUDIO, media, caption, caption_entities, parse_mode) self.thumb = self._parse_thumb_input(thumb) @@ -490,7 +489,7 @@ def __init__( caption_entities: Union[List[MessageEntity], Tuple[MessageEntity, ...]] = None, filename: str = None, ): - media = parse_file_input(media, Document, attach=True, filename=filename) + media = parse_file_input(media, Document, filename=filename) super().__init__(InputMediaType.DOCUMENT, media, caption, caption_entities, parse_mode) self.thumb = self._parse_thumb_input(thumb) self.disable_content_type_detection = disable_content_type_detection diff --git a/telegram/_inline/inlinequery.py b/telegram/_inline/inlinequery.py index 9708930377b..f29a0a8645c 100644 --- a/telegram/_inline/inlinequery.py +++ b/telegram/_inline/inlinequery.py @@ -110,7 +110,7 @@ def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['InlineQuery' return cls(bot=bot, **data) - def answer( + async def answer( self, results: Union[ Sequence['InlineQueryResult'], Callable[[int], Optional[Sequence['InlineQueryResult']]] @@ -120,7 +120,10 @@ def answer( next_offset: str = None, switch_pm_text: str = None, switch_pm_parameter: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, current_offset: str = None, api_kwargs: JSONDict = None, auto_pagination: bool = False, @@ -150,7 +153,7 @@ def answer( """ if current_offset and auto_pagination: raise ValueError('current_offset and auto_pagination are mutually exclusive!') - return self.get_bot().answer_inline_query( + return await self.get_bot().answer_inline_query( inline_query_id=self.id, current_offset=self.offset if auto_pagination else current_offset, results=results, @@ -159,7 +162,10 @@ def answer( next_offset=next_offset, switch_pm_text=switch_pm_text, switch_pm_parameter=switch_pm_parameter, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) diff --git a/telegram/_message.py b/telegram/_message.py index 48de2493fcd..484e641fe95 100644 --- a/telegram/_message.py +++ b/telegram/_message.py @@ -56,7 +56,7 @@ from telegram.constants import ParseMode, MessageAttachmentType from telegram.helpers import escape_markdown from telegram._utils.datetime import from_timestamp, to_timestamp -from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_20, DefaultValue +from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue from telegram._utils.types import JSONDict, FileInput, ODVInput, DVInput, ReplyMarkup if TYPE_CHECKING: @@ -725,7 +725,7 @@ def _quote(self, quote: Optional[bool], reply_to_message_id: Optional[int]) -> O return None - def reply_text( + async def reply_text( self, text: str, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -733,7 +733,10 @@ def reply_text( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -757,7 +760,7 @@ def reply_text( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_message( + return await self.get_bot().send_message( chat_id=self.chat_id, text=text, parse_mode=parse_mode, @@ -765,21 +768,27 @@ def reply_text( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, entities=entities, protect_content=protect_content, ) - def reply_markdown( + async def reply_markdown( self, text: str, disable_web_page_preview: ODVInput[bool] = DEFAULT_NONE, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -813,7 +822,7 @@ def reply_markdown( :class:`telegram.Message`: On success, instance representing the message posted. """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_message( + return await self.get_bot().send_message( chat_id=self.chat_id, text=text, parse_mode=ParseMode.MARKDOWN, @@ -821,21 +830,27 @@ def reply_markdown( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, entities=entities, protect_content=protect_content, ) - def reply_markdown_v2( + async def reply_markdown_v2( self, text: str, disable_web_page_preview: ODVInput[bool] = DEFAULT_NONE, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -865,7 +880,7 @@ def reply_markdown_v2( :class:`telegram.Message`: On success, instance representing the message posted. """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_message( + return await self.get_bot().send_message( chat_id=self.chat_id, text=text, parse_mode=ParseMode.MARKDOWN_V2, @@ -873,21 +888,27 @@ def reply_markdown_v2( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, entities=entities, protect_content=protect_content, ) - def reply_html( + async def reply_html( self, text: str, disable_web_page_preview: ODVInput[bool] = DEFAULT_NONE, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -917,7 +938,7 @@ def reply_html( :class:`telegram.Message`: On success, instance representing the message posted. """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_message( + return await self.get_bot().send_message( chat_id=self.chat_id, text=text, parse_mode=ParseMode.HTML, @@ -925,21 +946,27 @@ def reply_html( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, entities=entities, protect_content=protect_content, ) - def reply_media_group( + async def reply_media_group( self, media: List[ Union['InputMediaAudio', 'InputMediaDocument', 'InputMediaPhoto', 'InputMediaVideo'] ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, quote: bool = None, @@ -964,25 +991,31 @@ def reply_media_group( :class:`telegram.error.TelegramError` """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_media_group( + return await self.get_bot().send_media_group( chat_id=self.chat_id, media=media, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def reply_photo( + async def reply_photo( self, photo: Union[FileInput, 'PhotoSize'], caption: str = None, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1008,14 +1041,17 @@ def reply_photo( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_photo( + return await self.get_bot().send_photo( chat_id=self.chat_id, photo=photo, caption=caption, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1024,7 +1060,7 @@ def reply_photo( protect_content=protect_content, ) - def reply_audio( + async def reply_audio( self, audio: Union[FileInput, 'Audio'], duration: int = None, @@ -1034,7 +1070,10 @@ def reply_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -1061,7 +1100,7 @@ def reply_audio( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_audio( + return await self.get_bot().send_audio( chat_id=self.chat_id, audio=audio, duration=duration, @@ -1071,7 +1110,10 @@ def reply_audio( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, thumb=thumb, api_kwargs=api_kwargs, @@ -1081,7 +1123,7 @@ def reply_audio( protect_content=protect_content, ) - def reply_document( + async def reply_document( self, document: Union[FileInput, 'Document'], filename: str = None, @@ -1089,7 +1131,10 @@ def reply_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -1116,7 +1161,7 @@ def reply_document( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_document( + return await self.get_bot().send_document( chat_id=self.chat_id, document=document, filename=filename, @@ -1124,7 +1169,10 @@ def reply_document( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, thumb=thumb, api_kwargs=api_kwargs, @@ -1134,7 +1182,7 @@ def reply_document( protect_content=protect_content, ) - def reply_animation( + async def reply_animation( self, animation: Union[FileInput, 'Animation'], duration: int = None, @@ -1146,7 +1194,10 @@ def reply_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -1171,7 +1222,7 @@ def reply_animation( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_animation( + return await self.get_bot().send_animation( chat_id=self.chat_id, animation=animation, duration=duration, @@ -1183,7 +1234,10 @@ def reply_animation( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, caption_entities=caption_entities, @@ -1191,13 +1245,16 @@ def reply_animation( protect_content=protect_content, ) - def reply_sticker( + async def reply_sticker( self, sticker: Union[FileInput, 'Sticker'], disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, quote: bool = None, @@ -1220,19 +1277,22 @@ def reply_sticker( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_sticker( + return await self.get_bot().send_sticker( chat_id=self.chat_id, sticker=sticker, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def reply_video( + async def reply_video( self, video: Union[FileInput, 'Video'], duration: int = None, @@ -1240,7 +1300,10 @@ def reply_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, height: int = None, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1270,7 +1333,7 @@ def reply_video( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_video( + return await self.get_bot().send_video( chat_id=self.chat_id, video=video, duration=duration, @@ -1278,7 +1341,10 @@ def reply_video( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, width=width, height=height, parse_mode=parse_mode, @@ -1291,7 +1357,7 @@ def reply_video( protect_content=protect_content, ) - def reply_video_note( + async def reply_video_note( self, video_note: Union[FileInput, 'VideoNote'], duration: int = None, @@ -1299,7 +1365,10 @@ def reply_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1324,7 +1393,7 @@ def reply_video_note( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_video_note( + return await self.get_bot().send_video_note( chat_id=self.chat_id, video_note=video_note, duration=duration, @@ -1332,7 +1401,10 @@ def reply_video_note( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, thumb=thumb, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1340,7 +1412,7 @@ def reply_video_note( protect_content=protect_content, ) - def reply_voice( + async def reply_voice( self, voice: Union[FileInput, 'Voice'], duration: int = None, @@ -1348,7 +1420,10 @@ def reply_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1374,7 +1449,7 @@ def reply_voice( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_voice( + return await self.get_bot().send_voice( chat_id=self.chat_id, voice=voice, duration=duration, @@ -1382,7 +1457,10 @@ def reply_voice( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1391,14 +1469,17 @@ def reply_voice( protect_content=protect_content, ) - def reply_location( + async def reply_location( self, latitude: float = None, longitude: float = None, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, location: Location = None, live_period: int = None, api_kwargs: JSONDict = None, @@ -1426,14 +1507,17 @@ def reply_location( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_location( + return await self.get_bot().send_location( chat_id=self.chat_id, latitude=latitude, longitude=longitude, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, location=location, live_period=live_period, api_kwargs=api_kwargs, @@ -1444,7 +1528,7 @@ def reply_location( protect_content=protect_content, ) - def reply_venue( + async def reply_venue( self, latitude: float = None, longitude: float = None, @@ -1454,7 +1538,10 @@ def reply_venue( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, venue: Venue = None, foursquare_type: str = None, api_kwargs: JSONDict = None, @@ -1481,7 +1568,7 @@ def reply_venue( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_venue( + return await self.get_bot().send_venue( chat_id=self.chat_id, latitude=latitude, longitude=longitude, @@ -1491,7 +1578,10 @@ def reply_venue( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, venue=venue, foursquare_type=foursquare_type, api_kwargs=api_kwargs, @@ -1501,7 +1591,7 @@ def reply_venue( protect_content=protect_content, ) - def reply_contact( + async def reply_contact( self, phone_number: str = None, first_name: str = None, @@ -1509,7 +1599,10 @@ def reply_contact( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, contact: Contact = None, vcard: str = None, api_kwargs: JSONDict = None, @@ -1534,7 +1627,7 @@ def reply_contact( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_contact( + return await self.get_bot().send_contact( chat_id=self.chat_id, phone_number=phone_number, first_name=first_name, @@ -1542,7 +1635,10 @@ def reply_contact( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, contact=contact, vcard=vcard, api_kwargs=api_kwargs, @@ -1550,7 +1646,7 @@ def reply_contact( protect_content=protect_content, ) - def reply_poll( + async def reply_poll( self, question: str, options: List[str], @@ -1562,7 +1658,10 @@ def reply_poll( disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, explanation: str = None, explanation_parse_mode: ODVInput[str] = DEFAULT_NONE, open_period: int = None, @@ -1590,7 +1689,7 @@ def reply_poll( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_poll( + return await self.get_bot().send_poll( chat_id=self.chat_id, question=question, options=options, @@ -1602,7 +1701,10 @@ def reply_poll( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, explanation=explanation, explanation_parse_mode=explanation_parse_mode, open_period=open_period, @@ -1613,12 +1715,15 @@ def reply_poll( protect_content=protect_content, ) - def reply_dice( + async def reply_dice( self, disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, emoji: str = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1642,22 +1747,28 @@ def reply_dice( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_dice( + return await self.get_bot().send_dice( chat_id=self.chat_id, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, emoji=emoji, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def reply_chat_action( + async def reply_chat_action( self, action: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -1672,20 +1783,26 @@ def reply_chat_action( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().send_chat_action( + return await self.get_bot().send_chat_action( chat_id=self.chat_id, action=action, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def reply_game( + async def reply_game( self, game_short_name: str, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, quote: bool = None, @@ -1710,19 +1827,22 @@ def reply_game( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_game( + return await self.get_bot().send_game( chat_id=self.chat_id, game_short_name=game_short_name, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def reply_invoice( + async def reply_invoice( self, title: str, description: str, @@ -1746,7 +1866,10 @@ def reply_invoice( provider_data: Union[str, object] = None, send_phone_number_to_provider: bool = None, send_email_to_provider: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, quote: bool = None, @@ -1781,7 +1904,7 @@ def reply_invoice( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().send_invoice( + return await self.get_bot().send_invoice( chat_id=self.chat_id, title=title, description=description, @@ -1805,7 +1928,10 @@ def reply_invoice( provider_data=provider_data, send_phone_number_to_provider=send_phone_number_to_provider, send_email_to_provider=send_email_to_provider, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, max_tip_amount=max_tip_amount, @@ -1813,11 +1939,14 @@ def reply_invoice( protect_content=protect_content, ) - def forward( + async def forward( self, chat_id: Union[int, str], disable_notification: DVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'Message': @@ -1843,17 +1972,20 @@ def forward( :class:`telegram.Message`: On success, instance representing the message forwarded. """ - return self.get_bot().forward_message( + return await self.get_bot().forward_message( chat_id=chat_id, from_chat_id=self.chat_id, message_id=self.message_id, disable_notification=disable_notification, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def copy( + async def copy( self, chat_id: Union[int, str], caption: str = None, @@ -1863,7 +1995,10 @@ def copy( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'MessageId': @@ -1881,7 +2016,7 @@ def copy( :class:`telegram.MessageId`: On success, returns the MessageId of the sent message. """ - return self.get_bot().copy_message( + return await self.get_bot().copy_message( chat_id=chat_id, from_chat_id=self.chat_id, message_id=self.message_id, @@ -1892,12 +2027,15 @@ def copy( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def reply_copy( + async def reply_copy( self, from_chat_id: Union[str, int], message_id: int, @@ -1908,7 +2046,10 @@ def reply_copy( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, quote: bool = None, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -1936,7 +2077,7 @@ def reply_copy( """ reply_to_message_id = self._quote(quote, reply_to_message_id) - return self.get_bot().copy_message( + return await self.get_bot().copy_message( chat_id=self.chat_id, from_chat_id=from_chat_id, message_id=message_id, @@ -1947,18 +2088,24 @@ def reply_copy( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def edit_text( + async def edit_text( self, text: str, parse_mode: ODVInput[str] = DEFAULT_NONE, disable_web_page_preview: ODVInput[bool] = DEFAULT_NONE, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, ) -> Union['Message', bool]: @@ -1981,24 +2128,30 @@ def edit_text( edited Message is returned, otherwise ``True`` is returned. """ - return self.get_bot().edit_message_text( + return await self.get_bot().edit_message_text( chat_id=self.chat_id, message_id=self.message_id, text=text, parse_mode=parse_mode, disable_web_page_preview=disable_web_page_preview, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, entities=entities, inline_message_id=None, ) - def edit_caption( + async def edit_caption( self, caption: str = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -2023,23 +2176,29 @@ def edit_caption( edited Message is returned, otherwise ``True`` is returned. """ - return self.get_bot().edit_message_caption( + return await self.get_bot().edit_message_caption( chat_id=self.chat_id, message_id=self.message_id, caption=caption, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, caption_entities=caption_entities, inline_message_id=None, ) - def edit_media( + async def edit_media( self, media: 'InputMedia', reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union['Message', bool]: """Shortcut for:: @@ -2062,20 +2221,26 @@ def edit_media( edited Message is returned, otherwise ``True`` is returned. """ - return self.get_bot().edit_message_media( + return await self.get_bot().edit_message_media( media=media, chat_id=self.chat_id, message_id=self.message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, inline_message_id=None, ) - def edit_reply_markup( + async def edit_reply_markup( self, reply_markup: Optional['InlineKeyboardMarkup'] = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union['Message', bool]: """Shortcut for:: @@ -2097,22 +2262,28 @@ def edit_reply_markup( :class:`telegram.Message`: On success, if edited message is sent by the bot, the edited Message is returned, otherwise ``True`` is returned. """ - return self.get_bot().edit_message_reply_markup( + return await self.get_bot().edit_message_reply_markup( chat_id=self.chat_id, message_id=self.message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, inline_message_id=None, ) - def edit_live_location( + async def edit_live_location( self, latitude: float = None, longitude: float = None, location: Location = None, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, horizontal_accuracy: float = None, heading: int = None, @@ -2137,14 +2308,17 @@ def edit_live_location( :class:`telegram.Message`: On success, if edited message is sent by the bot, the edited Message is returned, otherwise :obj:`True` is returned. """ - return self.get_bot().edit_message_live_location( + return await self.get_bot().edit_message_live_location( chat_id=self.chat_id, message_id=self.message_id, latitude=latitude, longitude=longitude, location=location, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, horizontal_accuracy=horizontal_accuracy, heading=heading, @@ -2152,10 +2326,13 @@ def edit_live_location( inline_message_id=None, ) - def stop_live_location( + async def stop_live_location( self, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union['Message', bool]: """Shortcut for:: @@ -2177,22 +2354,28 @@ def stop_live_location( :class:`telegram.Message`: On success, if edited message is sent by the bot, the edited Message is returned, otherwise :obj:`True` is returned. """ - return self.get_bot().stop_message_live_location( + return await self.get_bot().stop_message_live_location( chat_id=self.chat_id, message_id=self.message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, inline_message_id=None, ) - def set_game_score( + async def set_game_score( self, user_id: Union[int, str], score: int, force: bool = None, disable_edit_message: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Union['Message', bool]: """Shortcut for:: @@ -2213,22 +2396,28 @@ def set_game_score( :class:`telegram.Message`: On success, if edited message is sent by the bot, the edited Message is returned, otherwise :obj:`True` is returned. """ - return self.get_bot().set_game_score( + return await self.get_bot().set_game_score( chat_id=self.chat_id, message_id=self.message_id, user_id=user_id, score=score, force=force, disable_edit_message=disable_edit_message, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, inline_message_id=None, ) - def get_game_high_scores( + async def get_game_high_scores( self, user_id: Union[int, str], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> List['GameHighScore']: """Shortcut for:: @@ -2249,18 +2438,24 @@ def get_game_high_scores( Returns: List[:class:`telegram.GameHighScore`] """ - return self.get_bot().get_game_high_scores( + return await self.get_bot().get_game_high_scores( chat_id=self.chat_id, message_id=self.message_id, user_id=user_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, inline_message_id=None, ) - def delete( + async def delete( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -2276,17 +2471,23 @@ def delete( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().delete_message( + return await self.get_bot().delete_message( chat_id=self.chat_id, message_id=self.message_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def stop_poll( + async def stop_poll( self, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Poll: """Shortcut for:: @@ -2303,18 +2504,24 @@ def stop_poll( returned. """ - return self.get_bot().stop_poll( + return await self.get_bot().stop_poll( chat_id=self.chat_id, message_id=self.message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def pin( + async def pin( self, disable_notification: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -2330,17 +2537,23 @@ def pin( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().pin_chat_message( + return await self.get_bot().pin_chat_message( chat_id=self.chat_id, message_id=self.message_id, disable_notification=disable_notification, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def unpin( + async def unpin( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -2356,10 +2569,13 @@ def unpin( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unpin_chat_message( + return await self.get_bot().unpin_chat_message( chat_id=self.chat_id, message_id=self.message_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) diff --git a/telegram/_passport/passportfile.py b/telegram/_passport/passportfile.py index a8df26a57e7..5565a485b40 100644 --- a/telegram/_passport/passportfile.py +++ b/telegram/_passport/passportfile.py @@ -136,8 +136,13 @@ def de_list_decrypted( for i, passport_file in enumerate(data) ] - def get_file( - self, timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None + async def get_file( + self, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + api_kwargs: JSONDict = None, ) -> 'File': """ Wrapper over :attr:`telegram.Bot.get_file`. Will automatically assign the correct @@ -153,8 +158,13 @@ def get_file( :class:`telegram.error.TelegramError` """ - file = self.get_bot().get_file( - file_id=self.file_id, timeout=timeout, api_kwargs=api_kwargs + file = await self.get_bot().get_file( + file_id=self.file_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) file.set_credentials(self._credentials) return file diff --git a/telegram/_payment/precheckoutquery.py b/telegram/_payment/precheckoutquery.py index b7497e73aef..1b4edc5f023 100644 --- a/telegram/_payment/precheckoutquery.py +++ b/telegram/_payment/precheckoutquery.py @@ -114,11 +114,14 @@ def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['PreCheckoutQ return cls(bot=bot, **data) - def answer( # pylint: disable=invalid-name + async def answer( # pylint: disable=invalid-name self, ok: bool, error_message: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -129,10 +132,13 @@ def answer( # pylint: disable=invalid-name :meth:`telegram.Bot.answer_pre_checkout_query`. """ - return self.get_bot().answer_pre_checkout_query( + return await self.get_bot().answer_pre_checkout_query( pre_checkout_query_id=self.id, ok=ok, error_message=error_message, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) diff --git a/telegram/_payment/shippingquery.py b/telegram/_payment/shippingquery.py index 8f3718e506c..de4b8a3463e 100644 --- a/telegram/_payment/shippingquery.py +++ b/telegram/_payment/shippingquery.py @@ -87,12 +87,15 @@ def de_json(cls, data: Optional[JSONDict], bot: 'Bot') -> Optional['ShippingQuer return cls(bot=bot, **data) - def answer( # pylint: disable=invalid-name + async def answer( # pylint: disable=invalid-name self, ok: bool, shipping_options: List[ShippingOption] = None, error_message: str = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -103,11 +106,14 @@ def answer( # pylint: disable=invalid-name :meth:`telegram.Bot.answer_shipping_query`. """ - return self.get_bot().answer_shipping_query( + return await self.get_bot().answer_shipping_query( shipping_query_id=self.id, ok=ok, shipping_options=shipping_options, error_message=error_message, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) diff --git a/telegram/_telegramobject.py b/telegram/_telegramobject.py index e70bb9c2d55..915eb248b51 100644 --- a/telegram/_telegramobject.py +++ b/telegram/_telegramobject.py @@ -171,8 +171,7 @@ def get_bot(self) -> 'Bot': """ if self._bot is None: raise RuntimeError( - 'This object has no bot associated with it. \ - Shortcuts cannot be used.' + 'This object has no bot associated with it. Shortcuts cannot be used.' ) return self._bot diff --git a/telegram/_user.py b/telegram/_user.py index 9a258e572f7..6c839b1729d 100644 --- a/telegram/_user.py +++ b/telegram/_user.py @@ -27,7 +27,7 @@ mention_markdown as helpers_mention_markdown, mention_html as helpers_mention_html, ) -from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_20 +from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.types import JSONDict, FileInput, ODVInput, DVInput, ReplyMarkup if TYPE_CHECKING: @@ -163,11 +163,14 @@ def link(self) -> Optional[str]: return f"https://t.me/{self.username}" return None - def get_profile_photos( + async def get_profile_photos( self, offset: int = None, limit: int = 100, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Optional['UserProfilePhotos']: """ @@ -179,11 +182,14 @@ def get_profile_photos( :meth:`telegram.Bot.get_user_profile_photos`. """ - return self.get_bot().get_user_profile_photos( + return await self.get_bot().get_user_profile_photos( user_id=self.id, offset=offset, limit=limit, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) @@ -247,11 +253,14 @@ def mention_button(self, name: str = None) -> InlineKeyboardButton: """ return InlineKeyboardButton(text=name or self.full_name, url=f"tg://user?id={self.id}") - def pin_message( + async def pin_message( self, message_id: int, disable_notification: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -266,17 +275,23 @@ def pin_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().pin_chat_message( + return await self.get_bot().pin_chat_message( chat_id=self.id, message_id=message_id, disable_notification=disable_notification, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def unpin_message( + async def unpin_message( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, message_id: int = None, ) -> bool: @@ -292,16 +307,22 @@ def unpin_message( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unpin_chat_message( + return await self.get_bot().unpin_chat_message( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, message_id=message_id, ) - def unpin_all_messages( + async def unpin_all_messages( self, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -317,13 +338,16 @@ def unpin_all_messages( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().unpin_all_chat_messages( + return await self.get_bot().unpin_all_chat_messages( chat_id=self.id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def send_message( + async def send_message( self, text: str, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -331,7 +355,10 @@ def send_message( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -347,7 +374,7 @@ def send_message( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_message( + return await self.get_bot().send_message( chat_id=self.id, text=text, parse_mode=parse_mode, @@ -355,21 +382,27 @@ def send_message( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, entities=entities, protect_content=protect_content, ) - def send_photo( + async def send_photo( self, photo: Union[FileInput, 'PhotoSize'], caption: str = None, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -387,14 +420,17 @@ def send_photo( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_photo( + return await self.get_bot().send_photo( chat_id=self.id, photo=photo, caption=caption, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -403,14 +439,17 @@ def send_photo( protect_content=protect_content, ) - def send_media_group( + async def send_media_group( self, media: List[ Union['InputMediaAudio', 'InputMediaDocument', 'InputMediaPhoto', 'InputMediaVideo'] ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -425,18 +464,21 @@ def send_media_group( List[:class:`telegram.Message`:] On success, instance representing the message posted. """ - return self.get_bot().send_media_group( + return await self.get_bot().send_media_group( chat_id=self.id, media=media, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_audio( + async def send_audio( self, audio: Union[FileInput, 'Audio'], duration: int = None, @@ -446,7 +488,10 @@ def send_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -465,7 +510,7 @@ def send_audio( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_audio( + return await self.get_bot().send_audio( chat_id=self.id, audio=audio, duration=duration, @@ -475,7 +520,10 @@ def send_audio( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, thumb=thumb, api_kwargs=api_kwargs, @@ -485,10 +533,13 @@ def send_audio( protect_content=protect_content, ) - def send_chat_action( + async def send_chat_action( self, action: str, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -501,17 +552,20 @@ def send_chat_action( :obj:`True`: On success. """ - return self.get_bot().send_chat_action( + return await self.get_bot().send_chat_action( chat_id=self.id, action=action, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) send_action = send_chat_action """Alias for :attr:`send_chat_action`""" - def send_contact( + async def send_contact( self, phone_number: str = None, first_name: str = None, @@ -519,7 +573,10 @@ def send_contact( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, contact: 'Contact' = None, vcard: str = None, api_kwargs: JSONDict = None, @@ -536,7 +593,7 @@ def send_contact( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_contact( + return await self.get_bot().send_contact( chat_id=self.id, phone_number=phone_number, first_name=first_name, @@ -544,7 +601,10 @@ def send_contact( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, contact=contact, vcard=vcard, api_kwargs=api_kwargs, @@ -552,12 +612,15 @@ def send_contact( protect_content=protect_content, ) - def send_dice( + async def send_dice( self, disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, emoji: str = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -573,19 +636,22 @@ def send_dice( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_dice( + return await self.get_bot().send_dice( chat_id=self.id, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, emoji=emoji, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_document( + async def send_document( self, document: Union[FileInput, 'Document'], filename: str = None, @@ -593,7 +659,10 @@ def send_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, @@ -612,7 +681,7 @@ def send_document( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_document( + return await self.get_bot().send_document( chat_id=self.id, document=document, filename=filename, @@ -620,7 +689,10 @@ def send_document( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, thumb=thumb, api_kwargs=api_kwargs, @@ -630,13 +702,16 @@ def send_document( protect_content=protect_content, ) - def send_game( + async def send_game( self, game_short_name: str, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: 'InlineKeyboardMarkup' = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -651,19 +726,22 @@ def send_game( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_game( + return await self.get_bot().send_game( chat_id=self.id, game_short_name=game_short_name, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_invoice( + async def send_invoice( self, title: str, description: str, @@ -687,7 +765,10 @@ def send_invoice( provider_data: Union[str, object] = None, send_phone_number_to_provider: bool = None, send_email_to_provider: bool = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, max_tip_amount: int = None, @@ -712,7 +793,7 @@ def send_invoice( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_invoice( + return await self.get_bot().send_invoice( chat_id=self.id, title=title, description=description, @@ -736,7 +817,10 @@ def send_invoice( provider_data=provider_data, send_phone_number_to_provider=send_phone_number_to_provider, send_email_to_provider=send_email_to_provider, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, max_tip_amount=max_tip_amount, @@ -744,14 +828,17 @@ def send_invoice( protect_content=protect_content, ) - def send_location( + async def send_location( self, latitude: float = None, longitude: float = None, disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, location: 'Location' = None, live_period: int = None, api_kwargs: JSONDict = None, @@ -771,14 +858,17 @@ def send_location( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_location( + return await self.get_bot().send_location( chat_id=self.id, latitude=latitude, longitude=longitude, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, location=location, live_period=live_period, api_kwargs=api_kwargs, @@ -789,7 +879,7 @@ def send_location( protect_content=protect_content, ) - def send_animation( + async def send_animation( self, animation: Union[FileInput, 'Animation'], duration: int = None, @@ -801,7 +891,10 @@ def send_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, caption_entities: Union[List['MessageEntity'], Tuple['MessageEntity', ...]] = None, @@ -818,7 +911,7 @@ def send_animation( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_animation( + return await self.get_bot().send_animation( chat_id=self.id, animation=animation, duration=duration, @@ -830,7 +923,10 @@ def send_animation( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, caption_entities=caption_entities, @@ -838,13 +934,16 @@ def send_animation( protect_content=protect_content, ) - def send_sticker( + async def send_sticker( self, sticker: Union[FileInput, 'Sticker'], disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, protect_content: ODVInput[bool] = DEFAULT_NONE, @@ -859,19 +958,22 @@ def send_sticker( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_sticker( + return await self.get_bot().send_sticker( chat_id=self.id, sticker=sticker, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, protect_content=protect_content, ) - def send_video( + async def send_video( self, video: Union[FileInput, 'Video'], duration: int = None, @@ -879,7 +981,10 @@ def send_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, height: int = None, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -901,7 +1006,7 @@ def send_video( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_video( + return await self.get_bot().send_video( chat_id=self.id, video=video, duration=duration, @@ -909,7 +1014,10 @@ def send_video( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, width=width, height=height, parse_mode=parse_mode, @@ -922,7 +1030,7 @@ def send_video( protect_content=protect_content, ) - def send_venue( + async def send_venue( self, latitude: float = None, longitude: float = None, @@ -932,7 +1040,10 @@ def send_venue( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, venue: 'Venue' = None, foursquare_type: str = None, api_kwargs: JSONDict = None, @@ -951,7 +1062,7 @@ def send_venue( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_venue( + return await self.get_bot().send_venue( chat_id=self.id, latitude=latitude, longitude=longitude, @@ -961,7 +1072,10 @@ def send_venue( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, venue=venue, foursquare_type=foursquare_type, api_kwargs=api_kwargs, @@ -971,7 +1085,7 @@ def send_venue( protect_content=protect_content, ) - def send_video_note( + async def send_video_note( self, video_note: Union[FileInput, 'VideoNote'], duration: int = None, @@ -979,7 +1093,10 @@ def send_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -996,7 +1113,7 @@ def send_video_note( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_video_note( + return await self.get_bot().send_video_note( chat_id=self.id, video_note=video_note, duration=duration, @@ -1004,7 +1121,10 @@ def send_video_note( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, thumb=thumb, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1012,7 +1132,7 @@ def send_video_note( protect_content=protect_content, ) - def send_voice( + async def send_voice( self, voice: Union[FileInput, 'Voice'], duration: int = None, @@ -1020,7 +1140,10 @@ def send_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: DVInput[float] = DEFAULT_20, + read_timeout: float = 20, + write_timeout: float = 20, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, api_kwargs: JSONDict = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, @@ -1038,7 +1161,7 @@ def send_voice( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_voice( + return await self.get_bot().send_voice( chat_id=self.id, voice=voice, duration=duration, @@ -1046,7 +1169,10 @@ def send_voice( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, parse_mode=parse_mode, api_kwargs=api_kwargs, allow_sending_without_reply=allow_sending_without_reply, @@ -1055,7 +1181,7 @@ def send_voice( protect_content=protect_content, ) - def send_poll( + async def send_poll( self, question: str, options: List[str], @@ -1068,7 +1194,10 @@ def send_poll( disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, explanation: str = None, explanation_parse_mode: ODVInput[str] = DEFAULT_NONE, open_period: int = None, @@ -1088,7 +1217,7 @@ def send_poll( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().send_poll( + return await self.get_bot().send_poll( chat_id=self.id, question=question, options=options, @@ -1100,7 +1229,10 @@ def send_poll( disable_notification=disable_notification, reply_to_message_id=reply_to_message_id, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, explanation=explanation, explanation_parse_mode=explanation_parse_mode, open_period=open_period, @@ -1111,7 +1243,7 @@ def send_poll( protect_content=protect_content, ) - def send_copy( + async def send_copy( self, from_chat_id: Union[str, int], message_id: int, @@ -1122,7 +1254,10 @@ def send_copy( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'MessageId': @@ -1136,7 +1271,7 @@ def send_copy( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().copy_message( + return await self.get_bot().copy_message( chat_id=self.id, from_chat_id=from_chat_id, message_id=message_id, @@ -1147,12 +1282,15 @@ def send_copy( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def copy_message( + async def copy_message( self, chat_id: Union[int, str], message_id: int, @@ -1163,7 +1301,10 @@ def copy_message( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> 'MessageId': @@ -1177,7 +1318,7 @@ def copy_message( :class:`telegram.Message`: On success, instance representing the message posted. """ - return self.get_bot().copy_message( + return await self.get_bot().copy_message( from_chat_id=self.id, chat_id=chat_id, message_id=message_id, @@ -1188,15 +1329,21 @@ def copy_message( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=reply_markup, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def approve_join_request( + async def approve_join_request( self, chat_id: Union[int, str], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -1212,14 +1359,23 @@ def approve_join_request( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().approve_chat_join_request( - user_id=self.id, chat_id=chat_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().approve_chat_join_request( + user_id=self.id, + chat_id=chat_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) - def decline_join_request( + async def decline_join_request( self, chat_id: Union[int, str], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> bool: """Shortcut for:: @@ -1235,6 +1391,12 @@ def decline_join_request( :obj:`bool`: On success, :obj:`True` is returned. """ - return self.get_bot().decline_chat_join_request( - user_id=self.id, chat_id=chat_id, timeout=timeout, api_kwargs=api_kwargs + return await self.get_bot().decline_chat_join_request( + user_id=self.id, + chat_id=chat_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, ) diff --git a/telegram/_utils/defaultvalue.py b/telegram/_utils/defaultvalue.py index d5870d48733..94140b10f4c 100644 --- a/telegram/_utils/defaultvalue.py +++ b/telegram/_utils/defaultvalue.py @@ -129,5 +129,11 @@ def __repr__(self) -> str: DEFAULT_FALSE: DefaultValue = DefaultValue(False) """:class:`DefaultValue`: Default :obj:`False`""" +DEFAULT_TRUE: DefaultValue = DefaultValue(True) +""":class:`DefaultValue`: Default :obj:`True` + +.. versionadded:: 14.0 +""" + DEFAULT_20: DefaultValue = DefaultValue(20) """:class:`DefaultValue`: Default :obj:`20`""" diff --git a/telegram/_utils/files.py b/telegram/_utils/files.py index 2ba5a94f167..9789cd45ff4 100644 --- a/telegram/_utils/files.py +++ b/telegram/_utils/files.py @@ -57,7 +57,6 @@ def is_local_file(obj: Optional[FilePathInput]) -> bool: def parse_file_input( file_input: Union[FileInput, 'TelegramObject'], tg_type: Type['TelegramObject'] = None, - attach: bool = None, filename: str = None, ) -> Union[str, 'InputFile', Any]: """ @@ -76,9 +75,6 @@ def parse_file_input( input to parse. tg_type (:obj:`type`, optional): The Telegram media type the input can be. E.g. :class:`telegram.Animation`. - attach (:obj:`bool`, optional): Whether this file should be send as one file or is part of - a collection of files. Only relevant in case an :class:`telegram.InputFile` is - returned. filename (:obj:`str`, optional): The filename. Only relevant in case an :class:`telegram.InputFile` is returned. @@ -98,10 +94,10 @@ def parse_file_input( out = file_input # type: ignore[assignment] return out if isinstance(file_input, bytes): - return InputFile(file_input, attach=attach, filename=filename) + return InputFile(file_input, filename=filename) if InputFile.is_file(file_input): file_input = cast(IO, file_input) - return InputFile(file_input, attach=attach, filename=filename) + return InputFile(file_input, filename=filename) if tg_type and isinstance(file_input, tg_type): return file_input.file_id # type: ignore[attr-defined] return file_input diff --git a/telegram/_utils/types.py b/telegram/_utils/types.py index 73934e3a884..c211cd45458 100644 --- a/telegram/_utils/types.py +++ b/telegram/_utils/types.py @@ -41,15 +41,16 @@ from telegram._utils.defaultvalue import DefaultValue # noqa: F401 from telegram import InlineKeyboardMarkup, ReplyKeyboardMarkup, ReplyKeyboardRemove, ForceReply -FileLike = Union[IO, 'InputFile'] -"""Either an open file handler or a :class:`telegram.InputFile`.""" +FileLike = Union[IO[bytes], 'InputFile'] +"""Either a bytes-stream (e.g. open file handler) or a :class:`telegram.InputFile`.""" FilePathInput = Union[str, Path] """A filepath either as string or as :obj:`pathlib.Path` object.""" -FileInput = Union[FilePathInput, bytes, FileLike] +FileInput = Union[FilePathInput, FileLike, bytes, str] """Valid input for passing files to Telegram. Either a file id as string, a file like object, -a local file path as string, :class:`pathlib.Path` or the file contents as :obj:`bytes`.""" +a local file path as string, :class:`pathlib.Path` or the file contents as :obj:`bytes` or +:obj:`str`.""" JSONDict = Dict[str, Any] """Dictionary containing response from Telegram or data to send to the API.""" @@ -73,3 +74,8 @@ .. versionadded:: 14.0 """ + +FieldTuple = Tuple[str, bytes, str] +"""Alias for return type of `InputFile.field_tuple`.""" +UploadFileDict = Dict[str, FieldTuple] +"""Dictionary containing file data to be uploaded to the API.""" diff --git a/telegram/error.py b/telegram/error.py index 64fc703931d..e10b1359458 100644 --- a/telegram/error.py +++ b/telegram/error.py @@ -16,12 +16,17 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains an classes that represent Telegram errors.""" +"""This module contains an classes that represent Telegram errors. + +.. versionchanged:: 14.0 + Replaced ``Unauthorized`` by :class:`Forbidden`. +""" __all__ = ( 'BadRequest', 'ChatMigrated', 'Conflict', + 'Forbidden', 'InvalidToken', 'NetworkError', 'PassportDecryptionError', @@ -30,7 +35,7 @@ 'TimedOut', ) -from typing import Tuple, Union +from typing import Tuple, Union, Optional def _lstrip_str(in_s: str, lstr: str) -> str: @@ -73,22 +78,33 @@ def __reduce__(self) -> Tuple[type, Tuple[str]]: return self.__class__, (self.message,) -class Unauthorized(TelegramError): - """Raised when the bot has not enough rights to perform the requested action.""" +class Forbidden(TelegramError): + """Raised when the bot has not enough rights to perform the requested action. + + .. versionchanged:: 14.0 + This class was previously named ``Unauthorized``. + """ __slots__ = () class InvalidToken(TelegramError): - """Raised when the token is invalid.""" + """Raised when the token is invalid. - __slots__ = () + Args: + message (:obj:`str`, optional): Any additional information about the exception. + + .. versionadded:: 14.0 + """ + + __slots__ = ('_message',) - def __init__(self) -> None: - super().__init__('Invalid token') + def __init__(self, message: str = None) -> None: + self._message = message + super().__init__('Invalid token' if self._message is None else self._message) - def __reduce__(self) -> Tuple[type, Tuple]: # type: ignore[override] - return self.__class__, () + def __reduce__(self) -> Tuple[type, Tuple[Optional[str]]]: # type: ignore[override] + return self.__class__, (self._message,) class NetworkError(TelegramError): @@ -104,15 +120,18 @@ class BadRequest(NetworkError): class TimedOut(NetworkError): - """Raised when a request took too long to finish.""" + """Raised when a request took too long to finish. - __slots__ = () + Args: + message (:obj:`str`, optional): Any additional information about the exception. - def __init__(self) -> None: - super().__init__('Timed out') + .. versionadded:: 14.0 + """ + + __slots__ = () - def __reduce__(self) -> Tuple[type, Tuple]: # type: ignore[override] - return self.__class__, () + def __init__(self, message: str = None) -> None: + super().__init__(message or 'Timed out') class ChatMigrated(TelegramError): @@ -128,7 +147,7 @@ class ChatMigrated(TelegramError): def __init__(self, new_chat_id: int): super().__init__(f'Group migrated to supergroup. New chat id: {new_chat_id}') - self.new_chat_id = new_chat_id + self.new_chat_id = int(new_chat_id) def __reduce__(self) -> Tuple[type, Tuple[int]]: # type: ignore[override] return self.__class__, (self.new_chat_id,) diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index 3e44c998275..e2ad407d8bc 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -19,6 +19,9 @@ """Extensions over the Telegram Bot API to facilitate bot making""" __all__ = ( + 'Application', + 'ApplicationBuilder', + 'ApplicationHandlerStop', 'BasePersistence', 'CallbackContext', 'CallbackDataCache', @@ -31,9 +34,6 @@ 'ConversationHandler', 'Defaults', 'DictPersistence', - 'Dispatcher', - 'DispatcherBuilder', - 'DispatcherHandlerStop', 'ExtBot', 'filters', 'Handler', @@ -53,7 +53,6 @@ 'StringRegexHandler', 'TypeHandler', 'Updater', - 'UpdaterBuilder', ) from ._extbot import ExtBot @@ -63,9 +62,9 @@ from ._handler import Handler from ._callbackcontext import CallbackContext from ._contexttypes import ContextTypes -from ._dispatcher import Dispatcher, DispatcherHandlerStop from ._jobqueue import JobQueue, Job from ._updater import Updater +from ._application import Application, ApplicationHandlerStop from ._callbackqueryhandler import CallbackQueryHandler from ._choseninlineresulthandler import ChosenInlineResultHandler from ._inlinequeryhandler import InlineQueryHandler @@ -84,4 +83,4 @@ from ._chatjoinrequesthandler import ChatJoinRequestHandler from ._defaults import Defaults from ._callbackdatacache import CallbackDataCache, InvalidCallbackData -from ._builders import DispatcherBuilder, UpdaterBuilder +from ._builders import ApplicationBuilder diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py new file mode 100644 index 00000000000..da73641cf44 --- /dev/null +++ b/telegram/ext/_application.py @@ -0,0 +1,1094 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains the Application class.""" +import asyncio +import inspect +import itertools +import logging +from asyncio import Event +from collections import defaultdict +from copy import deepcopy +from pathlib import Path +from types import TracebackType, MappingProxyType +from typing import ( + Callable, + Dict, + List, + Optional, + Union, + Generic, + TypeVar, + TYPE_CHECKING, + Type, + Tuple, + Coroutine, + Any, + Set, + Mapping, + cast, + MutableMapping, +) + +from telegram._utils.types import DVInput, ODVInput +from telegram.error import TelegramError +from telegram.ext import BasePersistence, ContextTypes, ExtBot, Updater +from telegram.ext._handler import Handler +from telegram.ext._callbackdatacache import CallbackDataCache +from telegram._utils.defaultvalue import DefaultValue, DEFAULT_TRUE, DEFAULT_NONE +from telegram._utils.warnings import warn +from telegram.ext._utils.trackingdefaultdict import TrackingDefaultDict +from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ, HandlerCallback +from telegram.ext._utils.stack import was_called_by + +if TYPE_CHECKING: + from telegram import Message + from telegram.ext._jobqueue import Job + from telegram.ext._builders import InitApplicationBuilder + +DEFAULT_GROUP: int = 0 + +_DispType = TypeVar('_DispType', bound="Application") +_RT = TypeVar('_RT') +_STOP_SIGNAL = object() + +_logger = logging.getLogger(__name__) + + +class ApplicationHandlerStop(Exception): + """ + Raise this in a handler or an error handler to prevent execution of any other handler (even in + different group). + + In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the + optional ``state`` parameter instead of returning the next state: + + .. code-block:: python + + def callback(update, context): + ... + raise ApplicationHandlerStop(next_state) + + Note: + Has no effect, if the handler or error handler is run asynchronously. + + Args: + state (:obj:`object`, optional): The next state of the conversation. + + Attributes: + state (:obj:`object`): Optional. The next state of the conversation. + """ + + __slots__ = ('state',) + + def __init__(self, state: object = None) -> None: + super().__init__() + self.state = state + + +class Application(Generic[BT, CCT, UD, CD, BD, JQ]): + """This class dispatches all kinds of updates to its registered handlers. + + Note: + This class may not be initialized directly. Use :class:`telegram.ext.ApplicationBuilder` + or :meth:`builder` (for convenience). + + .. versionchanged:: 14.0 + + * Initialization is now done through the :class:`telegram.ext.ApplicationBuilder`. + * Removed the attribute ``groups``. + + Attributes: + bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. + update_queue (:class:`asyncio.Queue`): The synchronized queue that will contain the + updates. + job_queue (:class:`telegram.ext.JobQueue`): Optional. The :class:`telegram.ext.JobQueue` + instance to pass onto handler callbacks. + concurrent_updates (:obj:`int`, optional): Number of maximum concurrent worker threads for + the ``@run_async`` decorator and :meth:`run_async`. + chat_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for + the chat. + + .. versionchanged:: 14.0 + :attr:`chat_data` is now read-only + + .. tip:: + Manually modifying :attr:`chat_data` is almost never needed and unadvisable. + + user_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for + the user. + + .. versionchanged:: 14.0 + :attr:`user_data` is now read-only + + .. tip:: + Manually modifying :attr:`user_data` is almost never needed and unadvisable. + + bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. + persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to + store data that should be persistent over restarts. + handlers (Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]): A dictionary mapping each + handler group to the list of handlers registered to that group. + + .. seealso:: + :meth:`add_handler`, :meth:`add_handlers`. + error_handlers (Dict[:obj:`callable`, :obj:`bool`]): A dict, where the keys are error + handlers and the values indicate whether they are to be run asynchronously via + :meth:`run_async`. + + .. seealso:: + :meth:`add_error_handler` + + """ + + # Allowing '__weakref__' creation here since we need it for the JobQueue + __slots__ = ( + '__create_task_tasks', + '__update_fetcher_task', + '__update_persistence_event', + '__update_persistence_lock', + '__update_persistence_task', + '__weakref__', + '_chat_data', + '_concurrent_updates', + '_concurrent_updates_sem', + '_conversation_handler_conversations', + '_initialized', + '_running', + '_user_data', + 'bot', + 'bot_data', + 'chat_data', + 'context_types', + 'error_handlers', + 'handlers', + 'job_queue', + 'persistence', + 'update_queue', + 'updater', + 'user_data', + ) + + def __init__( + self: 'Application[BT, CCT, UD, CD, BD, JQ]', + *, + bot: BT, + update_queue: asyncio.Queue, + updater: Optional[Updater], + job_queue: JQ, + concurrent_updates: Union[bool, int], + persistence: Optional[BasePersistence], + context_types: ContextTypes[CCT, UD, CD, BD], + ): + if not was_called_by( + inspect.currentframe(), Path(__file__).parent.resolve() / '_builders.py' + ): + warn( + '`Application` instances should be built via the `ApplicationBuilder`.', + stacklevel=2, + ) + + self.bot = bot + self.update_queue = update_queue + self.job_queue = job_queue + self.context_types = context_types + self.updater = updater + self.handlers: Dict[int, List[Handler]] = {} + self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {} + + if isinstance(concurrent_updates, int) and concurrent_updates < 0: + raise ValueError('`concurrent_updates` must be a non-negative integer!') + if concurrent_updates is True: + concurrent_updates = 4096 + self._concurrent_updates_sem = asyncio.BoundedSemaphore(concurrent_updates or 1) + self._concurrent_updates = bool(concurrent_updates) + + if self.job_queue: + self.job_queue.set_application(self) + + self.bot_data = self.context_types.bot_data() + self.persistence: Optional[BasePersistence] = None + if persistence and not isinstance(persistence, BasePersistence): + raise TypeError("persistence must be based on telegram.ext.BasePersistence") + self.persistence = persistence + # Track access to chat_ids only if necessary for the persistence + if self.persistence and self.persistence.store_data.user_data: + self._user_data: MutableMapping[int, UD] = TrackingDefaultDict( + default_factory=self.context_types.user_data, track_read=True, track_write=True + ) + else: + self._user_data = defaultdict(self.context_types.user_data) + # Track access to user_ids only if necessary for the persistence + if self.persistence and self.persistence.store_data.chat_data: + self._chat_data: MutableMapping[int, CD] = TrackingDefaultDict( + # track_write = True for self.migrate_chat_data + default_factory=self.context_types.chat_data, + track_read=True, + track_write=True, + ) + else: + self._chat_data = defaultdict(self.context_types.chat_data) + # Read only mapping + self.user_data: Mapping[int, UD] = MappingProxyType(self._user_data) + self.chat_data: Mapping[int, CD] = MappingProxyType(self._chat_data) + + # This attribute will hold references to the conversation dicts of all conversation + # handlers so that we can extract the changed states during `update_persistence` + self._conversation_handler_conversations: Dict[ + str, TrackingDefaultDict[Tuple[int, ...], object] + ] = {} + + # A number of low-level helpers for the internal logic + self._initialized = False + self._running = False + self.__update_fetcher_task: Optional[asyncio.Task] = None + self.__update_persistence_task: Optional[asyncio.Task] = None + self.__update_persistence_event = asyncio.Event() + self.__update_persistence_lock = asyncio.Lock() + self.__create_task_tasks: Set[asyncio.Task] = set() + + @property + def running(self) -> bool: + """:obj:`bool`: Indicates if this application is running. + + .. seealso:: + :meth:`start`, :meth:`stop` + """ + return self._running + + @property + def concurrent_updates(self) -> bool: + return self._concurrent_updates + + async def initialize(self) -> None: + await self.bot.initialize() + if self.updater: + await self.updater.initialize() + + if not self.persistence: + self._initialized = True + return + + await self._initialize_persistence() + + # Unfortunately due to circular imports this has to be here + # pylint: disable=import-outside-toplevel + from telegram.ext._conversationhandler import ConversationHandler + + # Initialize the persistent conversation handlers with the stored states + for handler in itertools.chain.from_iterable(self.handlers.values()): + if isinstance(handler, ConversationHandler) and handler.persistent and handler.name: + self._conversation_handler_conversations[ + handler.name + ] = await handler._initialize_persistence( # pylint: disable=protected-access + self + ) + + self._initialized = True + + async def shutdown(self) -> None: + await self.bot.shutdown() + if self.updater: + await self.updater.shutdown() + + if self.persistence: + _logger.debug('Updating & flushing persistence before shutdown') + await self.update_persistence() + await self.persistence.flush() + _logger.debug('Updated and flushed persistence') + + self._initialized = False + + async def __aenter__(self: _DispType) -> _DispType: + try: + await self.initialize() + return self + except Exception as exc: + await self.shutdown() + raise exc + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # Make sure not to return `True` so that exceptions are not suppressed + # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ + await self.shutdown() + + async def _initialize_persistence(self) -> None: + if not self.persistence: + return + + # This raises an exception if persistence.store_data.callback_data is True + # but self.bot is not an instance of ExtBot - so no need to check that later on + self.persistence.set_bot(self.bot) + + if self.persistence.store_data.user_data: + cast(TrackingDefaultDict, self._user_data).update_no_track( + await self.persistence.get_user_data() + ) + if self.persistence.store_data.chat_data: + cast(TrackingDefaultDict, self._chat_data).update_no_track( + await self.persistence.get_chat_data() + ) + if self.persistence.store_data.bot_data: + self.bot_data = await self.persistence.get_bot_data() + if not isinstance(self.bot_data, self.context_types.bot_data): + raise ValueError( + f"bot_data must be of type {self.context_types.bot_data.__name__}" + ) + if self.persistence.store_data.callback_data: + persistent_data = await self.persistence.get_callback_data() + if persistent_data is not None: + if not isinstance(persistent_data, tuple) and len(persistent_data) != 2: + raise ValueError('callback_data must be a tuple of length 2') + # Mypy doesn't know that persistence.set_bot (see above) already checks that + # self.bot is an instance of ExtBot if callback_data should be stored ... + self.bot.callback_data_cache = CallbackDataCache( # type: ignore[attr-defined] + self.bot, # type: ignore[arg-type] + self.bot.callback_data_cache.maxsize, # type: ignore[attr-defined] + persistent_data=persistent_data, + ) + + @staticmethod + def builder() -> 'InitApplicationBuilder': + """Convenience method. Returns a new :class:`telegram.ext.ApplicationBuilder`. + + .. versionadded:: 14.0 + """ + # Unfortunately this needs to be here due to cyclical imports + from telegram.ext import ApplicationBuilder # pylint: disable=import-outside-toplevel + + return ApplicationBuilder() + + async def start(self, ready: Event = None) -> None: + """Starts + + * a background task that fetches updates from :attr:`update_queue` and + processes them. + * :attr:`job_queue`, if set + * a background tasks that calls :meth:`update_persistence` in regular intervals, if + :attr:`persistence` is set. + + Note: + This does *not* start fetching updates from Telegram. You need either start + :attr:`updater` manually or use one of :attr:`run_polling` or :attr:`run_webhook`. + + Args: + ready (:obj:`asyncio.Event`, optional): If specified, the event will be set once the + application is ready. + + """ + if self.running: + _logger.warning('already running') + if ready is not None: + ready.set() + return + + self.__update_persistence_event.clear() + if self.persistence: + self.__update_persistence_task = asyncio.create_task( + self._persistence_updater() + # TODO: Add this once we drop py3.7 + # name=f'Application:{self.bot.id}:persistence_updater' + ) + _logger.debug('Loop for updating persistence started') + + if self.job_queue: + self.job_queue.start() + _logger.debug('JobQueue started') + + self.__update_fetcher_task = asyncio.create_task( + self._update_fetcher(), + # TODO: Add this once we drop py3.7 + # name=f'Application:{self.bot.id}:update_fetcher' + ) + self._running = True + _logger.info('Application started') + + if ready is not None: + ready.set() + + async def stop(self) -> None: + """Stops the process after processing any pending updates or tasks created by + :meth:`create_task`. Also stops :attr:`job_queue`, if set and :attr:`updater`, if set and + running. + Finally, calls :meth:`update_persistence` and :meth:`BasePersistence.flush` on + :attr:`persistence`, if set. + + Warning: + Once this method is called, no more updates will be fetched from :attr:`update_queue`, + even if it's not empty. + """ + if self.running: + self._running = False + _logger.info('Application is stopping. This might take a moment.') + + if self.updater and self.updater.running: + _logger.debug('Waiting for updater to stop fetching updates') + await self.updater.stop() + + # Stop listening for new updates and handle all pending ones + await self.update_queue.put(_STOP_SIGNAL) + _logger.debug('Waiting for update_queue to join') + await self.update_queue.join() + if self.__update_fetcher_task: + await self.__update_fetcher_task + _logger.debug("Application stopped fetching of updates.") + + if self.job_queue: + _logger.debug('Waiting for running jobs to finish') + await self.job_queue.stop(wait=True) + _logger.debug('JobQueue stopped') + + _logger.debug('Waiting for `create_task` calls to be processed') + await asyncio.gather(*self.__create_task_tasks, return_exceptions=True) + + # Make sure that this is the *last* step of stopping the application! + if self.persistence and self.__update_persistence_task: + _logger.debug('Waiting for persistence loop to finish') + self.__update_persistence_event.set() + await self.__update_persistence_task + + _logger.info('Application.stop() complete') + + def run_polling( + self, + poll_interval: float = 0.0, + timeout: int = 10, + bootstrap_retries: int = -1, + read_timeout: float = 2, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + allowed_updates: List[str] = None, + drop_pending_updates: bool = None, + ready: asyncio.Event = None, + ) -> None: + if not self.updater: + raise RuntimeError( + 'Application.run_polling is only available if the application has an Updater.' + ) + + def error_callback(exc: TelegramError) -> None: + self.create_task(self.dispatch_error(update=None, error=exc)) + + return self.__run( + updater_coroutine=self.updater.start_polling( + poll_interval=poll_interval, + timeout=timeout, + bootstrap_retries=bootstrap_retries, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + allowed_updates=allowed_updates, + drop_pending_updates=drop_pending_updates, + error_callback=error_callback, + ), + ready=ready, + ) + + def run_webhook( + self, + listen: str = '127.0.0.1', + port: int = 80, + url_path: str = '', + cert: Union[str, Path] = None, + key: Union[str, Path] = None, + bootstrap_retries: int = 0, + webhook_url: str = None, + allowed_updates: List[str] = None, + drop_pending_updates: bool = None, + ip_address: str = None, + max_connections: int = 40, + ready: asyncio.Event = None, + ) -> None: + if not self.updater: + raise RuntimeError( + 'Application.run_webhook is only available if the application has an Updater.' + ) + + return self.__run( + updater_coroutine=self.updater.start_webhook( + listen=listen, + port=port, + url_path=url_path, + cert=cert, + key=key, + bootstrap_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + webhook_url=webhook_url, + allowed_updates=allowed_updates, + ip_address=ip_address, + max_connections=max_connections, + ), + ready=ready, + ) + + def __run(self, updater_coroutine: Coroutine, ready: asyncio.Event = None) -> None: + loop = asyncio.get_event_loop() + loop.run_until_complete(self.initialize()) + loop.run_until_complete(self.start(ready=ready)) + loop.run_until_complete(updater_coroutine) + try: + loop.run_forever() + except (KeyboardInterrupt, SystemExit): + loop.run_until_complete(self.stop()) + loop.run_until_complete(self.shutdown()) + finally: + loop.close() + + def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Task: + """Thin wrapper around :meth:`asyncio.create_task` that handles exceptions raised by + the ``coroutine`` with :meth:`dispatch_error`. + + Note: + * If ``coroutine`` raises an exception, it will be set on the task created by this + method even though it's handled by :meth:`dispatch_error`. + * If the application is currently running, tasks created by this methods will be + awaited by :meth:`stop`. + + Args: + coroutine: The coroutine to run as task. + update: Optional. If passed, will be passed to :meth:`dispatch_error` as additional + information for the error handlers. + + Returns: + :class:`asyncio.Task`: The created task. + """ + return self.__create_task(coroutine=coroutine, update=update) + + def __create_task( + self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False + ) -> asyncio.Task: + # Unfortunately, we can't know if `coroutine` runs one of the error handler functions + # but by passing `is_error_handler=True` from `dispatch_error`, we can make sure that we + # get at most one recursion of the user calls `create_task` manually with an error handler + # function + task = asyncio.create_task( + self.__create_task_callback( + coroutine=coroutine, update=update, is_error_handler=is_error_handler + ) + ) + + if self.running: + self.__create_task_tasks.add(task) + task.add_done_callback(self.__create_task_tasks.discard) + else: + _logger.warning( + "Tasks created via `Application.create_task` while the application is not " + "running won't be automatically awaited!" + ) + + return task + + async def __create_task_callback( + self, + coroutine: Coroutine[Any, Any, _RT], + update: object = None, + is_error_handler: bool = False, + ) -> _RT: + try: + return await coroutine + except Exception as exception: + if isinstance(exception, ApplicationHandlerStop): + warn( + 'ApplicationHandlerStop is not supported with asynchronously running handlers.' + ) + + # Avoid infinite recursion of error handlers. + elif is_error_handler: + _logger.exception( + 'An error was raised and an uncaught error was raised while ' + 'handling the error with an error_handler.', + exc_info=exception, + ) + + else: + # If we arrive here, an exception happened in the task and was neither + # ApplicationHandlerStop nor raised by an error handler. + # So we can and must handle it + self.create_task(self.dispatch_error(update, exception, coroutine=coroutine)) + + raise exception + + async def _update_fetcher(self) -> None: + # Continuously fetch updates from the queue. Exit only once the signal object is found. + while True: + update = await self.update_queue.get() + + if update is _STOP_SIGNAL: + _logger.debug('Dropping pending updates') + while not self.update_queue.empty(): + self.update_queue.task_done() + + # For the _STOP_SIGNAL + self.update_queue.task_done() + return + + _logger.debug('Processing update %s', update) + + if self._concurrent_updates: + asyncio.create_task(self.__process_update_wrapper(update)) + else: + await self.__process_update_wrapper(update) + + async def __process_update_wrapper(self, update: object) -> None: + async with self._concurrent_updates_sem: + await self.process_update(update) + self.update_queue.task_done() + + async def process_update(self, update: object) -> None: + """Processes a single update and updates the persistence. + + .. versionchanged:: 14.0 + This calls :meth:`update_persistence` exactly once after handling of the update was + finished by *all* handlers that handled the update, including asynchronously running + handlers. + + Args: + update (:class:`telegram.Update` | :obj:`object` | \ + :class:`telegram.error.TelegramError`): + The update to process. + + """ + # An error happened while polling + if isinstance(update, TelegramError): + await self.dispatch_error(None, update) + return + + context = None + + for handlers in self.handlers.values(): + try: + for handler in handlers: + check = handler.check_update(update) + if check is not None and check is not False: + if not context: + context = self.context_types.context.from_update(update, self) + await context.refresh_data() + coroutine: Coroutine = handler.handle_update(update, self, check, context) + if handler.block: + await coroutine + else: + self.create_task(coroutine, update=update) + break + + # Stop processing with any other handler. + except ApplicationHandlerStop: + _logger.debug('Stopping further handlers due to ApplicationHandlerStop') + break + + # Dispatch any error. + except Exception as exc: + if await self.dispatch_error(update, exc): + _logger.debug('Error handler stopped further handlers.') + break + + def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: + """Register a handler. + + TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of + update with :class:`telegram.ext.ApplicationHandlerStop`. + + A handler must be an instance of a subclass of :class:`telegram.ext.Handler`. All handlers + are organized in groups with a numeric value. The default group is 0. All groups will be + evaluated for handling an update, but only 0 or 1 handler per group will be used. If + :class:`telegram.ext.ApplicationHandlerStop` is raised from one of the handlers, no further + handlers (regardless of the group) will be called. + + The priority/order of handlers is determined as follows: + + * Priority of the group (lower group number == higher priority) + * The first handler in a group which should handle an update (see + :attr:`telegram.ext.Handler.check_update`) will be used. Other handlers from the + group will not be used. The order in which handlers were added to the group defines the + priority. + + Args: + handler (:class:`telegram.ext.Handler`): A Handler instance. + group (:obj:`int`, optional): The group identifier. Default is 0. + + """ + # Unfortunately due to circular imports this has to be here + # pylint: disable=import-outside-toplevel + from telegram.ext._conversationhandler import ConversationHandler + + if not isinstance(handler, Handler): + raise TypeError(f'handler is not an instance of {Handler.__name__}') + if not isinstance(group, int): + raise TypeError('group is not int') + if isinstance(handler, ConversationHandler) and handler.persistent and handler.name: + if not self.persistence: + raise ValueError( + f"ConversationHandler {handler.name} " + f"can not be persistent if application has no persistence" + ) + if self._initialized: + warn( + 'A persistent `ConversationHandler` was passed to `add_handler`, ' + 'after `Application.initialize` was called. Conversation states will not be ' + 'loaded from persistence! ' + ) + + if group not in self.handlers: + self.handlers[group] = [] + self.handlers = dict(sorted(self.handlers.items())) # lower -> higher groups + + self.handlers[group].append(handler) + + def add_handlers( + self, + handlers: Union[ + Union[List[Handler], Tuple[Handler]], Dict[int, Union[List[Handler], Tuple[Handler]]] + ], + group: DVInput[int] = DefaultValue(0), + ) -> None: + """Registers multiple handlers at once. The order of the handlers in the passed + sequence(s) matters. See :meth:`add_handler` for details. + + .. versionadded:: 14.0 + .. seealso:: :meth:`add_handler` + + Args: + handlers (List[:obj:`telegram.ext.Handler`] | \ + Dict[int, List[:obj:`telegram.ext.Handler`]]): \ + Specify a sequence of handlers *or* a dictionary where the keys are groups and + values are handlers. + group (:obj:`int`, optional): Specify which group the sequence of ``handlers`` + should be added to. Defaults to ``0``. + + """ + if isinstance(handlers, dict) and not isinstance(group, DefaultValue): + raise ValueError('The `group` argument can only be used with a sequence of handlers.') + + if isinstance(handlers, dict): + for handler_group, grp_handlers in handlers.items(): + if not isinstance(grp_handlers, (list, tuple)): + raise ValueError(f'Handlers for group {handler_group} must be a list or tuple') + + for handler in grp_handlers: + self.add_handler(handler, handler_group) + + elif isinstance(handlers, (list, tuple)): + for handler in handlers: + self.add_handler(handler, DefaultValue.get_value(group)) + + else: + raise ValueError( + "The `handlers` argument must be a sequence of handlers or a " + "dictionary where the keys are groups and values are sequences of handlers." + ) + + def remove_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None: + """Remove a handler from the specified group. + + Args: + handler (:class:`telegram.ext.Handler`): A Handler instance. + group (:obj:`object`, optional): The group identifier. Default is 0. + + """ + if handler in self.handlers[group]: + self.handlers[group].remove(handler) + if not self.handlers[group]: + del self.handlers[group] + + async def drop_chat_data(self, chat_id: int) -> None: + """Used for deleting a key from the :attr:`chat_data`. + + .. versionadded:: 14.0 + + Args: + chat_id (:obj:`int`): The chat id to delete from the persistence. The entry + will be deleted even if it is not empty. + """ + self._chat_data.pop(chat_id, None) # type: ignore[arg-type] + + async def drop_user_data(self, user_id: int) -> None: + """Used for deleting a key from the :attr:`user_data`. + + .. versionadded:: 14.0 + + Args: + user_id (:obj:`int`): The user id to delete from the persistence. The entry + will be deleted even if it is not empty. + """ + self._user_data.pop(user_id, None) # type: ignore[arg-type] + + async def migrate_chat_data( + self, message: 'Message' = None, old_chat_id: int = None, new_chat_id: int = None + ) -> None: + """Moves the contents of :attr:`chat_data` at key old_chat_id to the key new_chat_id. + Also updates the persistence by calling :attr:`update_persistence`. + Warning: + * Any data stored in :attr:`chat_data` at key `new_chat_id` will be overridden + * The key `old_chat_id` of :attr:`chat_data` will be deleted + Args: + message (:class:`Message`, optional): A message with either + :attr:`telegram.Message.migrate_from_chat_id` or + :attr:`telegram.Message.migrate_to_chat_id`. + Mutually exclusive with passing ``old_chat_id`` and ``new_chat_id`` + .. seealso: `telegram.ext.filters.StatusUpdate.MIGRATE` + old_chat_id (:obj:`int`, optional): The old chat ID. + Mutually exclusive with passing ``message`` + new_chat_id (:obj:`int`, optional): The new chat ID. + Mutually exclusive with passing ``message`` + """ + if message and (old_chat_id or new_chat_id): + raise ValueError("Message and chat_id pair are mutually exclusive") + if not any((message, old_chat_id, new_chat_id)): + raise ValueError("chat_id pair or message must be passed") + + if message: + if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: + raise ValueError( + "Invalid message instance. The message must have either " + "`Message.migrate_from_chat_id` or `Message.migrate_to_chat_id`." + ) + + old_chat_id = message.migrate_from_chat_id or message.chat.id + new_chat_id = message.migrate_to_chat_id or message.chat.id + + elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): + raise ValueError("old_chat_id and new_chat_id must be integers") + + self._chat_data[new_chat_id] = self._chat_data[old_chat_id] + + async def _persistence_updater(self) -> None: + # Update the persistence in regular intervals. Exit only when the stop event has been set + while not self.__update_persistence_event.is_set(): + if not self.persistence: + return + + await self.update_persistence() + try: + await asyncio.wait_for( + self.__update_persistence_event.wait(), + timeout=self.persistence.update_interval, + ) + except asyncio.TimeoutError: + return + + async def update_persistence(self) -> None: + """Updates :attr:`user_data`, :attr:`chat_data`, :attr:`bot_data` in :attr:`persistence` + along with :attr:`~telegram.ExtBot.callback_data_cache` and the conversation states of + any persistent :class:`~telegram.ext.ConversationHandler` registered for this application. + + For :attr:`user_data`, :attr:`chat_data`, only entries accessed since the last run of this + method are updated. + + Tip: + This method will be called in regular intervals by the application. There is usually + no need to call it manually. + + .. seealso:: :attr:`telegram.ext.BasePersistence.update_interval`. + """ + async with self.__update_persistence_lock: + await self.__update_persistence() + + async def __update_persistence(self) -> None: + if not self.persistence: + return + + _logger.debug('Starting next run of updating the persistence.') + + coroutines: Set[Coroutine] = set() + + if self.persistence.store_data.callback_data: + # Mypy doesn't know that persistence.set_bot (see above) already checks that + # self.bot is an instance of ExtBot if callback_data should be stored ... + coroutines.add( + self.persistence.update_callback_data( + deepcopy( + self.bot.callback_data_cache.persistence_data # type: ignore[attr-defined] + ) + ) + ) + + if self.persistence.store_data.bot_data: + coroutines.add(self.persistence.update_bot_data(deepcopy(self.bot_data))) + + if self.persistence.store_data.chat_data: + # Mypy can't handle the conditional assignment in `__init__` + chat_data = cast(TrackingDefaultDict, self._chat_data) + for chat_id, data in chat_data.pop_accessed_read_items(): + coroutines.add(self.persistence.update_chat_data(chat_id, deepcopy(data))) + for chat_id, data in chat_data.pop_accessed_write_items(): + if data is not chat_data.DELETED: + _logger.critical('`Application._chat_data[%s]` was written manually', chat_id) + coroutines.add(self.persistence.update_chat_data(chat_id, deepcopy(data))) + else: + coroutines.add(self.persistence.drop_chat_data(chat_id)) + + if self.persistence.store_data.user_data: + # Mypy can't handle the conditional assignment in `__init__` + user_data = cast(TrackingDefaultDict, self._user_data) + for user_id, data in user_data.pop_accessed_read_items(): + coroutines.add(self.persistence.update_user_data(user_id, deepcopy(data))) + for user_id, data in user_data.pop_accessed_write_items(): + if data is not user_data.DELETED: + _logger.critical('`Application._user_data[%s]` was written manually', user_id) + coroutines.add(self.persistence.update_user_data(user_id, deepcopy(data))) + else: + coroutines.add(self.persistence.drop_user_data(user_id)) + + for name, (key, new_state) in itertools.chain.from_iterable( + zip(itertools.repeat(name), states_dict.pop_accessed_write_items()) + for name, states_dict in self._conversation_handler_conversations.items() + ): + if isinstance(new_state, tuple) and isinstance(new_state[1], asyncio.Task): + # If the handler was running non-blocking, we check if the new state is already + # available. Otherwise, we update with the old state, which is the next best + # guess. + # Note that when updating the persistence one last time during self.stop(), + # *all* tasks will be done. + try: + result = new_state[1].result() + coroutines.add( + self.persistence.update_conversation(name=name, key=key, new_state=result) + ) + except (asyncio.InvalidStateError, asyncio.CancelledError): + effective_new_state = ( + None if new_state[0] is TrackingDefaultDict.DELETED else new_state[0] + ) + coroutines.add( + self.persistence.update_conversation( + name=name, + key=key, + new_state=effective_new_state, + ) + ) + + results = await asyncio.gather(*coroutines, return_exceptions=True) + _logger.debug('Finished updating persistence.') + + # dispatch any errors + await asyncio.gather( + self.dispatch_error(update=None, error=result) + for result in results + if isinstance(result, Exception) + ) + + def add_error_handler( + self, + callback: HandlerCallback[object, CCT, None], + block: DVInput[bool] = DEFAULT_TRUE, + ) -> None: + """Registers an error handler in the Application. This handler will receive every error + which happens in your bot. See the docs of :meth:`dispatch_error` for more details on how + errors are handled. + + Note: + Attempts to add the same callback multiple times will be ignored. + + Args: + callback (:obj:`callable`): The callback function for this error handler. Will be + called when an error is raised. Callback signature: + ``def callback(update: object, context: CallbackContext)``. + The error that happened will be present in ``context.error``. + block (:obj:`bool`, optional): Determines whether the return value of the callback + should be awaited before processing the next error handler in + :meth:`dispatch_error`. Defaults to :obj:`True`. + """ + if callback in self.error_handlers: + _logger.warning('The callback is already registered as an error handler. Ignoring.') + return + + if ( + block is DEFAULT_TRUE + and isinstance(self.bot, ExtBot) + and self.bot.defaults + and not self.bot.defaults.block + ): + block = False + + self.error_handlers[callback] = block + + def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None: + """Removes an error handler. + + Args: + callback (:obj:`callable`): The error handler to remove. + + """ + self.error_handlers.pop(callback, None) + + async def dispatch_error( + self, + update: Optional[object], + error: Exception, + job: 'Job' = None, + coroutine: Coroutine = None, + ) -> bool: + """Dispatches an error by passing it to all error handlers registered with + :meth:`add_error_handler`. If one of the error handlers raises + :class:`telegram.ext.ApplicationHandlerStop`, the update will not be handled by other error + handlers or handlers (even in other groups). All other exceptions raised by an error + handler will just be logged. + + .. versionchanged:: 14.0 + + * Exceptions raised by error handlers are now properly logged. + * :class:`telegram.ext.ApplicationHandlerStop` is no longer reraised but converted into + the return value. + + Args: + update (:obj:`object` | :class:`telegram.Update`): The update that caused the error. + error (:obj:`Exception`): The error that was raised. + job (:class:`telegram.ext.Job`, optional): The job that caused the error. + + .. versionadded:: 14.0 + + Returns: + :obj:`bool`: :obj:`True` if one of the error handlers raised + :class:`telegram.ext.ApplicationHandlerStop`. :obj:`False`, otherwise. + """ + if self.error_handlers: + for ( + callback, + block, + ) in self.error_handlers.items(): # pylint: disable=redefined-outer-name + context = self.context_types.context.from_error( + update=update, + error=error, + application=self, + job=job, + coroutine=coroutine, + ) + if not block: + self.__create_task( + callback(update, context), update=update, is_error_handler=True + ) + else: + try: + await callback(update, context) + except ApplicationHandlerStop: + return True + except Exception as exc: + _logger.exception( + 'An error was raised and an uncaught error was raised while ' + 'handling the error with an error_handler.', + exc_info=exc, + ) + return False + + _logger.exception('No error handlers are registered, logging exception.', exc_info=error) + return False diff --git a/telegram/ext/_basepersistence.py b/telegram/ext/_basepersistence.py index 598d84dd264..056513948cf 100644 --- a/telegram/ext/_basepersistence.py +++ b/telegram/ext/_basepersistence.py @@ -19,7 +19,16 @@ """This module contains the BasePersistence class.""" from abc import ABC, abstractmethod from copy import copy -from typing import Dict, Optional, Tuple, cast, ClassVar, Generic, NamedTuple +from typing import ( + Dict, + Optional, + Tuple, + cast, + ClassVar, + Generic, + NamedTuple, + NoReturn, +) from telegram import Bot from telegram.ext import ExtBot @@ -100,6 +109,12 @@ class BasePersistence(Generic[UD, CD, BD], ABC): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. + update_interval (:obj:`int` | :obj:`float:, optional): The + :class:`~telegram.ext.Application` will update + the persistence in regular intervals. This parameter specifies the time (in seconds) to + wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. + + .. versionadded:: 14.0 Attributes: store_data (:class:`PersistenceInput`): Specifies which kinds of data will be saved by this @@ -109,6 +124,7 @@ class BasePersistence(Generic[UD, CD, BD], ABC): __slots__ = ( 'bot', 'store_data', + '_update_interval', '__dict__', # __dict__ is included because we replace methods in the __new__ ) @@ -132,33 +148,33 @@ def __new__( update_bot_data = instance.update_bot_data update_callback_data = instance.update_callback_data - def get_user_data_insert_bot() -> Dict[int, UD]: - return instance.insert_bot(get_user_data()) + async def get_user_data_insert_bot() -> Dict[int, UD]: + return instance.insert_bot(await get_user_data()) - def get_chat_data_insert_bot() -> Dict[int, CD]: - return instance.insert_bot(get_chat_data()) + async def get_chat_data_insert_bot() -> Dict[int, CD]: + return instance.insert_bot(await get_chat_data()) - def get_bot_data_insert_bot() -> BD: - return instance.insert_bot(get_bot_data()) + async def get_bot_data_insert_bot() -> BD: + return instance.insert_bot(await get_bot_data()) - def get_callback_data_insert_bot() -> Optional[CDCData]: - cdc_data = get_callback_data() + async def get_callback_data_insert_bot() -> Optional[CDCData]: + cdc_data = await get_callback_data() if cdc_data is None: return None return instance.insert_bot(cdc_data[0]), cdc_data[1] - def update_user_data_replace_bot(user_id: int, data: UD) -> None: - return update_user_data(user_id, instance.replace_bot(data)) + async def update_user_data_replace_bot(user_id: int, data: UD) -> None: + return await update_user_data(user_id, instance.replace_bot(data)) - def update_chat_data_replace_bot(chat_id: int, data: CD) -> None: - return update_chat_data(chat_id, instance.replace_bot(data)) + async def update_chat_data_replace_bot(chat_id: int, data: CD) -> None: + return await update_chat_data(chat_id, instance.replace_bot(data)) - def update_bot_data_replace_bot(data: BD) -> None: - return update_bot_data(instance.replace_bot(data)) + async def update_bot_data_replace_bot(data: BD) -> None: + return await update_bot_data(instance.replace_bot(data)) - def update_callback_data_replace_bot(data: CDCData) -> None: + async def update_callback_data_replace_bot(data: CDCData) -> None: obj_data, queue = data - return update_callback_data((instance.replace_bot(obj_data), queue)) + return await update_callback_data((instance.replace_bot(obj_data), queue)) # Adds to __dict__ setattr(instance, 'get_user_data', get_user_data_insert_bot) @@ -171,11 +187,31 @@ def update_callback_data_replace_bot(data: CDCData) -> None: setattr(instance, 'update_callback_data', update_callback_data_replace_bot) return instance - def __init__(self, store_data: PersistenceInput = None): + def __init__( + self, + store_data: PersistenceInput = None, + update_interval: float = 60, + ): self.store_data = store_data or PersistenceInput() + self._update_interval = update_interval self.bot: Bot = None # type: ignore[assignment] + @property + def update_interval(self) -> float: + """:obj:`int`, optional): Time (in seconds) that the :class:`~telegram.ext.Application` + will wait between two consecutive runs of updating the persistence. + + .. versionadded:: 14.0 + """ + return self._update_interval + + @update_interval.setter + def update_interval(self, value: object) -> NoReturn: # pylint: disable=no-self-use + raise AttributeError( + "You can not assign a new value to update_interval after initialization." + ) + def set_bot(self, bot: Bot) -> None: """Set the Bot to be used by this persistence instance. @@ -395,8 +431,8 @@ def _insert_bot(self, obj: object, memo: Dict[int, object]) -> object: return obj @abstractmethod - def get_user_data(self) -> Dict[int, UD]: - """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + async def get_user_data(self) -> Dict[int, UD]: + """Will be called by :class:`telegram.ext.Application` upon creation with a persistence object. It should return the ``user_data`` if stored, or an empty :obj:`dict`. In the latter case, the dictionary should produce values corresponding to one of the following: @@ -414,8 +450,8 @@ def get_user_data(self) -> Dict[int, UD]: """ @abstractmethod - def get_chat_data(self) -> Dict[int, CD]: - """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + async def get_chat_data(self) -> Dict[int, CD]: + """Will be called by :class:`telegram.ext.Application` upon creation with a persistence object. It should return the ``chat_data`` if stored, or an empty :obj:`dict`. In the latter case, the dictionary should produce values corresponding to one of the following: @@ -433,8 +469,8 @@ def get_chat_data(self) -> Dict[int, CD]: """ @abstractmethod - def get_bot_data(self) -> BD: - """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + async def get_bot_data(self) -> BD: + """Will be called by :class:`telegram.ext.Application` upon creation with a persistence object. It should return the ``bot_data`` if stored, or an empty :obj:`dict`. In the latter case, the :obj:`dict` should produce values corresponding to one of the following: @@ -449,8 +485,8 @@ def get_bot_data(self) -> BD: """ @abstractmethod - def get_callback_data(self) -> Optional[CDCData]: - """Will be called by :class:`telegram.ext.Dispatcher` upon creation with a + async def get_callback_data(self) -> Optional[CDCData]: + """Will be called by :class:`telegram.ext.Application` upon creation with a persistence object. If callback data was stored, it should be returned. .. versionadded:: 13.6 @@ -465,8 +501,8 @@ def get_callback_data(self) -> Optional[CDCData]: """ @abstractmethod - def get_conversations(self, name: str) -> ConversationDict: - """Will be called by :class:`telegram.ext.Dispatcher` when a + async def get_conversations(self, name: str) -> ConversationDict: + """Will be called by :class:`telegram.ext.Application` when a :class:`telegram.ext.ConversationHandler` is added if :attr:`telegram.ext.ConversationHandler.persistent` is :obj:`True`. It should return the conversations for the handler with `name` or an empty :obj:`dict` @@ -479,7 +515,7 @@ def get_conversations(self, name: str) -> ConversationDict: """ @abstractmethod - def update_conversation( + async def update_conversation( self, name: str, key: Tuple[int, ...], new_state: Optional[object] ) -> None: """Will be called when a :class:`telegram.ext.ConversationHandler` changes states. @@ -492,40 +528,40 @@ def update_conversation( """ @abstractmethod - def update_user_data(self, user_id: int, data: UD) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + async def update_user_data(self, user_id: int, data: UD) -> None: + """Will be called by the :class:`telegram.ext.Application` after a handler has handled an update. Args: user_id (:obj:`int`): The user the data might have been changed for. data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.user_data`): - The :attr:`telegram.ext.Dispatcher.user_data` ``[user_id]``. + The :attr:`telegram.ext.Application.user_data` ``[user_id]``. """ @abstractmethod - def update_chat_data(self, chat_id: int, data: CD) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + async def update_chat_data(self, chat_id: int, data: CD) -> None: + """Will be called by the :class:`telegram.ext.Application` after a handler has handled an update. Args: chat_id (:obj:`int`): The chat the data might have been changed for. data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.chat_data`): - The :attr:`telegram.ext.Dispatcher.chat_data` ``[chat_id]``. + The :attr:`telegram.ext.Application.chat_data` ``[chat_id]``. """ @abstractmethod - def update_bot_data(self, data: BD) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + async def update_bot_data(self, data: BD) -> None: + """Will be called by the :class:`telegram.ext.Application` after a handler has handled an update. Args: data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.bot_data`): - The :attr:`telegram.ext.Dispatcher.bot_data`. + The :attr:`telegram.ext.Application.bot_data`. """ @abstractmethod - def update_callback_data(self, data: CDCData) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` after a handler has + async def update_callback_data(self, data: CDCData) -> None: + """Will be called by the :class:`telegram.ext.Application` after a handler has handled an update. .. versionadded:: 13.6 @@ -540,9 +576,9 @@ def update_callback_data(self, data: CDCData) -> None: """ @abstractmethod - def drop_chat_data(self, chat_id: int) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher`, when using - :meth:`~telegram.ext.Dispatcher.drop_chat_data`. + async def drop_chat_data(self, chat_id: int) -> None: + """Will be called by the :class:`telegram.ext.Application`, when using + :meth:`~telegram.ext.Application.drop_chat_data`. .. versionadded:: 14.0 @@ -551,9 +587,9 @@ def drop_chat_data(self, chat_id: int) -> None: """ @abstractmethod - def drop_user_data(self, user_id: int) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher`, when using - :meth:`~telegram.ext.Dispatcher.drop_user_data`. + async def drop_user_data(self, user_id: int) -> None: + """Will be called by the :class:`telegram.ext.Application`, when using + :meth:`~telegram.ext.Application.drop_user_data`. .. versionadded:: 14.0 @@ -562,8 +598,8 @@ def drop_user_data(self, user_id: int) -> None: """ @abstractmethod - def refresh_user_data(self, user_id: int, user_data: UD) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` before passing the + async def refresh_user_data(self, user_id: int, user_data: UD) -> None: + """Will be called by the :class:`telegram.ext.Application` before passing the :attr:`user_data` to a callback. Can be used to update data stored in :attr:`user_data` from an external source. @@ -579,8 +615,8 @@ def refresh_user_data(self, user_id: int, user_data: UD) -> None: """ @abstractmethod - def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` before passing the + async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: + """Will be called by the :class:`telegram.ext.Application` before passing the :attr:`chat_data` to a callback. Can be used to update data stored in :attr:`chat_data` from an external source. @@ -596,8 +632,8 @@ def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: """ @abstractmethod - def refresh_bot_data(self, bot_data: BD) -> None: - """Will be called by the :class:`telegram.ext.Dispatcher` before passing the + async def refresh_bot_data(self, bot_data: BD) -> None: + """Will be called by the :class:`telegram.ext.Application` before passing the :attr:`bot_data` to a callback. Can be used to update data stored in :attr:`bot_data` from an external source. @@ -612,8 +648,8 @@ def refresh_bot_data(self, bot_data: BD) -> None: """ @abstractmethod - def flush(self) -> None: - """Will be called by :class:`telegram.ext.Updater` upon receiving a stop signal. Gives the + async def flush(self) -> None: + """Will be called by :meth:`telegram.ext.Application.stop`. Gives the persistence a chance to finish up saving or close a database connection gracefully. .. versionchanged:: 14.0 diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index dadb96c0316..a495a161730 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # # A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2021 +# Copyright (C) 2015-2022 # Leandro Toledo de Souza # # This program is free software: you can redistribute it and/or modify @@ -16,14 +16,9 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -# -# Some of the type hints are just ridiculously long ... -# flake8: noqa: E501 -# pylint: disable=line-too-long """This module contains the Builder classes for the telegram.ext module.""" +from asyncio import Queue from pathlib import Path -from queue import Queue -from threading import Event from typing import ( TypeVar, Generic, @@ -32,18 +27,17 @@ Any, Dict, Union, - Optional, - overload, Type, + Optional, ) from telegram import Bot -from telegram.request import Request from telegram._utils.types import ODVInput, DVInput, FilePathInput -from telegram._utils.warnings import warn from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue, DEFAULT_FALSE -from telegram.ext import Dispatcher, JobQueue, Updater, ExtBot, ContextTypes, CallbackContext -from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ, PT +from telegram.ext import Application, JobQueue, ExtBot, ContextTypes, CallbackContext, Updater +from telegram.request._httpxrequest import HTTPXRequest +from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ +from telegram.request import BaseRequest if TYPE_CHECKING: from telegram.ext import ( @@ -53,57 +47,35 @@ # Type hinting is a bit complicated here because we try to get to a sane level of # leveraging generics and therefore need a number of type variables. -ODT = TypeVar('ODT', bound=Union[None, Dispatcher]) -DT = TypeVar('DT', bound=Dispatcher) +OAppT = TypeVar('OAppT', bound=Union[None, Application]) +AppT = TypeVar('AppT', bound=Application) InBT = TypeVar('InBT', bound=Bot) InJQ = TypeVar('InJQ', bound=Union[None, JobQueue]) InPT = TypeVar('InPT', bound=Union[None, 'BasePersistence']) -InDT = TypeVar('InDT', bound=Union[None, Dispatcher]) +InAppT = TypeVar('InAppT', bound=Union[None, Application]) InCCT = TypeVar('InCCT', bound='CallbackContext') InUD = TypeVar('InUD') InCD = TypeVar('InCD') InBD = TypeVar('InBD') -BuilderType = TypeVar('BuilderType', bound='_BaseBuilder') +BuilderType = TypeVar('BuilderType', bound='ApplicationBuilder') CT = TypeVar('CT', bound=Callable[..., Any]) if TYPE_CHECKING: DEF_CCT = CallbackContext.DEFAULT_TYPE # type: ignore[misc] - InitBaseBuilder = _BaseBuilder[ # noqa: F821 # pylint: disable=used-before-assignment - Dispatcher[ExtBot, DEF_CCT, Dict, Dict, Dict, JobQueue, None], - ExtBot, - DEF_CCT, - Dict, - Dict, - Dict, - JobQueue, - None, - ] - InitUpdaterBuilder = UpdaterBuilder[ # noqa: F821 # pylint: disable=used-before-assignment - Dispatcher[ExtBot, DEF_CCT, Dict, Dict, Dict, JobQueue, None], - ExtBot, - DEF_CCT, - Dict, - Dict, - Dict, - JobQueue, - None, - ] - InitDispatcherBuilder = ( - DispatcherBuilder[ # noqa: F821 # pylint: disable=used-before-assignment - Dispatcher[ExtBot, DEF_CCT, Dict, Dict, Dict, JobQueue, None], + InitApplicationBuilder = ( + ApplicationBuilder[ # noqa: F821 # pylint: disable=used-before-assignment ExtBot, DEF_CCT, Dict, Dict, Dict, JobQueue, - None, ] ) _BOT_CHECKS = [ - ('dispatcher', 'Dispatcher instance'), + ('updater', 'Updater instance'), ('request', 'Request instance'), ('request_kwargs', 'request_kwargs'), ('base_file_url', 'base_file_url'), @@ -114,103 +86,131 @@ ('private_key', 'private_key'), ] -_DISPATCHER_CHECKS = [ - ('bot', 'bot instance'), - ('update_queue', 'update_queue'), - ('workers', 'workers'), - ('exception_event', 'exception_event'), - ('job_queue', 'JobQueue instance'), - ('persistence', 'persistence instance'), - ('context_types', 'ContextTypes instance'), - ('dispatcher_class', 'Dispatcher Class'), -] + _BOT_CHECKS -_DISPATCHER_CHECKS.remove(('dispatcher', 'Dispatcher instance')) - _TWO_ARGS_REQ = "The parameter `{}` may only be set, if no {} was set." -# Base class for all builders. We do this mainly to reduce code duplication, because e.g. -# the UpdaterBuilder has all method that the DispatcherBuilder has -class _BaseBuilder(Generic[ODT, BT, CCT, UD, CD, BD, JQ, PT]): - # pylint reports false positives here: +class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): + """This class serves as initializer for :class:`telegram.ext.Application` via the so called + `builder pattern`_. To build a :class:`telegram.ext.Application`, one first initializes an + instance of this class. Arguments for the :class:`telegram.ext.Application` to build are then + added by subsequently calling the methods of the builder. Finally, the + :class:`telegram.ext.Application` is built by calling :meth:`build`. In the simplest case this + can look like the following example. + + Example: + .. code:: python + + application = ApplicationBuilder().token('TOKEN').build() + + Please see the description of the individual methods for information on which arguments can be + set and what the defaults are when not called. When no default is mentioned, the argument will + not be used by default. + + Note: + * Some arguments are mutually exclusive. E.g. after calling :meth:`token`, you can't set + a custom bot with :meth:`bot` and vice versa. + * Unless a custom :class:`telegram.Bot` instance is set via :meth:`bot`, :meth:`build` will + use :class:`telegram.ext.ExtBot` for the bot. + + .. _`builder pattern`: https://en.wikipedia.org/wiki/Builder_pattern. + """ __slots__ = ( '_token', '_base_url', '_base_file_url', - '_request_kwargs', + '_connection_pool_size', + '_proxy_url', + '_connect_timeout', + '_read_timeout', + '_write_timeout', + '_pool_timeout', '_request', + '_get_updates_connection_pool_size', + '_get_updates_proxy_url', + '_get_updates_connect_timeout', + '_get_updates_read_timeout', + '_get_updates_write_timeout', + '_get_updates_pool_timeout', + '_get_updates_request', '_private_key', '_private_key_password', '_defaults', '_arbitrary_callback_data', '_bot', '_update_queue', - '_workers', - '_exception_event', '_job_queue', '_persistence', '_context_types', - '_dispatcher', - '_user_signal_handler', - '_dispatcher_class', - '_dispatcher_kwargs', - '_updater_class', - '_updater_kwargs', + '_application_class', + '_application_kwargs', + '_concurrent_updates', + '_updater', ) - def __init__(self: 'InitBaseBuilder'): + def __init__(self: 'InitApplicationBuilder'): self._token: DVInput[str] = DefaultValue('') self._base_url: DVInput[str] = DefaultValue('https://api.telegram.org/bot') self._base_file_url: DVInput[str] = DefaultValue('https://api.telegram.org/file/bot') - self._request_kwargs: DVInput[Dict[str, Any]] = DefaultValue({}) - self._request: ODVInput['Request'] = DEFAULT_NONE + self._connection_pool_size: Optional[int] = None + self._proxy_url: Optional[str] = None + self._connect_timeout: ODVInput[float] = DEFAULT_NONE + self._read_timeout: ODVInput[float] = DEFAULT_NONE + self._write_timeout: ODVInput[float] = DEFAULT_NONE + self._pool_timeout: ODVInput[float] = DEFAULT_NONE + self._request: DVInput['BaseRequest'] = DEFAULT_NONE + self._get_updates_connection_pool_size: Optional[int] = None + self._get_updates_proxy_url: Optional[str] = None + self._get_updates_connect_timeout: ODVInput[float] = DEFAULT_NONE + self._get_updates_read_timeout: ODVInput[float] = DEFAULT_NONE + self._get_updates_write_timeout: ODVInput[float] = DEFAULT_NONE + self._get_updates_pool_timeout: ODVInput[float] = DEFAULT_NONE + self._get_updates_request: DVInput['BaseRequest'] = DEFAULT_NONE self._private_key: ODVInput[bytes] = DEFAULT_NONE self._private_key_password: ODVInput[bytes] = DEFAULT_NONE self._defaults: ODVInput['Defaults'] = DEFAULT_NONE self._arbitrary_callback_data: DVInput[Union[bool, int]] = DEFAULT_FALSE - self._bot: Bot = DEFAULT_NONE # type: ignore[assignment] + self._bot: DVInput[Bot] = DEFAULT_NONE self._update_queue: DVInput[Queue] = DefaultValue(Queue()) - self._workers: DVInput[int] = DefaultValue(4) - self._exception_event: DVInput[Event] = DefaultValue(Event()) self._job_queue: ODVInput['JobQueue'] = DefaultValue(JobQueue()) self._persistence: ODVInput['BasePersistence'] = DEFAULT_NONE self._context_types: DVInput[ContextTypes] = DefaultValue(ContextTypes()) - self._dispatcher: ODVInput['Dispatcher'] = DEFAULT_NONE - self._user_signal_handler: Optional[Callable[[int, object], Any]] = None - self._dispatcher_class: DVInput[Type[Dispatcher]] = DefaultValue(Dispatcher) - self._dispatcher_kwargs: Dict[str, object] = {} - self._updater_class: Type[Updater] = Updater - self._updater_kwargs: Dict[str, object] = {} - - @staticmethod - def _get_connection_pool_size(workers: DVInput[int]) -> int: - # For the standard use case (Updater + Dispatcher + Bot) - # we need a connection pool the size of: - # * for each of the workers - # * 1 for Dispatcher - # * 1 for Updater (even if webhook is used, we can spare a connection) - # * 1 for JobQueue - # * 1 for main thread - return DefaultValue.get_value(workers) + 4 + self._application_class: DVInput[Type[Application]] = DefaultValue(Application) + self._application_kwargs: Dict[str, object] = {} + self._concurrent_updates: DVInput[Union[int, bool]] = DEFAULT_FALSE + self._updater: ODVInput[Updater] = DEFAULT_NONE + + def _build_request(self, get_updates: bool) -> BaseRequest: + prefix = 'get_updates_' if get_updates else '' + if not isinstance(getattr(self, f'{prefix}request'), DefaultValue): + return getattr(self, f'{prefix}request') + + proxy_url = getattr(self, f'{prefix}proxy_url') + if get_updates: + connection_pool_size = getattr(self, f'{prefix}connection_pool_size') or 1 + else: + connection_pool_size = getattr(self, f'{prefix}connection_pool_size') or 128 + + timeouts = dict( + connect_timeout=getattr(self, f'{prefix}connect_timeout'), + read_timeout=getattr(self, f'{prefix}read_timeout'), + write_timeout=getattr(self, f'{prefix}write_timeout'), + pool_timeout=getattr(self, f'{prefix}pool_timeout'), + ) + effective_timeouts = { + key: value for key, value in timeouts.items() if not isinstance(value, DefaultValue) + } + + return HTTPXRequest( + connection_pool_size=connection_pool_size, + proxy_url=proxy_url, + **effective_timeouts, + ) def _build_ext_bot(self) -> ExtBot: if isinstance(self._token, DefaultValue): raise RuntimeError('No bot token was set.') - if not isinstance(self._request, DefaultValue): - request = self._request - else: - request_kwargs = DefaultValue.get_value(self._request_kwargs) - if ( - 'con_pool_size' - not in request_kwargs # pylint: disable=unsupported-membership-test - ): - request_kwargs[ # pylint: disable=unsupported-assignment-operation - 'con_pool_size' - ] = self._get_connection_pool_size(self._workers) - request = Request(**request_kwargs) # pylint: disable=not-a-mapping - return ExtBot( token=self._token, base_url=DefaultValue.get_value(self._base_url), @@ -219,313 +219,56 @@ def _build_ext_bot(self) -> ExtBot: private_key_password=DefaultValue.get_value(self._private_key_password), defaults=DefaultValue.get_value(self._defaults), arbitrary_callback_data=DefaultValue.get_value(self._arbitrary_callback_data), - request=request, + request=self._build_request(get_updates=False), + get_updates_request=self._build_request(get_updates=True), ) - def _build_dispatcher( - self: '_BaseBuilder[ODT, BT, CCT, UD, CD, BD, JQ, PT]', stack_level: int = 3 - ) -> Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]: + def build( + self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', + ) -> Application[BT, CCT, UD, CD, BD, JQ]: + """Builds a :class:`telegram.ext.Application` with the provided arguments. + + Returns: + :class:`telegram.ext.Application` + """ job_queue = DefaultValue.get_value(self._job_queue) - dispatcher: Dispatcher[ - BT, CCT, UD, CD, BD, JQ, PT + + if isinstance(self._updater, DefaultValue) or self._updater is None: + if isinstance(self._bot, DefaultValue): + bot: Bot = self._build_ext_bot() + else: + bot = self._bot + update_queue = DefaultValue.get_value(self._update_queue) + updater = Updater(bot=bot, update_queue=update_queue) + else: + updater = self._updater + bot = self._updater.bot + update_queue = self._updater.update_queue + + application: Application[ + BT, CCT, UD, CD, BD, JQ ] = DefaultValue.get_value( # type: ignore[call-arg] # pylint: disable=not-callable - self._dispatcher_class + self._application_class )( - bot=self._bot if self._bot is not DEFAULT_NONE else self._build_ext_bot(), - update_queue=DefaultValue.get_value(self._update_queue), - workers=DefaultValue.get_value(self._workers), - exception_event=DefaultValue.get_value(self._exception_event), + bot=bot, + update_queue=update_queue, + updater=updater, + concurrent_updates=DefaultValue.get_value(self._concurrent_updates), job_queue=job_queue, persistence=DefaultValue.get_value(self._persistence), context_types=DefaultValue.get_value(self._context_types), - stack_level=stack_level + 1, - **self._dispatcher_kwargs, + **self._application_kwargs, ) if job_queue is not None: - job_queue.set_dispatcher(dispatcher) - - con_pool_size = self._get_connection_pool_size(self._workers) - actual_size = dispatcher.bot.request.con_pool_size - if actual_size < con_pool_size: - warn( - f'The Connection pool of Request object is smaller ({actual_size}) than the ' - f'recommended value of {con_pool_size}.', - stacklevel=stack_level, - ) - - return dispatcher - - def _build_updater( - self: '_BaseBuilder[ODT, BT, Any, Any, Any, Any, Any, Any]', - ) -> Updater[BT, ODT]: - if isinstance(self._dispatcher, DefaultValue): - dispatcher = self._build_dispatcher(stack_level=4) - return self._updater_class( - dispatcher=dispatcher, - user_signal_handler=self._user_signal_handler, - exception_event=dispatcher.exception_event, - **self._updater_kwargs, # type: ignore[arg-type] - ) - - if self._dispatcher: - exception_event = self._dispatcher.exception_event - bot = self._dispatcher.bot - else: - exception_event = DefaultValue.get_value(self._exception_event) - bot = self._bot or self._build_ext_bot() - - return self._updater_class( # type: ignore[call-arg] - dispatcher=self._dispatcher, - bot=bot, - update_queue=DefaultValue.get_value(self._update_queue), - user_signal_handler=self._user_signal_handler, - exception_event=exception_event, - **self._updater_kwargs, - ) - - @property - def _dispatcher_check(self) -> bool: - return self._dispatcher not in (DEFAULT_NONE, None) - - def _set_dispatcher_class( - self: BuilderType, dispatcher_class: Type[Dispatcher], kwargs: Dict[str, object] = None - ) -> BuilderType: - if self._dispatcher is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('dispatcher_class', 'Dispatcher instance')) - self._dispatcher_class = dispatcher_class - self._dispatcher_kwargs = kwargs or {} - return self - - def _set_updater_class( - self: BuilderType, updater_class: Type[Updater], kwargs: Dict[str, object] = None - ) -> BuilderType: - self._updater_class = updater_class - self._updater_kwargs = kwargs or {} - return self - - def _set_token(self: BuilderType, token: str) -> BuilderType: - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('token', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('token', 'Dispatcher instance')) - self._token = token - return self - - def _set_base_url(self: BuilderType, base_url: str) -> BuilderType: - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('base_url', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('base_url', 'Dispatcher instance')) - self._base_url = base_url - return self - - def _set_base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('base_file_url', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('base_file_url', 'Dispatcher instance')) - self._base_file_url = base_file_url - return self - - def _set_request_kwargs(self: BuilderType, request_kwargs: Dict[str, Any]) -> BuilderType: - if self._request is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('request_kwargs', 'Request instance')) - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('request_kwargs', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('request_kwargs', 'Dispatcher instance')) - self._request_kwargs = request_kwargs - return self - - def _set_request(self: BuilderType, request: Request) -> BuilderType: - if not isinstance(self._request_kwargs, DefaultValue): - raise RuntimeError(_TWO_ARGS_REQ.format('request', 'request_kwargs')) - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('request', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('request', 'Dispatcher instance')) - self._request = request - return self - - def _set_private_key( - self: BuilderType, - private_key: Union[bytes, FilePathInput], - password: Union[bytes, FilePathInput] = None, - ) -> BuilderType: - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('private_key', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('private_key', 'Dispatcher instance')) - - self._private_key = ( - private_key if isinstance(private_key, bytes) else Path(private_key).read_bytes() - ) - if password is None or isinstance(password, bytes): - self._private_key_password = password - else: - self._private_key_password = Path(password).read_bytes() - - return self - - def _set_defaults(self: BuilderType, defaults: 'Defaults') -> BuilderType: - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('defaults', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('defaults', 'Dispatcher instance')) - self._defaults = defaults - return self - - def _set_arbitrary_callback_data( - self: BuilderType, arbitrary_callback_data: Union[bool, int] - ) -> BuilderType: - if self._bot is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('arbitrary_callback_data', 'bot instance')) - if self._dispatcher_check: - raise RuntimeError( - _TWO_ARGS_REQ.format('arbitrary_callback_data', 'Dispatcher instance') - ) - self._arbitrary_callback_data = arbitrary_callback_data - return self - - def _set_bot( - self: '_BaseBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, ' - 'JQ, PT]', - bot: InBT, - ) -> '_BaseBuilder[Dispatcher[InBT, CCT, UD, CD, BD, JQ, PT], InBT, CCT, UD, CD, BD, JQ, PT]': - for attr, error in _BOT_CHECKS: - if ( - not isinstance(getattr(self, f'_{attr}'), DefaultValue) - if attr != 'dispatcher' - else self._dispatcher_check - ): - raise RuntimeError(_TWO_ARGS_REQ.format('bot', error)) - self._bot = bot - return self # type: ignore[return-value] - - def _set_update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('update_queue', 'Dispatcher instance')) - self._update_queue = update_queue - return self - - def _set_workers(self: BuilderType, workers: int) -> BuilderType: - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('workers', 'Dispatcher instance')) - self._workers = workers - return self - - def _set_exception_event(self: BuilderType, exception_event: Event) -> BuilderType: - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('exception_event', 'Dispatcher instance')) - self._exception_event = exception_event - return self - - def _set_job_queue( - self: '_BaseBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - job_queue: InJQ, - ) -> '_BaseBuilder[Dispatcher[BT, CCT, UD, CD, BD, InJQ, PT], BT, CCT, UD, CD, BD, InJQ, PT]': - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('job_queue', 'Dispatcher instance')) - self._job_queue = job_queue - return self # type: ignore[return-value] - - def _set_persistence( - self: '_BaseBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - persistence: InPT, - ) -> '_BaseBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, InPT], BT, CCT, UD, CD, BD, JQ, InPT]': - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('persistence', 'Dispatcher instance')) - self._persistence = persistence - return self # type: ignore[return-value] - - def _set_context_types( - self: '_BaseBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - context_types: 'ContextTypes[InCCT, InUD, InCD, InBD]', - ) -> '_BaseBuilder[Dispatcher[BT, InCCT, InUD, InCD, InBD, JQ, PT], BT, InCCT, InUD, InCD, InBD, JQ, PT]': - if self._dispatcher_check: - raise RuntimeError(_TWO_ARGS_REQ.format('context_types', 'Dispatcher instance')) - self._context_types = context_types - return self # type: ignore[return-value] - - @overload - def _set_dispatcher( - self: '_BaseBuilder[ODT, BT, CCT, UD, CD, BD, JQ, PT]', dispatcher: None - ) -> '_BaseBuilder[None, BT, CCT, UD, CD, BD, JQ, PT]': - ... - - @overload - def _set_dispatcher( - self: BuilderType, dispatcher: Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT] - ) -> '_BaseBuilder[Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT], InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]': - ... + job_queue.set_application(application) - def _set_dispatcher( # type: ignore[misc] - self: BuilderType, - dispatcher: Optional[Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]], - ) -> '_BaseBuilder[Optional[Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]], InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]': - for attr, error in _DISPATCHER_CHECKS: - if not isinstance(getattr(self, f'_{attr}'), DefaultValue): - raise RuntimeError(_TWO_ARGS_REQ.format('dispatcher', error)) - self._dispatcher = dispatcher - return self + return application - def _set_user_signal_handler( - self: BuilderType, user_signal_handler: Callable[[int, object], Any] + def application_class( + self: BuilderType, application_class: Type[Application], kwargs: Dict[str, object] = None ) -> BuilderType: - self._user_signal_handler = user_signal_handler - return self - - -class DispatcherBuilder(_BaseBuilder[ODT, BT, CCT, UD, CD, BD, JQ, PT]): - """This class serves as initializer for :class:`telegram.ext.Dispatcher` via the so called - `builder pattern`_. To build a :class:`telegram.ext.Dispatcher`, one first initializes an - instance of this class. Arguments for the :class:`telegram.ext.Dispatcher` to build are then - added by subsequently calling the methods of the builder. Finally, the - :class:`telegram.ext.Dispatcher` is built by calling :meth:`build`. In the simplest case this - can look like the following example. - - Example: - .. code:: python - - dispatcher = DispatcherBuilder().token('TOKEN').build() - - Please see the description of the individual methods for information on which arguments can be - set and what the defaults are when not called. When no default is mentioned, the argument will - not be used by default. - - Note: - * Some arguments are mutually exclusive. E.g. after calling :meth:`token`, you can't set - a custom bot with :meth:`bot` and vice versa. - * Unless a custom :class:`telegram.Bot` instance is set via :meth:`bot`, :meth:`build` will - use :class:`telegram.ext.ExtBot` for the bot. - - .. seealso:: - :class:`telegram.ext.UpdaterBuilder` - - .. _`builder pattern`: https://en.wikipedia.org/wiki/Builder_pattern. - """ - - __slots__ = () - - # The init is just here for mypy - def __init__(self: 'InitDispatcherBuilder'): - super().__init__() - - def build( - self: 'DispatcherBuilder[ODT, BT, CCT, UD, CD, BD, JQ, PT]', - ) -> Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]: - """Builds a :class:`telegram.ext.Dispatcher` with the provided arguments. - - Returns: - :class:`telegram.ext.Dispatcher` - """ - return self._build_dispatcher() - - def dispatcher_class( - self: BuilderType, dispatcher_class: Type[Dispatcher], kwargs: Dict[str, object] = None - ) -> BuilderType: - """Sets a custom subclass to be used instead of :class:`telegram.ext.Dispatcher`. The + """Sets a custom subclass to be used instead of :class:`telegram.ext.Application`. The subclasses ``__init__`` should look like this .. code:: python @@ -536,28 +279,33 @@ def __init__(self, custom_arg_1, custom_arg_2, ..., **kwargs): self.custom_arg_2 = custom_arg_2 Args: - dispatcher_class (:obj:`type`): A subclass of :class:`telegram.ext.Dispatcher` + application_class (:obj:`type`): A subclass of :class:`telegram.ext.Application` kwargs (Dict[:obj:`str`, :obj:`object`], optional): Keyword arguments for the initialization. Defaults to an empty dict. Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_dispatcher_class(dispatcher_class, kwargs) + self._application_class = application_class + self._application_kwargs = kwargs or {} + return self def token(self: BuilderType, token: str) -> BuilderType: - """Sets the token to be used for :attr:`telegram.ext.Dispatcher.bot`. + """Sets the token to be used for :attr:`telegram.ext.Application.bot`. Args: token (:obj:`str`): The token. Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_token(token) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('token', 'bot instance')) + self._token = token + return self def base_url(self: BuilderType, base_url: str) -> BuilderType: - """Sets the base URL to be used for :attr:`telegram.ext.Dispatcher.bot`. If not called, + """Sets the base URL to be used for :attr:`telegram.ext.Application.bot`. If not called, will default to ``'https://api.telegram.org/bot'``. .. seealso:: :attr:`telegram.Bot.base_url`, `Local Bot API Server BuilderType: base_url (:obj:`str`): The URL. Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_base_url(base_url) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('base_url', 'bot instance')) + self._base_url = base_url + return self def base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: - """Sets the base file URL to be used for :attr:`telegram.ext.Dispatcher.bot`. If not + """Sets the base file URL to be used for :attr:`telegram.ext.Application.bot`. If not called, will default to ``'https://api.telegram.org/file/bot'``. .. seealso:: :attr:`telegram.Bot.base_file_url`, `Local Bot API Server BuilderType: base_file_url (:obj:`str`): The URL. Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_base_file_url(base_file_url) - - def request_kwargs(self: BuilderType, request_kwargs: Dict[str, Any]) -> BuilderType: - """Sets keyword arguments that will be passed to the :class:`telegram.utils.Request` object - that is created when :attr:`telegram.ext.Dispatcher.bot` is created. If not called, no - keyword arguments will be passed. - - .. seealso:: :meth:`request` - - Args: - request_kwargs (Dict[:obj:`str`, :obj:`object`]): The keyword arguments. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_request_kwargs(request_kwargs) - - def request(self: BuilderType, request: Request) -> BuilderType: - """Sets a :class:`telegram.utils.Request` object to be used for - :attr:`telegram.ext.Dispatcher.bot`. - - .. seealso:: :meth:`request_kwargs` - - Args: - request (:class:`telegram.utils.Request`): The request object. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_request(request) - - def private_key( - self: BuilderType, - private_key: Union[bytes, FilePathInput], - password: Union[bytes, FilePathInput] = None, - ) -> BuilderType: - """Sets the private key and corresponding password for decryption of telegram passport data - to be used for :attr:`telegram.ext.Dispatcher.bot`. - - .. seealso:: `passportbot.py `_, `Telegram Passports `_ - - Args: - private_key (:obj:`bytes` | :obj:`str` | :obj:`pathlib.Path`): The private key or the - file path of a file that contains the key. In the latter case, the file's content - will be read automatically. - password (:obj:`bytes` | :obj:`str` | :obj:`pathlib.Path`, optional): The corresponding - password or the file path of a file that contains the password. In the latter case, - the file's content will be read automatically. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_private_key(private_key=private_key, password=password) - - def defaults(self: BuilderType, defaults: 'Defaults') -> BuilderType: - """Sets the :class:`telegram.ext.Defaults` object to be used for - :attr:`telegram.ext.Dispatcher.bot`. - - .. seealso:: `Adding Defaults `_ - - Args: - defaults (:class:`telegram.ext.Defaults`): The defaults. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_defaults(defaults) - - def arbitrary_callback_data( - self: BuilderType, arbitrary_callback_data: Union[bool, int] - ) -> BuilderType: - """Specifies whether :attr:`telegram.ext.Dispatcher.bot` should allow arbitrary objects as - callback data for :class:`telegram.InlineKeyboardButton` and how many keyboards should be - cached in memory. If not called, only strings can be used as callback data and no data will - be stored in memory. - - .. seealso:: `Arbitrary callback_data `_, - `arbitrarycallbackdatabot.py `_ - - Args: - arbitrary_callback_data (:obj:`bool` | :obj:`int`): If :obj:`True` is passed, the - default cache size of 1024 will be used. Pass an integer to specify a different - cache size. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_arbitrary_callback_data(arbitrary_callback_data) - - def bot( - self: 'DispatcherBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, ' - 'JQ, PT]', - bot: InBT, - ) -> 'DispatcherBuilder[Dispatcher[InBT, CCT, UD, CD, BD, JQ, PT], InBT, CCT, UD, CD, BD, JQ, PT]': - """Sets a :class:`telegram.Bot` instance to be used for - :attr:`telegram.ext.Dispatcher.bot`. Instances of subclasses like - :class:`telegram.ext.ExtBot` are also valid. - - Args: - bot (:class:`telegram.Bot`): The bot. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_bot(bot) # type: ignore[return-value] - - def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: - """Sets a :class:`queue.Queue` instance to be used for - :attr:`telegram.ext.Dispatcher.update_queue`, i.e. the queue that the dispatcher will fetch - updates from. If not called, a queue will be instantiated. - - .. seealso:: :attr:`telegram.ext.Updater.update_queue`, - :meth:`telegram.ext.UpdaterBuilder.update_queue` - - Args: - update_queue (:class:`queue.Queue`): The queue. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_update_queue(update_queue) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('base_file_url', 'bot instance')) + self._base_file_url = base_file_url + return self - def workers(self: BuilderType, workers: int) -> BuilderType: - """Sets the number of worker threads to be used for - :meth:`telegram.ext.Dispatcher.run_async`, i.e. the number of callbacks that can be run - asynchronously at the same time. + def _request_check(self, get_updates: bool) -> None: + name = 'get_updates_request' if get_updates else 'request' - .. seealso:: :attr:`telegram.ext.Handler.run_sync`, - :attr:`telegram.ext.Defaults.run_async` + for attr in ('connect_timeout', 'read_timeout', 'write_timeout', 'pool_timeout'): + if not isinstance(getattr(self, f"_{attr}"), DefaultValue): + raise RuntimeError(_TWO_ARGS_REQ.format(name, attr)) + if self._connection_pool_size is not None: + raise RuntimeError(_TWO_ARGS_REQ.format(name, 'connection_pool_size')) + if self._proxy_url is not None: + raise RuntimeError(_TWO_ARGS_REQ.format(name, 'proxy_url')) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format(name, 'bot instance')) - Args: - workers (:obj:`int`): The number of worker threads. + def _request_param_check(self, get_updates: bool) -> None: + if get_updates and self._get_updates_request is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('get_updates_request', 'bot instance')) + if self._request is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('request', 'bot instance')) - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_workers(workers) + if self._bot is not DEFAULT_NONE: + raise RuntimeError( + _TWO_ARGS_REQ.format( + 'get_updates_request' if get_updates else 'request', 'bot instance' + ) + ) - def exception_event(self: BuilderType, exception_event: Event) -> BuilderType: - """Sets a :class:`threading.Event` instance to be used for - :attr:`telegram.ext.Dispatcher.exception_event`. When this event is set, the dispatcher - will stop processing updates. If not called, an event will be instantiated. - If the dispatcher is passed to :meth:`telegram.ext.UpdaterBuilder.dispatcher`, then this - event will also be used for :attr:`telegram.ext.Updater.exception_event`. + def request(self: BuilderType, request: BaseRequest) -> BuilderType: + """Sets a :class:`telegram.utils.Request` object to be used for the ``request`` parameter + of :attr:`telegram.ext.Application.bot`. - .. seealso:: :attr:`telegram.ext.Updater.exception_event`, - :meth:`telegram.ext.UpdaterBuilder.exception_event` + .. seealso:: :meth:`get_updates_request` Args: - exception_event (:class:`threading.Event`): The event. + request (:class:`telegram.request.BaseRequest`): The request object. Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_exception_event(exception_event) - - def job_queue( - self: 'DispatcherBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - job_queue: InJQ, - ) -> 'DispatcherBuilder[Dispatcher[BT, CCT, UD, CD, BD, InJQ, PT], BT, CCT, UD, CD, BD, InJQ, PT]': - """Sets a :class:`telegram.ext.JobQueue` instance to be used for - :attr:`telegram.ext.Dispatcher.job_queue`. If not called, a job queue will be instantiated. - - .. seealso:: `JobQueue `_, `timerbot.py `_ - - Note: - * :meth:`telegram.ext.JobQueue.set_dispatcher` will be called automatically by - :meth:`build`. - * The job queue will be automatically started and stopped by - :meth:`telegram.ext.Dispatcher.start` and :meth:`telegram.ext.Dispatcher.stop`, - respectively. - * When passing :obj:`None`, - :attr:`telegram.ext.ConversationHandler.conversation_timeout` can not be used, as - this uses :attr:`telegram.ext.Dispatcher.job_queue` internally. - - Args: - job_queue (:class:`telegram.ext.JobQueue`, optional): The job queue. Pass :obj:`None` - if you don't want to use a job queue. + self._request_check(get_updates=False) + self._request = request + return self - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_job_queue(job_queue) # type: ignore[return-value] + def connection_pool_size(self: BuilderType, connection_pool_size: int) -> BuilderType: + self._request_param_check(get_updates=False) + self._connection_pool_size = connection_pool_size + return self - def persistence( - self: 'DispatcherBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - persistence: InPT, - ) -> 'DispatcherBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, InPT], BT, CCT, UD, CD, BD, JQ, InPT]': - """Sets a :class:`telegram.ext.BasePersistence` instance to be used for - :attr:`telegram.ext.Dispatcher.persistence`. + def proxy_url(self: BuilderType, proxy_url: str) -> BuilderType: + self._request_param_check(get_updates=False) + self._proxy_url = proxy_url + return self - .. seealso:: `Making your bot persistent `_, - `persistentconversationbot.py `_ + def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType: + self._request_param_check(get_updates=False) + self._connect_timeout = connect_timeout + return self - Warning: - If a :class:`telegram.ext.ContextTypes` instance is set via :meth:`context_types`, - the persistence instance must use the same types! + def read_timeout(self: BuilderType, read_timeout: Optional[float]) -> BuilderType: + self._request_param_check(get_updates=False) + self._read_timeout = read_timeout + return self - Args: - persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence - instance. + def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderType: + self._request_param_check(get_updates=False) + self._write_timeout = write_timeout + return self - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_persistence(persistence) # type: ignore[return-value] + def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderType: + self._request_param_check(get_updates=False) + self._pool_timeout = pool_timeout + return self - def context_types( - self: 'DispatcherBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - context_types: 'ContextTypes[InCCT, InUD, InCD, InBD]', - ) -> 'DispatcherBuilder[Dispatcher[BT, InCCT, InUD, InCD, InBD, JQ, PT], BT, InCCT, InUD, InCD, InBD, JQ, PT]': - """Sets a :class:`telegram.ext.ContextTypes` instance to be used for - :attr:`telegram.ext.Dispatcher.context_types`. + def get_updates_request(self: BuilderType, request: BaseRequest) -> BuilderType: + """Sets a :class:`telegram.utils.Request` object to be used for the ``get_updates_request`` + parameter of :attr:`telegram.ext.Application.bot`. - .. seealso:: `contexttypesbot.py `_ + .. seealso:: :meth:`request` Args: - context_types (:class:`telegram.ext.ContextTypes`, optional): The context types. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_context_types(context_types) # type: ignore[return-value] - - -class UpdaterBuilder(_BaseBuilder[ODT, BT, CCT, UD, CD, BD, JQ, PT]): - """This class serves as initializer for :class:`telegram.ext.Updater` via the so called - `builder pattern`_. To build an :class:`telegram.ext.Updater`, one first initializes an - instance of this class. Arguments for the :class:`telegram.ext.Updater` to build are then - added by subsequently calling the methods of the builder. Finally, the - :class:`telegram.ext.Updater` is built by calling :meth:`build`. In the simplest case this - can look like the following example. - - Example: - .. code:: python - - updater = UpdaterBuilder().token('TOKEN').build() - - Please see the description of the individual methods for information on which arguments can be - set and what the defaults are when not called. When no default is mentioned, the argument will - not be used by default. - - Note: - * Some arguments are mutually exclusive. E.g. after calling :meth:`token`, you can't set - a custom bot with :meth:`bot` and vice versa. - * Unless a custom :class:`telegram.Bot` instance is set via :meth:`bot`, :meth:`build` will - use :class:`telegram.ext.ExtBot` for the bot. - - .. seealso:: - :class:`telegram.ext.DispatcherBuilder` - - .. _`builder pattern`: https://en.wikipedia.org/wiki/Builder_pattern. - """ - - __slots__ = () - - # The init is just here for mypy - def __init__(self: 'InitUpdaterBuilder'): - super().__init__() - - def build( - self: 'UpdaterBuilder[ODT, BT, Any, Any, Any, Any, Any, Any]', - ) -> Updater[BT, ODT]: - """Builds a :class:`telegram.ext.Updater` with the provided arguments. + request (:class:`telegram.request.BaseRequest`): The request object. Returns: - :class:`telegram.ext.Updater` + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._build_updater() + self._request_check(get_updates=True) + self._request = request + return self - def dispatcher_class( - self: BuilderType, dispatcher_class: Type[Dispatcher], kwargs: Dict[str, object] = None + def get_updates_connection_pool_size( + self: BuilderType, get_updates_connection_pool_size: int ) -> BuilderType: - """Sets a custom subclass to be used instead of :class:`telegram.ext.Dispatcher`. The - subclasses ``__init__`` should look like this - - .. code:: python - - def __init__(self, custom_arg_1, custom_arg_2, ..., **kwargs): - super().__init__(**kwargs) - self.custom_arg_1 = custom_arg_1 - self.custom_arg_2 = custom_arg_2 - - Args: - dispatcher_class (:obj:`type`): A subclass of :class:`telegram.ext.Dispatcher` - kwargs (Dict[:obj:`str`, :obj:`object`], optional): Keyword arguments for the - initialization. Defaults to an empty dict. + self._request_param_check(get_updates=True) + self._get_updates_connection_pool_size = get_updates_connection_pool_size + return self - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_dispatcher_class(dispatcher_class, kwargs) + def get_updates_proxy_url(self: BuilderType, get_updates_proxy_url: str) -> BuilderType: + self._request_param_check(get_updates=True) + self._get_updates_proxy_url = get_updates_proxy_url + return self - def updater_class( - self: BuilderType, updater_class: Type[Updater], kwargs: Dict[str, object] = None + def get_updates_connect_timeout( + self: BuilderType, get_updates_connect_timeout: Optional[float] ) -> BuilderType: - """Sets a custom subclass to be used instead of :class:`telegram.ext.Updater`. The - subclasses ``__init__`` should look like this - - .. code:: python - - def __init__(self, custom_arg_1, custom_arg_2, ..., **kwargs): - super().__init__(**kwargs) - self.custom_arg_1 = custom_arg_1 - self.custom_arg_2 = custom_arg_2 - - Args: - updater_class (:obj:`type`): A subclass of :class:`telegram.ext.Updater` - kwargs (Dict[:obj:`str`, :obj:`object`], optional): Keyword arguments for the - initialization. Defaults to an empty dict. - - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_updater_class(updater_class, kwargs) - - def token(self: BuilderType, token: str) -> BuilderType: - """Sets the token to be used for :attr:`telegram.ext.Updater.bot`. - - Args: - token (:obj:`str`): The token. - - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_token(token) - - def base_url(self: BuilderType, base_url: str) -> BuilderType: - """Sets the base URL to be used for :attr:`telegram.ext.Updater.bot`. If not called, - will default to ``'https://api.telegram.org/bot'``. - - .. seealso:: :attr:`telegram.Bot.base_url`, `Local Bot API Server `_, - :meth:`base_url` - - Args: - base_url (:obj:`str`): The URL. - - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_base_url(base_url) - - def base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: - """Sets the base file URL to be used for :attr:`telegram.ext.Updater.bot`. If not - called, will default to ``'https://api.telegram.org/file/bot'``. - - .. seealso:: :attr:`telegram.Bot.base_file_url`, `Local Bot API Server `_, - :meth:`base_file_url` - - Args: - base_file_url (:obj:`str`): The URL. - - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_base_file_url(base_file_url) - - def request_kwargs(self: BuilderType, request_kwargs: Dict[str, Any]) -> BuilderType: - """Sets keyword arguments that will be passed to the :class:`telegram.utils.Request` object - that is created when :attr:`telegram.ext.Updater.bot` is created. If not called, no - keyword arguments will be passed. - - .. seealso:: :meth:`request` - - Args: - request_kwargs (Dict[:obj:`str`, :obj:`object`]): The keyword arguments. - - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_request_kwargs(request_kwargs) - - def request(self: BuilderType, request: Request) -> BuilderType: - """Sets a :class:`telegram.utils.Request` object to be used for - :attr:`telegram.ext.Updater.bot`. + self._request_param_check(get_updates=True) + self._get_updates_connect_timeout = get_updates_connect_timeout + return self - .. seealso:: :meth:`request_kwargs` + def get_updates_read_timeout( + self: BuilderType, get_updates_read_timeout: Optional[float] + ) -> BuilderType: + self._request_param_check(get_updates=True) + self._get_updates_read_timeout = get_updates_read_timeout + return self - Args: - request (:class:`telegram.utils.Request`): The request object. + def get_updates_write_timeout( + self: BuilderType, get_updates_write_timeout: Optional[float] + ) -> BuilderType: + self._request_param_check(get_updates=True) + self._get_updates_write_timeout = get_updates_write_timeout + return self - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_request(request) + def get_updates_pool_timeout( + self: BuilderType, get_updates_pool_timeout: Optional[float] + ) -> BuilderType: + self._request_param_check(get_updates=True) + self._get_updates_pool_timeout = get_updates_pool_timeout + return self def private_key( self: BuilderType, @@ -990,11 +476,11 @@ def private_key( password: Union[bytes, FilePathInput] = None, ) -> BuilderType: """Sets the private key and corresponding password for decryption of telegram passport data - to be used for :attr:`telegram.ext.Updater.bot`. + to be used for :attr:`telegram.ext.Application.bot`. .. seealso:: `passportbot.py `_, `Telegram Passports `_ + /tree/master/examples#passportbotpy>`_, `Telegram Passports + `_ Args: private_key (:obj:`bytes` | :obj:`str` | :obj:`pathlib.Path`): The private key or the @@ -1005,13 +491,24 @@ def private_key( the file's content will be read automatically. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_private_key(private_key=private_key, password=password) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('private_key', 'bot instance')) + + self._private_key = ( + private_key if isinstance(private_key, bytes) else Path(private_key).read_bytes() + ) + if password is None or isinstance(password, bytes): + self._private_key_password = password + else: + self._private_key_password = Path(password).read_bytes() + + return self def defaults(self: BuilderType, defaults: 'Defaults') -> BuilderType: """Sets the :class:`telegram.ext.Defaults` object to be used for - :attr:`telegram.ext.Updater.bot`. + :attr:`telegram.ext.Application.bot`. .. seealso:: `Adding Defaults `_ @@ -1020,14 +517,17 @@ def defaults(self: BuilderType, defaults: 'Defaults') -> BuilderType: defaults (:class:`telegram.ext.Defaults`): The defaults. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_defaults(defaults) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('defaults', 'bot instance')) + self._defaults = defaults + return self def arbitrary_callback_data( self: BuilderType, arbitrary_callback_data: Union[bool, int] ) -> BuilderType: - """Specifies whether :attr:`telegram.ext.Updater.bot` should allow arbitrary objects as + """Specifies whether :attr:`telegram.ext.Application.bot` should allow arbitrary objects as callback data for :class:`telegram.InlineKeyboardButton` and how many keyboards should be cached in memory. If not called, only strings can be used as callback data and no data will be stored in memory. @@ -1043,123 +543,109 @@ def arbitrary_callback_data( cache size. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_arbitrary_callback_data(arbitrary_callback_data) + if self._bot is not DEFAULT_NONE: + raise RuntimeError(_TWO_ARGS_REQ.format('arbitrary_callback_data', 'bot instance')) + self._arbitrary_callback_data = arbitrary_callback_data + return self def bot( - self: 'UpdaterBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, ' - 'JQ, PT]', + self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', bot: InBT, - ) -> 'UpdaterBuilder[Dispatcher[InBT, CCT, UD, CD, BD, JQ, PT], InBT, CCT, UD, CD, BD, JQ, PT]': + ) -> 'ApplicationBuilder[InBT, CCT, UD, CD, BD, JQ]': """Sets a :class:`telegram.Bot` instance to be used for - :attr:`telegram.ext.Updater.bot`. Instances of subclasses like + :attr:`telegram.ext.Application.bot`. Instances of subclasses like :class:`telegram.ext.ExtBot` are also valid. Args: bot (:class:`telegram.Bot`): The bot. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_bot(bot) # type: ignore[return-value] + for attr, error in _BOT_CHECKS: + if not isinstance(getattr(self, f'_{attr}'), DefaultValue): + raise RuntimeError(_TWO_ARGS_REQ.format('bot', error)) + self._bot = bot + return self # type: ignore[return-value] def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: """Sets a :class:`queue.Queue` instance to be used for - :attr:`telegram.ext.Updater.update_queue`, i.e. the queue that the fetched updates will - be queued into. If not called, a queue will be instantiated. - If :meth:`dispatcher` is not called, this queue will also be used for - :attr:`telegram.ext.Dispatcher.update_queue`. + :attr:`telegram.ext.Application.update_queue`, i.e. the queue that the application will + fetch updates from. Will also be used for the :attr:`telegram.ext.Application.updater`. + If not called, a queue will be instantiated. - .. seealso:: :attr:`telegram.ext.Dispatcher.update_queue`, - :meth:`telegram.ext.DispatcherBuilder.update_queue` + .. seealso:: :attr:`telegram.ext.Updater.update_queue` Args: update_queue (:class:`queue.Queue`): The queue. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_update_queue(update_queue) + if isinstance(self._updater, DefaultValue): + raise RuntimeError(_TWO_ARGS_REQ.format('update_queue', 'updater instance')) + self._update_queue = update_queue + return self - def workers(self: BuilderType, workers: int) -> BuilderType: + def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) -> BuilderType: """Sets the number of worker threads to be used for - :meth:`telegram.ext.Dispatcher.run_async`, i.e. the number of callbacks that can be run + :meth:`telegram.ext.Application.run_async`, i.e. the number of callbacks that can be run asynchronously at the same time. .. seealso:: :attr:`telegram.ext.Handler.run_sync`, - :attr:`telegram.ext.Defaults.run_async` - - Args: - workers (:obj:`int`): The number of worker threads. - - Returns: - :class:`DispatcherBuilder`: The same builder with the updated argument. - """ - return self._set_workers(workers) - - def exception_event(self: BuilderType, exception_event: Event) -> BuilderType: - """Sets a :class:`threading.Event` instance to be used by the - :class:`telegram.ext.Updater`. When an unhandled exception happens while fetching updates, - this event will be set and the ``Updater`` will stop fetching for updates. If not called, - an event will be instantiated. - If :meth:`dispatcher` is not called, this event will also be used for - :attr:`telegram.ext.Dispatcher.exception_event`. - - .. seealso:: :attr:`telegram.ext.Dispatcher.exception_event`, - :meth:`telegram.ext.DispatcherBuilder.exception_event` + :attr:`telegram.ext.Defaults.block` Args: - exception_event (:class:`threading.Event`): The event. + concurrent_updates (:obj:`int`): The number of worker threads. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_exception_event(exception_event) + self._concurrent_updates = concurrent_updates + return self def job_queue( - self: 'UpdaterBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', + self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', job_queue: InJQ, - ) -> 'UpdaterBuilder[Dispatcher[BT, CCT, UD, CD, BD, InJQ, PT], BT, CCT, UD, CD, BD, InJQ, PT]': - """Sets a :class:`telegram.ext.JobQueue` instance to be used for the - :attr:`telegram.ext.Updater.dispatcher`. If not called, a job queue will be instantiated. + ) -> 'ApplicationBuilder[BT, CCT, UD, CD, BD, InJQ]': + """Sets a :class:`telegram.ext.JobQueue` instance to be used for + :attr:`telegram.ext.Application.job_queue`. If not called, a job queue will be + instantiated. - .. seealso:: `JobQueue `_, `timerbot.py `_, - :attr:`telegram.ext.Dispatcher.job_queue` + .. seealso:: `JobQueue `_, `timerbot.py `_ Note: - * :meth:`telegram.ext.JobQueue.set_dispatcher` will be called automatically by + * :meth:`telegram.ext.JobQueue.set_application` will be called automatically by :meth:`build`. - * The job queue will be automatically started/stopped by starting/stopping the - ``Updater``, which automatically calls :meth:`telegram.ext.Dispatcher.start` - and :meth:`telegram.ext.Dispatcher.stop`, respectively. + * The job queue will be automatically started and stopped by + :meth:`telegram.ext.Application.start` and :meth:`telegram.ext.Application.stop`, + respectively. * When passing :obj:`None`, :attr:`telegram.ext.ConversationHandler.conversation_timeout` can not be used, as - this uses :attr:`telegram.ext.Dispatcher.job_queue` internally. + this uses :attr:`telegram.ext.Application.job_queue` internally. Args: job_queue (:class:`telegram.ext.JobQueue`, optional): The job queue. Pass :obj:`None` if you don't want to use a job queue. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_job_queue(job_queue) # type: ignore[return-value] + self._job_queue = job_queue + return self # type: ignore[return-value] - def persistence( - self: 'UpdaterBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', - persistence: InPT, - ) -> 'UpdaterBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, InPT], BT, CCT, UD, CD, BD, JQ, InPT]': - """Sets a :class:`telegram.ext.BasePersistence` instance to be used for the - :attr:`telegram.ext.Updater.dispatcher`. + def persistence(self: BuilderType, persistence: 'BasePersistence') -> BuilderType: + """Sets a :class:`telegram.ext.BasePersistence` instance to be used for + :attr:`telegram.ext.Application.persistence`. .. seealso:: `Making your bot persistent `_, `persistentconversationbot.py `_, - :attr:`telegram.ext.Dispatcher.persistence` + /python-telegram-bot/tree/master/examples#persistentconversationbotpy>`_ Warning: If a :class:`telegram.ext.ContextTypes` instance is set via :meth:`context_types`, @@ -1170,80 +656,44 @@ def persistence( instance. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_persistence(persistence) # type: ignore[return-value] + self._persistence = persistence + return self def context_types( - self: 'UpdaterBuilder[Dispatcher[BT, CCT, UD, CD, BD, JQ, PT], BT, CCT, UD, CD, BD, JQ, PT]', + self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', context_types: 'ContextTypes[InCCT, InUD, InCD, InBD]', - ) -> 'UpdaterBuilder[Dispatcher[BT, InCCT, InUD, InCD, InBD, JQ, PT], BT, InCCT, InUD, InCD, InBD, JQ, PT]': - """Sets a :class:`telegram.ext.ContextTypes` instance to be used for the - :attr:`telegram.ext.Updater.dispatcher`. + ) -> 'ApplicationBuilder[BT, InCCT, InUD, InCD, InBD, JQ]': + """Sets a :class:`telegram.ext.ContextTypes` instance to be used for + :attr:`telegram.ext.Application.context_types`. - .. seealso:: `contexttypesbot.py `_, - :attr:`telegram.ext.Dispatcher.context_types`. + .. seealso:: `contexttypesbot.py `_ Args: context_types (:class:`telegram.ext.ContextTypes`, optional): The context types. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_context_types(context_types) # type: ignore[return-value] - - @overload - def dispatcher( - self: 'UpdaterBuilder[ODT, BT, CCT, UD, CD, BD, JQ, PT]', dispatcher: None - ) -> 'UpdaterBuilder[None, BT, CCT, UD, CD, BD, JQ, PT]': - ... - - @overload - def dispatcher( - self: BuilderType, dispatcher: Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT] - ) -> 'UpdaterBuilder[Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT], InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]': - ... + self._context_types = context_types + return self # type: ignore[return-value] - def dispatcher( # type: ignore[misc] - self: BuilderType, - dispatcher: Optional[Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]], - ) -> 'UpdaterBuilder[Optional[Dispatcher[InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]], InBT, InCCT, InUD, InCD, InBD, InJQ, InPT]': - """Sets a :class:`telegram.ext.Dispatcher` instance to be used for - :attr:`telegram.ext.Updater.dispatcher`. - The dispatchers :attr:`telegram.ext.Dispatcher.bot`, - :attr:`telegram.ext.Dispatcher.update_queue` and - :attr:`telegram.ext.Dispatcher.exception_event` will be used for the respective arguments - of the updater. - If not called, a dispatcher will be instantiated. + def updater(self: BuilderType, updater: Union[Updater, None]) -> BuilderType: + """Sets a :class:`telegram.ext.Updater` instance to be used for + :attr:`telegram.ext.Application.updater`. Args: - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher. + updater (:class:`telegram.ext.Updater` | :obj:`None`): The updater instance or + :obj:`None` if no updater should be used. Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. + :class:`ApplicationBuilder`: The same builder with the updated argument. """ - return self._set_dispatcher(dispatcher) # type: ignore[return-value] - - def user_signal_handler( - self: BuilderType, user_signal_handler: Callable[[int, object], Any] - ) -> BuilderType: - """Sets a callback to be used for :attr:`telegram.ext.Updater.user_signal_handler`. - The callback will be called when :meth:`telegram.ext.Updater.idle()` receives a signal. - It will be called with the two arguments ``signum, frame`` as for the - :meth:`signal.signal` of the standard library. + for attr, error in (self._bot, 'bot instance'), (self._update_queue, 'update queue'): + if not isinstance(attr, DefaultValue): + raise RuntimeError(_TWO_ARGS_REQ.format('updater', error)) - Note: - Signal handlers are an advanced feature that come with some culprits and are not thread - safe. This should therefore only be used for tasks like closing threads or database - connections on shutdown. Note that for many tasks a viable alternative is to simply - put your code *after* calling :meth:`telegram.ext.Updater.idle`. In this case it will - be executed after the updater has shut down. - - Args: - user_signal_handler (Callable[signum, frame]): The signal handler. - - Returns: - :class:`UpdaterBuilder`: The same builder with the updated argument. - """ - return self._set_user_signal_handler(user_signal_handler) + self._updater = updater + return self diff --git a/telegram/ext/_callbackcontext.py b/telegram/ext/_callbackcontext.py index 4663243a740..4fcaab44cd8 100644 --- a/telegram/ext/_callbackcontext.py +++ b/telegram/ext/_callbackcontext.py @@ -18,7 +18,7 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. # pylint: disable=no-self-use """This module contains the CallbackContext class.""" -from queue import Queue +from asyncio import Queue from typing import ( TYPE_CHECKING, Dict, @@ -27,17 +27,17 @@ NoReturn, Optional, Tuple, - Union, Generic, Type, + Coroutine, ) from telegram import Update, CallbackQuery from telegram.ext import ExtBot -from telegram.ext._utils.types import UD, CD, BD, BT, JQ, PT # pylint: disable=unused-import +from telegram.ext._utils.types import UD, CD, BD, BT, JQ # pylint: disable=unused-import if TYPE_CHECKING: - from telegram.ext import Dispatcher, Job, JobQueue + from telegram.ext import Application, Job, JobQueue from telegram.ext._utils.types import CCT _STORING_DATA_WIKI = ( @@ -49,12 +49,12 @@ class CallbackContext(Generic[BT, UD, CD, BD]): """ This is a context object passed to the callback called by :class:`telegram.ext.Handler` - or by the :class:`telegram.ext.Dispatcher` in an error handler added by - :attr:`telegram.ext.Dispatcher.add_error_handler` or to the callback of a + or by the :class:`telegram.ext.Application` in an error handler added by + :attr:`telegram.ext.Application.add_error_handler` or to the callback of a :class:`telegram.ext.Job`. Note: - :class:`telegram.ext.Dispatcher` will create a single context for an entire update. This + :class:`telegram.ext.Application` will create a single context for an entire update. This means that if you got 2 handlers in different groups and they both get called, they will get passed the same `CallbackContext` object (of course with proper attributes like `.matches` differing). This allows you to add custom attributes in a lower handler group @@ -69,7 +69,8 @@ class CallbackContext(Generic[BT, UD, CD, BD]): that you think you added will not be present. Args: - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this context. + application (:class:`telegram.ext.Application`): The application associated with this + context. Attributes: matches (List[:obj:`re match object`]): Optional. If the associated update originated from @@ -81,13 +82,13 @@ class CallbackContext(Generic[BT, UD, CD, BD]): or :class:`telegram.ext.StringCommandHandler`. It contains a list of the words in the text after the command, using any whitespace string as a delimiter. error (:obj:`Exception`): Optional. The error that was raised. Only present when passed - to a error handler registered with :attr:`telegram.ext.Dispatcher.add_error_handler`. + to a error handler registered with :attr:`telegram.ext.Application.add_error_handler`. async_args (List[:obj:`object`]): Optional. Positional arguments of the function that raised the error. Only present when the raising function was run asynchronously using - :meth:`telegram.ext.Dispatcher.run_async`. + :meth:`telegram.ext.Application.run_async`. async_kwargs (Dict[:obj:`str`, :obj:`object`]): Optional. Keyword arguments of the function that raised the error. Only present when the raising function was run asynchronously - using :meth:`telegram.ext.Dispatcher.run_async`. + using :meth:`telegram.ext.Application.run_async`. job (:class:`telegram.ext.Job`): Optional. The job which originated this callback. Only present when passed to the callback of :class:`telegram.ext.Job` or in error handlers if the error is caused by a job. @@ -118,44 +119,42 @@ def callback(update: Update, context: CallbackContext.DEFAULT_TYPE): """ __slots__ = ( - '_dispatcher', + '_application', '_chat_id_and_data', '_user_id_and_data', 'args', 'matches', 'error', 'job', - 'async_args', - 'async_kwargs', + 'coroutine', '__dict__', ) - def __init__(self: 'CCT', dispatcher: 'Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]'): + def __init__(self: 'CCT', application: 'Application[BT, CCT, UD, CD, BD, JQ]'): """ Args: - dispatcher (:class:`telegram.ext.Dispatcher`): + application (:class:`telegram.ext.Application`): """ - self._dispatcher = dispatcher + self._application = application self._chat_id_and_data: Optional[Tuple[int, CD]] = None self._user_id_and_data: Optional[Tuple[int, UD]] = None self.args: Optional[List[str]] = None self.matches: Optional[List[Match]] = None self.error: Optional[Exception] = None self.job: Optional['Job'] = None - self.async_args: Optional[Union[List, Tuple]] = None - self.async_kwargs: Optional[Dict[str, object]] = None + self.coroutine: Optional[Coroutine] = None @property - def dispatcher(self) -> 'Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]': - """:class:`telegram.ext.Dispatcher`: The dispatcher associated with this context.""" - return self._dispatcher + def application(self) -> 'Application[BT, CCT, UD, CD, BD, JQ]': + """:class:`telegram.ext.Application`: The application associated with this context.""" + return self._application @property def bot_data(self) -> BD: """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each update it will be the same ``dict``. """ - return self.dispatcher.bot_data + return self.application.bot_data @bot_data.setter def bot_data(self, value: object) -> NoReturn: @@ -199,28 +198,31 @@ def user_data(self, value: object) -> NoReturn: f"You can not assign a new value to user_data, see {_STORING_DATA_WIKI}" ) - def refresh_data(self) -> None: - """If :attr:`dispatcher` uses persistence, calls + async def refresh_data(self) -> None: + """If :attr:`application` uses persistence, calls :meth:`telegram.ext.BasePersistence.refresh_bot_data` on :attr:`bot_data`, :meth:`telegram.ext.BasePersistence.refresh_chat_data` on :attr:`chat_data` and :meth:`telegram.ext.BasePersistence.refresh_user_data` on :attr:`user_data`, if appropriate. + Will be called by :meth:`telegram.ext.Application.process_update` and + :meth:`telegram.ext.Job.run`. + .. versionadded:: 13.6 """ - if self.dispatcher.persistence: - if self.dispatcher.persistence.store_data.bot_data: - self.dispatcher.persistence.refresh_bot_data(self.bot_data) + if self.application.persistence: + if self.application.persistence.store_data.bot_data: + await self.application.persistence.refresh_bot_data(self.bot_data) if ( - self.dispatcher.persistence.store_data.chat_data + self.application.persistence.store_data.chat_data and self._chat_id_and_data is not None ): - self.dispatcher.persistence.refresh_chat_data(*self._chat_id_and_data) + await self.application.persistence.refresh_chat_data(*self._chat_id_and_data) if ( - self.dispatcher.persistence.store_data.user_data + self.application.persistence.store_data.user_data and self._user_id_and_data is not None ): - self.dispatcher.persistence.refresh_user_data(*self._user_id_and_data) + await self.application.persistence.refresh_user_data(*self._user_id_and_data) def drop_callback_data(self, callback_query: CallbackQuery) -> None: """ @@ -255,29 +257,28 @@ def from_error( cls: Type['CCT'], update: object, error: Exception, - dispatcher: 'Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]', - async_args: Union[List, Tuple] = None, - async_kwargs: Dict[str, object] = None, + application: 'Application[BT, CCT, UD, CD, BD, JQ]', job: 'Job' = None, + coroutine: Coroutine = None, ) -> 'CCT': """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the error handlers. - .. seealso:: :meth:`telegram.ext.Dispatcher.add_error_handler` + .. seealso:: :meth:`telegram.ext.Application.add_error_handler` Args: update (:obj:`object` | :class:`telegram.Update`): The update associated with the error. May be :obj:`None`, e.g. for errors in job callbacks. error (:obj:`Exception`): The error. - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this + application (:class:`telegram.ext.Application`): The application associated with this context. async_args (List[:obj:`object`], optional): Positional arguments of the function that raised the error. Pass only when the raising function was run asynchronously using - :meth:`telegram.ext.Dispatcher.run_async`. + :meth:`telegram.ext.Application.run_async`. async_kwargs (Dict[:obj:`str`, :obj:`object`], optional): Keyword arguments of the function that raised the error. Pass only when the raising function was run - asynchronously using :meth:`telegram.ext.Dispatcher.run_async`. + asynchronously using :meth:`telegram.ext.Application.run_async`. job (:class:`telegram.ext.Job`, optional): The job associated with the error. .. versionadded:: 14.0 @@ -285,32 +286,33 @@ def from_error( Returns: :class:`telegram.ext.CallbackContext` """ - self = cls.from_update(update, dispatcher) + self = cls.from_update(update, application) self.error = error - self.async_args = async_args - self.async_kwargs = async_kwargs + self.coroutine = coroutine self.job = job return self @classmethod def from_update( - cls: Type['CCT'], update: object, dispatcher: 'Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]' + cls: Type['CCT'], + update: object, + application: 'Application[BT, CCT, UD, CD, BD, JQ]', ) -> 'CCT': """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to the handlers. - .. seealso:: :meth:`telegram.ext.Dispatcher.add_handler` + .. seealso:: :meth:`telegram.ext.Application.add_handler` Args: update (:obj:`object` | :class:`telegram.Update`): The update. - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this + application (:class:`telegram.ext.Application`): The application associated with this context. Returns: :class:`telegram.ext.CallbackContext` """ - self = cls(dispatcher) # type: ignore[arg-type] + self = cls(application) # type: ignore[arg-type] if update is not None and isinstance(update, Update): chat = update.effective_chat @@ -319,18 +321,20 @@ def from_update( if chat: self._chat_id_and_data = ( chat.id, - dispatcher.chat_data[chat.id], # pylint: disable=protected-access + application.chat_data[chat.id], # pylint: disable=protected-access ) if user: self._user_id_and_data = ( user.id, - dispatcher.user_data[user.id], # pylint: disable=protected-access + application.user_data[user.id], # pylint: disable=protected-access ) return self @classmethod def from_job( - cls: Type['CCT'], job: 'Job', dispatcher: 'Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]' + cls: Type['CCT'], + job: 'Job', + application: 'Application[BT, CCT, UD, CD, BD, JQ]', ) -> 'CCT': """ Constructs an instance of :class:`telegram.ext.CallbackContext` to be passed to a @@ -340,14 +344,25 @@ def from_job( Args: job (:class:`telegram.ext.Job`): The job. - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher associated with this + application (:class:`telegram.ext.Application`): The application associated with this context. Returns: :class:`telegram.ext.CallbackContext` """ - self = cls(dispatcher) # type: ignore[arg-type] + self = cls(application) # type: ignore[arg-type] self.job = job + + if job.chat_id: + self._chat_id_and_data = ( + job.chat_id, + application.chat_data[job.chat_id], # pylint: disable=protected-access + ) + if job.user_id: + self._user_id_and_data = ( + job.user_id, + application.user_data[job.user_id], # pylint: disable=protected-access + ) return self def update(self, data: Dict[str, object]) -> None: @@ -362,27 +377,27 @@ def update(self, data: Dict[str, object]) -> None: @property def bot(self) -> BT: """:class:`telegram.Bot`: The bot associated with this context.""" - return self._dispatcher.bot + return self._application.bot @property def job_queue(self) -> Optional['JobQueue']: """ :class:`telegram.ext.JobQueue`: The ``JobQueue`` used by the - :class:`telegram.ext.Dispatcher` and (usually) the :class:`telegram.ext.Updater` + :class:`telegram.ext.Application` and (usually) the :class:`telegram.ext.Updater` associated with this context. """ - return self._dispatcher.job_queue + return self._application.job_queue @property - def update_queue(self) -> Queue: + def update_queue(self) -> 'Queue[object]': """ - :class:`queue.Queue`: The ``Queue`` instance used by the - :class:`telegram.ext.Dispatcher` and (usually) the :class:`telegram.ext.Updater` + :class:`asyncio.Queue`: The ``Queue`` instance used by the + :class:`telegram.ext.Application` and (usually) the :class:`telegram.ext.Updater` associated with this context. """ - return self._dispatcher.update_queue + return self._application.update_queue @property def match(self) -> Optional[Match[str]]: diff --git a/telegram/ext/_callbackqueryhandler.py b/telegram/ext/_callbackqueryhandler.py index 9b135e1aca2..9c760255254 100644 --- a/telegram/ext/_callbackqueryhandler.py +++ b/telegram/ext/_callbackqueryhandler.py @@ -31,12 +31,13 @@ ) from telegram import Update +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.types import CCT +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram.ext._utils.types import CCT, HandlerCallback if TYPE_CHECKING: - from telegram.ext import Dispatcher + from telegram.ext import Application RT = TypeVar('RT') @@ -57,7 +58,7 @@ class CallbackQueryHandler(Handler[Update, CCT]): .. versionadded:: 13.6 Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -85,8 +86,9 @@ class CallbackQueryHandler(Handler[Update, CCT]): .. versionchanged:: 13.6 Added support for arbitrary callback data. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. @@ -95,7 +97,9 @@ class CallbackQueryHandler(Handler[Update, CCT]): .. versionchanged:: 13.6 Added support for arbitrary callback data. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -103,14 +107,11 @@ class CallbackQueryHandler(Handler[Update, CCT]): def __init__( self, - callback: Callable[[Update, CCT], RT], + callback: HandlerCallback[Update, CCT, RT], pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None, - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) if isinstance(pattern, str): pattern = re.compile(pattern) @@ -147,7 +148,7 @@ def collect_additional_context( self, context: CCT, update: Update, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Union[bool, Match], ) -> None: """Add the result of ``re.match(pattern, update.callback_query.data)`` to diff --git a/telegram/ext/_chatjoinrequesthandler.py b/telegram/ext/_chatjoinrequesthandler.py index 47b83057952..0e49f4e60e5 100644 --- a/telegram/ext/_chatjoinrequesthandler.py +++ b/telegram/ext/_chatjoinrequesthandler.py @@ -29,7 +29,7 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a chat join request. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. .. versionadded:: 13.8 @@ -43,12 +43,13 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. """ diff --git a/telegram/ext/_chatmemberhandler.py b/telegram/ext/_chatmemberhandler.py index eecb3ef280e..e6d3089f631 100644 --- a/telegram/ext/_chatmemberhandler.py +++ b/telegram/ext/_chatmemberhandler.py @@ -17,12 +17,13 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the ChatMemberHandler classes.""" -from typing import ClassVar, TypeVar, Union, Callable +from typing import ClassVar, TypeVar from telegram import Update +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.types import CCT +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram.ext._utils.types import CCT, HandlerCallback RT = TypeVar('RT') @@ -33,7 +34,7 @@ class ChatMemberHandler(Handler[Update, CCT]): .. versionadded:: 13.4 Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -47,15 +48,18 @@ class ChatMemberHandler(Handler[Update, CCT]): :attr:`CHAT_MEMBER` or :attr:`ANY_CHAT_MEMBER` to specify if this handler should handle only updates with :attr:`telegram.Update.my_chat_member`, :attr:`telegram.Update.chat_member` or both. Defaults to :attr:`MY_CHAT_MEMBER`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. chat_member_types (:obj:`int`, optional): Specifies if this handler should handle only updates with :attr:`telegram.Update.my_chat_member`, :attr:`telegram.Update.chat_member` or both. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -70,14 +74,11 @@ class ChatMemberHandler(Handler[Update, CCT]): def __init__( self, - callback: Callable[[Update, CCT], RT], + callback: HandlerCallback[Update, CCT, RT], chat_member_types: int = MY_CHAT_MEMBER, - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) self.chat_member_types = chat_member_types diff --git a/telegram/ext/_choseninlineresulthandler.py b/telegram/ext/_choseninlineresulthandler.py index 4831550c95f..103fa4bbb24 100644 --- a/telegram/ext/_choseninlineresulthandler.py +++ b/telegram/ext/_choseninlineresulthandler.py @@ -18,24 +18,25 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the ChosenInlineResultHandler class.""" import re -from typing import Optional, TypeVar, Union, Callable, TYPE_CHECKING, Pattern, Match, cast +from typing import Optional, TypeVar, Union, TYPE_CHECKING, Pattern, Match, cast from telegram import Update +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.types import CCT +from telegram.ext._utils.types import CCT, HandlerCallback RT = TypeVar('RT') if TYPE_CHECKING: - from telegram.ext import CallbackContext, Dispatcher + from telegram.ext import CallbackContext, Application class ChosenInlineResultHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a chosen inline result. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -45,8 +46,9 @@ class ChosenInlineResultHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. pattern (:obj:`str` | `Pattern`, optional): Regex pattern. If not :obj:`None`, ``re.match`` is used on :attr:`telegram.ChosenInlineResult.result_id` to determine if an update should be handled by this handler. This is accessible in the callback as @@ -56,7 +58,9 @@ class ChosenInlineResultHandler(Handler[Update, CCT]): Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. pattern (`Pattern`): Optional. Regex pattern to test :attr:`telegram.ChosenInlineResult.result_id` against. @@ -68,14 +72,11 @@ class ChosenInlineResultHandler(Handler[Update, CCT]): def __init__( self, - callback: Callable[[Update, 'CallbackContext'], RT], - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + callback: HandlerCallback[Update, CCT, RT], + block: DVInput[bool] = DEFAULT_TRUE, pattern: Union[str, Pattern] = None, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) if isinstance(pattern, str): pattern = re.compile(pattern) @@ -105,7 +106,7 @@ def collect_additional_context( self, context: 'CallbackContext', update: Update, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Union[bool, Match], ) -> None: """This function adds the matched regex pattern result to diff --git a/telegram/ext/_commandhandler.py b/telegram/ext/_commandhandler.py index 5a762ba8005..aebfd21dece 100644 --- a/telegram/ext/_commandhandler.py +++ b/telegram/ext/_commandhandler.py @@ -18,16 +18,16 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the CommandHandler and PrefixHandler classes.""" import re -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar, Union from telegram import MessageEntity, Update from telegram.ext import filters as filters_module, Handler -from telegram._utils.types import SLT -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.types import CCT +from telegram._utils.types import SLT, DVInput +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram.ext._utils.types import CCT, HandlerCallback if TYPE_CHECKING: - from telegram.ext import Dispatcher + from telegram.ext import Application RT = TypeVar('RT') @@ -47,7 +47,7 @@ class CommandHandler(Handler[Update, CCT]): * :class:`CommandHandler` does *not* handle (edited) channel posts. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -64,8 +64,9 @@ class CommandHandler(Handler[Update, CCT]): :class:`telegram.ext.filters.BaseFilter`. Standard filters can be found in :mod:`telegram.ext.filters`. Filters can be combined using bitwise operators (& for and, | for or, ~ for not). - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Raises: ValueError: when command is too long or has illegal chars. @@ -77,7 +78,9 @@ class CommandHandler(Handler[Update, CCT]): callback (:obj:`callable`): The callback function for this handler. filters (:class:`telegram.ext.BaseFilter`): Optional. Only allow updates with these Filters. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ __slots__ = ('command', 'filters') @@ -85,11 +88,11 @@ class CommandHandler(Handler[Update, CCT]): def __init__( self, command: SLT[str], - callback: Callable[[Update, CCT], RT], + callback: HandlerCallback[Update, CCT, RT], filters: filters_module.BaseFilter = None, - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__(callback, run_async=run_async) + super().__init__(callback, block=block) if isinstance(command, str): self.command = [command.lower()] @@ -144,7 +147,7 @@ def collect_additional_context( self, context: CCT, update: Update, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Optional[Union[bool, Tuple[List[str], Optional[bool]]]], ) -> None: """Add text after the command to :attr:`CallbackContext.args` as list, split on single @@ -194,7 +197,7 @@ class PrefixHandler(CommandHandler): * :class:`PrefixHandler` does *not* handle (edited) channel posts. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -212,14 +215,17 @@ class PrefixHandler(CommandHandler): :class:`telegram.ext.filters.BaseFilter`. Standard filters can be found in :mod:`telegram.ext.filters`. Filters can be combined using bitwise operators (& for and, | for or, ~ for not). - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. filters (:class:`telegram.ext.BaseFilter`): Optional. Only allow updates with these Filters. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -230,9 +236,9 @@ def __init__( self, prefix: SLT[str], command: SLT[str], - callback: Callable[[Update, CCT], RT], + callback: HandlerCallback[Update, CCT, RT], filters: filters_module.BaseFilter = None, - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + block: DVInput[bool] = DEFAULT_TRUE, ): self._prefix: List[str] = [] @@ -243,7 +249,7 @@ def __init__( 'nocommand', callback, filters=filters, - run_async=run_async, + block=block, ) self.prefix = prefix # type: ignore[assignment] diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index a8b9c95864b..32cb0a9bea2 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -18,11 +18,10 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. # pylint: disable=no-self-use """This module contains the ConversationHandler.""" - +import asyncio import logging import functools import datetime -from threading import Lock from typing import ( # pylint: disable=unused-import # for the "Any" import TYPE_CHECKING, Dict, @@ -34,15 +33,15 @@ cast, ClassVar, Any, + Set, ) from telegram import Update from telegram.ext import ( - BasePersistence, CallbackContext, CallbackQueryHandler, ChosenInlineResultHandler, - DispatcherHandlerStop, + ApplicationHandlerStop, Handler, InlineQueryHandler, StringCommandHandler, @@ -50,28 +49,28 @@ TypeHandler, ) from telegram._utils.warnings import warn -from telegram.ext._utils.promise import Promise +from telegram.ext._utils.trackingdefaultdict import TrackingDefaultDict from telegram.ext._utils.types import ConversationDict from telegram.ext._utils.types import CCT if TYPE_CHECKING: - from telegram.ext import Dispatcher, Job, JobQueue -CheckUpdateType = Optional[Tuple[Tuple[int, ...], Handler, object]] + from telegram.ext import Application, Job, JobQueue +CheckUpdateType = Tuple[object, Tuple[int, ...], Handler, object] class _ConversationTimeoutContext: - __slots__ = ('conversation_key', 'update', 'dispatcher', 'callback_context') + __slots__ = ('conversation_key', 'update', 'application', 'callback_context') def __init__( self, conversation_key: Tuple[int, ...], update: Update, - dispatcher: 'Dispatcher[Any, CCT, Any, Any, Any, JobQueue, Any]', + application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', callback_context: CallbackContext, ): self.conversation_key = conversation_key self.update = update - self.dispatcher = dispatcher + self.application = application self.callback_context = callback_context @@ -93,7 +92,7 @@ class ConversationHandler(Handler[Update, CCT]): Finally, ``ConversationHandler``, does *not* handle (edited) channel posts. .. _`FAQ`: https://github.com/python-telegram-bot/python-telegram-bot/wiki\ - /Frequently-Asked-Questions#what-do-the-per_-settings-in-conversationhandler-do + /Frequently-Asked-Questions#what-do-the-per_-settings-in-conversation handler-do The first collection, a ``list`` named :attr:`entry_points`, is used to initiate the conversation, for example with a :class:`telegram.ext.CommandHandler` or @@ -118,7 +117,7 @@ class ConversationHandler(Handler[Update, CCT]): the conversation ends immediately after the execution of this callback function. To end the conversation, the callback function must return :attr:`END` or ``-1``. To handle the conversation timeout, use handler :attr:`TIMEOUT` or ``-2``. - Finally, :class:`telegram.ext.DispatcherHandlerStop` can be used in conversations as described + Finally, :class:`telegram.ext.ApplicationHandlerStop` can be used in conversations as described in the corresponding documentation. Note: @@ -170,22 +169,21 @@ class ConversationHandler(Handler[Update, CCT]): from what you expect. - name (:obj:`str`, optional): The name for this conversationhandler. Required for + name (:obj:`str`, optional): The name for this conversation handler. Required for persistence. persistent (:obj:`bool`, optional): If the conversations dict for this handler should be saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` map_to_parent (Dict[:obj:`object`, :obj:`object`], optional): A :obj:`dict` that can be - used to instruct a nested conversationhandler to transition into a mapped state on - its parent conversationhandler in place of a specified nested state. - run_async (:obj:`bool`, optional): Pass :obj:`True` to *override* the - :attr:`Handler.run_async` setting of all handlers (in :attr:`entry_points`, + used to instruct a nested conversation handler to transition into a mapped state on + its parent conversation handler in place of a specified nested state. + block (:obj:`bool`, optional): Pass :obj:`False` to *overrule* the + :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, :attr:`states` and :attr:`fallbacks`). - - Note: - If set to :obj:`True`, you should not pass a handler instance, that needs to be - run synchronously in another context. + Defaults to :obj:`True`. .. versionadded:: 13.2 + .. versionchanged:: 14.0 + No longer overrides the handlers settings Raises: ValueError @@ -193,31 +191,32 @@ class ConversationHandler(Handler[Update, CCT]): Attributes: persistent (:obj:`bool`): Optional. If the conversations dict for this handler should be saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` - run_async (:obj:`bool`): If :obj:`True`, will override the - :attr:`Handler.run_async` setting of all internal handlers on initialization. + block (:obj:`bool`): Determines whether the callback will run asynchronously. .. versionadded:: 13.2 """ __slots__ = ( + '__aplication', + '_allow_reentry', + '_child_conversations', + '_conversation_timeout', + '_conversations', + '_conversations_lock', '_entry_points', - '_states', '_fallbacks', - '_allow_reentry', - '_per_user', + '_logger', + '_map_to_parent', + '_name', '_per_chat', '_per_message', - '_conversation_timeout', - '_name', - 'persistent', + '_per_user', '_persistence', - '_map_to_parent', - 'timeout_jobs', + '_states', '_timeout_jobs_lock', - '_conversations', - '_conversations_lock', - 'logger', + 'persistent', + 'timeout_jobs', ) END: ClassVar[int] = -1 @@ -241,7 +240,7 @@ def __init__( name: str = None, persistent: bool = False, map_to_parent: Dict[object, object] = None, - run_async: bool = False, + block: bool = False, ): # these imports need to be here because of circular import error otherwise from telegram.ext import ( # pylint: disable=import-outside-toplevel @@ -251,7 +250,7 @@ def __init__( PollAnswerHandler, ) - self.run_async = run_async + self.block = block self._entry_points = entry_points self._states = states @@ -263,20 +262,19 @@ def __init__( self._per_message = per_message self._conversation_timeout = conversation_timeout self._name = name - if persistent and not self.name: - raise ValueError("Conversations can't be persistent when handler is unnamed.") - self.persistent: bool = persistent - self._persistence: Optional[BasePersistence] = None - """:obj:`telegram.ext.BasePersistence`: The persistence used to store conversations. - Set by dispatcher""" self._map_to_parent = map_to_parent self.timeout_jobs: Dict[Tuple[int, ...], 'Job'] = {} - self._timeout_jobs_lock = Lock() + self._timeout_jobs_lock = asyncio.Lock() self._conversations: ConversationDict = {} - self._conversations_lock = Lock() + self._conversations_lock = asyncio.Lock() + self._child_conversations: Set['ConversationHandler'] = set() + + if persistent and not self.name: + raise ValueError("Conversations can't be persistent when handler is unnamed.") + self.persistent: bool = persistent - self.logger = logging.getLogger(__name__) + self._logger = logging.getLogger(__name__) if not any((self.per_user, self.per_chat, self.per_message)): raise ValueError("'per_user', 'per_chat' and 'per_message' can't all be 'False'") @@ -295,6 +293,10 @@ def __init__( for state_handlers in states.values(): all_handlers.extend(state_handlers) + self._child_conversations.update( + handler for handler in all_handlers if isinstance(handler, ConversationHandler) + ) + # this loop is going to warn the user about handlers which can work unexpected # in conversations @@ -302,10 +304,13 @@ def __init__( per_faq_link = ( " Read this FAQ entry to learn more about the per_* settings: " "https://github.com/python-telegram-bot/python-telegram-bot/wiki" - "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversationhandler-do." + "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversation handler-do." ) for handler in all_handlers: + if self.block: + handler.block = True + if isinstance(handler, (StringCommandHandler, StringRegexHandler)): warn( "The `ConversationHandler` only handles updates of type `telegram.Update`. " @@ -367,9 +372,6 @@ def __init__( stacklevel=2, ) - if self.run_async: - handler.run_async = True - @property def entry_points(self) -> List[Handler]: """List[:class:`telegram.ext.Handler`]: A list of ``Handler`` objects that can trigger the @@ -484,32 +486,40 @@ def map_to_parent(self, value: object) -> NoReturn: "You can not assign a new value to map_to_parent after initialization." ) - @property - def persistence(self) -> Optional[BasePersistence]: - """The persistence class as provided by the :class:`Dispatcher`.""" - return self._persistence - - @persistence.setter - def persistence(self, persistence: BasePersistence) -> None: - self._persistence = persistence - # Set persistence for nested conversations - for handlers in self.states.values(): - for handler in handlers: - if isinstance(handler, ConversationHandler): - handler.persistence = self.persistence + async def _initialize_persistence( + self, application: 'Application' + ) -> TrackingDefaultDict[Tuple[int, ...], object]: + """Initializes the persistence for this handler. While this method is marked as protected, + we expect it to be called by the Application/parent conversations. It's just protected to + hide it from users. - @property - def conversations(self) -> ConversationDict: # skipcq: PY-D0003 - return self._conversations + Args: + application (:class:`telegram.ext.Application`): The application. + + """ + if not (self.persistent and self.name and application.persistence): + raise RuntimeError( + 'This handler is not persistent, has no name or the application has no ' + 'persistence!' + ) + + def default_factory() -> NoReturn: + raise KeyError + + self._conversations = cast( + TrackingDefaultDict[Tuple[int, ...], object], + TrackingDefaultDict( + default_factory=default_factory, track_read=False, track_write=True + ), + ) + self._conversations.update(await application.persistence.get_conversations(self.name)) + + for handler in self._child_conversations: + await handler._initialize_persistence( # pylint: disable=protected-access + application=application + ) - @conversations.setter - def conversations(self, value: ConversationDict) -> None: - self._conversations = value - # Set conversations for nested conversations - for handlers in self.states.values(): - for handler in handlers: - if isinstance(handler, ConversationHandler) and self.persistence and handler.name: - handler.conversations = self.persistence.get_conversations(handler.name) + return self._conversations def _get_key(self, update: Update) -> Tuple[int, ...]: chat = update.effective_chat @@ -531,49 +541,54 @@ def _get_key(self, update: Update) -> Tuple[int, ...]: return tuple(key) - def _resolve_promise(self, state: Tuple) -> object: + def _resolve_task(self, state: Tuple[object, asyncio.Task]) -> object: old_state, new_state = state - try: - res = new_state.result(0) - res = res if res is not None else old_state - except Exception as exc: - self.logger.exception("Promise function raised exception") - self.logger.exception("%s", exc) + res = new_state.result() + res = res if res is not None else old_state + + exc = new_state.exception() + if exc: + self._logger.exception("Task function raised exception") + self._logger.exception("%s", exc) res = old_state - finally: - if res is None and old_state is None: - res = self.END + + if res is None and old_state is None: + res = self.END + return res def _schedule_job( self, - new_state: object, - dispatcher: 'Dispatcher[Any, CCT, Any, Any, Any, JobQueue, Any]', + new_state: Union[object, asyncio.Task], + application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', update: Update, context: CallbackContext, conversation_key: Tuple[int, ...], ) -> None: + if isinstance(new_state, asyncio.Task): + new_state = new_state.result() + if new_state != self.END: try: # both job_queue & conversation_timeout are checked before calling _schedule_job - j_queue = dispatcher.job_queue + j_queue = application.job_queue self.timeout_jobs[conversation_key] = j_queue.run_once( self._trigger_timeout, self.conversation_timeout, # type: ignore[arg-type] context=_ConversationTimeoutContext( - conversation_key, update, dispatcher, context + conversation_key, update, application, context ), ) except Exception as exc: - self.logger.exception( + self._logger.exception( "Failed to schedule timeout job due to the following exception:" ) - self.logger.exception("%s", exc) + self._logger.exception("%s", exc) # pylint: disable=too-many-return-statements - def check_update(self, update: object) -> CheckUpdateType: + def check_update(self, update: object) -> Optional[CheckUpdateType]: """ - Determines whether an update should be handled by this conversationhandler, and if so in + Determines whether an update should be handled by this conversation handler, and if so in which state the conversation currently is. Args: @@ -597,31 +612,31 @@ def check_update(self, update: object) -> CheckUpdateType: key = self._get_key(update) with self._conversations_lock: - state = self.conversations.get(key) + state = self._conversations.get(key) # Resolve promises - if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], Promise): - self.logger.debug('waiting for promise...') + if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], asyncio.Task): + self._logger.warning('Waiting for asyncio Task to finish ...') # check if promise is finished or not - if state[1].done.wait(0): - res = self._resolve_promise(state) + if state[1].done(): + res = self._resolve_task(state) # type: ignore[arg-type] self._update_state(res, key) with self._conversations_lock: - state = self.conversations.get(key) + state = self._conversations.get(key) # if not then handle WAITING state instead else: - hdlrs = self.states.get(self.WAITING, []) - for hdlr in hdlrs: - check = hdlr.check_update(update) + handlers = self.states.get(self.WAITING, []) + for handler_ in handlers: + check = handler_.check_update(update) if check is not None and check is not False: - return key, hdlr, check + return self.WAITING, key, handler_, check return None - self.logger.debug('selecting conversation %s with state %s', str(key), str(state)) + self._logger.debug('Selecting conversation %s with state %s', str(key), str(state)) - handler = None + handler: Optional[Handler] = None # Search entry points for a match if state is None or self.allow_reentry: @@ -637,9 +652,7 @@ def check_update(self, update: object) -> CheckUpdateType: # Get the handler list for current state, if we didn't find one yet and we're still here if state is not None and not handler: - handlers = self.states.get(state) - - for candidate in handlers or []: + for candidate in self.states.get(state, []): check = candidate.check_update(update) if check is not None and check is not False: handler = candidate @@ -656,27 +669,28 @@ def check_update(self, update: object) -> CheckUpdateType: else: return None - return key, handler, check # type: ignore[return-value] + return state, key, handler, check # type: ignore[return-value] - def handle_update( # type: ignore[override] + async def handle_update( # type: ignore[override] self, update: Update, - dispatcher: 'Dispatcher', + application: 'Application', check_result: CheckUpdateType, context: CallbackContext, ) -> Optional[object]: """Send the update to the callback for the current state and Handler Args: - check_result: The result from check_update. For this handler it's a tuple of key, - handler, and the handler's check result. + check_result: The result from check_update. For this handler it's a tuple of the + conversation state, key, handler, and the handler's check result. update (:class:`telegram.Update`): Incoming telegram update. - dispatcher (:class:`telegram.ext.Dispatcher`): Dispatcher that originated the Update. + application (:class:`telegram.ext.Application`): Application that originated the + update. context (:class:`telegram.ext.CallbackContext`): The context as provided by - the dispatcher. + the application. """ - conversation_key, handler, check_result = check_result # type: ignore[assignment,misc] + current_state, conversation_key, handler, handler_check_result = check_result raise_dp_handler_stop = False with self._timeout_jobs_lock: @@ -686,19 +700,22 @@ def handle_update( # type: ignore[override] if timeout_job is not None: timeout_job.schedule_removal() try: - new_state = handler.handle_update(update, dispatcher, check_result, context) - except DispatcherHandlerStop as exception: + # TODO handle non-blocking handlers correctly + new_state: object = await handler.handle_update( + update, application, handler_check_result, context + ) + except ApplicationHandlerStop as exception: new_state = exception.state raise_dp_handler_stop = True with self._timeout_jobs_lock: if self.conversation_timeout: - if dispatcher.job_queue is not None: + if application.job_queue is not None: # Add the new timeout job - if isinstance(new_state, Promise): + if isinstance(new_state, asyncio.Task): new_state.add_done_callback( functools.partial( self._schedule_job, - dispatcher=dispatcher, + application=application, update=update, context=context, conversation_key=conversation_key, @@ -706,42 +723,38 @@ def handle_update( # type: ignore[override] ) elif new_state != self.END: self._schedule_job( - new_state, dispatcher, update, context, conversation_key + new_state, application, update, context, conversation_key ) else: warn( - "Ignoring `conversation_timeout` because the Dispatcher has no JobQueue.", + "Ignoring `conversation_timeout` because the Application has no JobQueue.", ) if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent: self._update_state(self.END, conversation_key) if raise_dp_handler_stop: - raise DispatcherHandlerStop(self.map_to_parent.get(new_state)) + raise ApplicationHandlerStop(self.map_to_parent.get(new_state)) return self.map_to_parent.get(new_state) - self._update_state(new_state, conversation_key) + if current_state != self.WAITING: + self._update_state(new_state, conversation_key) + if raise_dp_handler_stop: # Don't pass the new state here. If we're in a nested conversation, the parent is # expecting None as return value. - raise DispatcherHandlerStop() + raise ApplicationHandlerStop() return None def _update_state(self, new_state: object, key: Tuple[int, ...]) -> None: if new_state == self.END: with self._conversations_lock: - if key in self.conversations: + if key in self._conversations: # If there is no key in conversations, nothing is done. - del self.conversations[key] - if self.persistent and self.persistence and self.name: - self.persistence.update_conversation(self.name, key, None) + del self._conversations[key] - elif isinstance(new_state, Promise): + elif isinstance(new_state, asyncio.Task): with self._conversations_lock: - self.conversations[key] = (self.conversations.get(key), new_state) - if self.persistent and self.persistence and self.name: - self.persistence.update_conversation( - self.name, key, (self.conversations.get(key), new_state) - ) + self._conversations[key] = (self._conversations.get(key), new_state) elif new_state is not None: if new_state not in self.states: @@ -750,20 +763,20 @@ def _update_state(self, new_state: object, key: Tuple[int, ...]) -> None: f"ConversationHandler{' ' + self.name if self.name is not None else ''}.", ) with self._conversations_lock: - self.conversations[key] = new_state - if self.persistent and self.persistence and self.name: - self.persistence.update_conversation(self.name, key, new_state) - - def _trigger_timeout(self, context: CallbackContext) -> None: - self.logger.debug('conversation timeout was triggered!') + self._conversations[key] = new_state + async def _trigger_timeout(self, context: CallbackContext) -> None: job = cast('Job', context.job) ctxt = cast(_ConversationTimeoutContext, job.context) + self._logger.debug( + 'Conversation timeout was triggered for conversation %s!', ctxt.conversation_key + ) + callback_context = ctxt.callback_context with self._timeout_jobs_lock: - found_job = self.timeout_jobs[ctxt.conversation_key] + found_job = self.timeout_jobs.get(ctxt.conversation_key) if found_job is not job: # The timeout has been cancelled in handle_update return @@ -774,10 +787,12 @@ def _trigger_timeout(self, context: CallbackContext) -> None: check = handler.check_update(ctxt.update) if check is not None and check is not False: try: - handler.handle_update(ctxt.update, ctxt.dispatcher, check, callback_context) - except DispatcherHandlerStop: + await handler.handle_update( + ctxt.update, ctxt.application, check, callback_context + ) + except ApplicationHandlerStop: warn( - 'DispatcherHandlerStop in TIMEOUT state of ' + 'ApplicationHandlerStop in TIMEOUT state of ' 'ConversationHandler has no effect. Ignoring.', ) diff --git a/telegram/ext/_defaults.py b/telegram/ext/_defaults.py index d9e9cd9af67..294dc0c1ecb 100644 --- a/telegram/ext/_defaults.py +++ b/telegram/ext/_defaults.py @@ -23,12 +23,17 @@ import pytz from telegram._utils.defaultvalue import DEFAULT_NONE -from telegram._utils.types import ODVInput class Defaults: """Convenience Class to gather all parameters with a (user defined) default value + .. versionchanged:: 14.0 + Removed the argument and attribute ``timeout``. Specify default timeout behavior for the + networking backend directly via :class:`telegram.ext.UpdaterBuilder` or + :class:`telegram.ext.ApplicationBuilder` instead. + + Parameters: parse_mode (:obj:`str`, optional): Send Markdown or HTML, if you want Telegram apps to show bold, italic, fixed-width text or URLs in your bot's message. @@ -38,12 +43,6 @@ class Defaults: message. allow_sending_without_reply (:obj:`bool`, optional): Pass :obj:`True`, if the message should be sent even if the specified replied-to message is not found. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as the - read timeout from the server (instead of the one specified during creation of the - connection pool). - - Note: - Will *not* be used for :meth:`telegram.Bot.get_updates`! quote (:obj:`bool`, optional): If set to :obj:`True`, the reply is sent as an actual reply to the message. If ``reply_to_message_id`` is passed in ``kwargs``, this parameter will be ignored. Default: :obj:`True` in group chats and :obj:`False` in private chats. @@ -51,9 +50,9 @@ class Defaults: appearing throughout PTB, i.e. if a timezone naive date(time) object is passed somewhere, it will be assumed to be in ``tzinfo``. Must be a timezone provided by the ``pytz`` module. Defaults to UTC. - run_async (:obj:`bool`, optional): Default setting for the ``run_async`` parameter of - handlers and error handlers registered through :meth:`Dispatcher.add_handler` and - :meth:`Dispatcher.add_error_handler`. Defaults to :obj:`False`. + block (:obj:`bool`, optional): Default setting for the ``block`` parameter of + handlers and error handlers registered through :meth:`Application.add_handler` and + :meth:`Application.add_error_handler`. Defaults to :obj:`True`. protect_content (:obj:`bool`, optional): Protects the contents of the sent message from forwarding and saving. @@ -61,10 +60,9 @@ class Defaults: """ __slots__ = ( - '_timeout', '_tzinfo', '_disable_web_page_preview', - '_run_async', + '_block', '_quote', '_disable_notification', '_allow_sending_without_reply', @@ -78,12 +76,9 @@ def __init__( parse_mode: str = None, disable_notification: bool = None, disable_web_page_preview: bool = None, - # Timeout needs special treatment, since the bot methods have two different - # default values for timeout (None and 20s) - timeout: ODVInput[float] = DEFAULT_NONE, quote: bool = None, tzinfo: pytz.BaseTzInfo = pytz.utc, - run_async: bool = False, + block: bool = True, allow_sending_without_reply: bool = None, protect_content: bool = None, ): @@ -91,10 +86,9 @@ def __init__( self._disable_notification = disable_notification self._disable_web_page_preview = disable_web_page_preview self._allow_sending_without_reply = allow_sending_without_reply - self._timeout = timeout self._quote = quote self._tzinfo = tzinfo - self._run_async = run_async + self._block = block self._protect_content = protect_content # Gather all defaults that actually have a default value @@ -110,9 +104,6 @@ def __init__( value = getattr(self, kwarg) if value not in [None, DEFAULT_NONE]: self._api_defaults[kwarg] = value - # Special casing, as None is a valid default value - if self._timeout != DEFAULT_NONE: - self._api_defaults['timeout'] = self._timeout @property def api_defaults(self) -> Dict[str, Any]: # skip-cq: PY-D0003 @@ -181,18 +172,6 @@ def allow_sending_without_reply(self, value: object) -> NoReturn: "You can not assign a new value to allow_sending_without_reply after initialization." ) - @property - def timeout(self) -> ODVInput[float]: - """:obj:`int` | :obj:`float`: Optional. If this value is specified, use it as the - read timeout from the server (instead of the one specified during creation of the - connection pool). - """ - return self._timeout - - @timeout.setter - def timeout(self, value: object) -> NoReturn: - raise AttributeError("You can not assign a new value to timeout after initialization.") - @property def quote(self) -> Optional[bool]: """:obj:`bool`: Optional. If set to :obj:`True`, the reply is sent as an actual reply @@ -217,15 +196,15 @@ def tzinfo(self, value: object) -> NoReturn: raise AttributeError("You can not assign a new value to tzinfo after initialization.") @property - def run_async(self) -> bool: - """:obj:`bool`: Optional. Default setting for the ``run_async`` parameter of - handlers and error handlers registered through :meth:`Dispatcher.add_handler` and - :meth:`Dispatcher.add_error_handler`. + def block(self) -> bool: + """:obj:`bool`: Optional. Default setting for the ``block`` parameter of + handlers and error handlers registered through :meth:`Application.add_handler` and + :meth:`Application.add_error_handler`. """ - return self._run_async + return self._block - @run_async.setter - def run_async(self, value: object) -> NoReturn: + @block.setter + def block(self, value: object) -> NoReturn: raise AttributeError("You can not assign a new value to run_async after initialization.") @property @@ -250,10 +229,10 @@ def __hash__(self) -> int: self._disable_notification, self._disable_web_page_preview, self._allow_sending_without_reply, - self._timeout, self._quote, self._tzinfo, - self._run_async, + self._block, + self._protect_content, self._protect_content, ) ) diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index 6ea7bb17c53..7a4311763dd 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -55,6 +55,12 @@ class DictPersistence(BasePersistence): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. + update_interval (:obj:`int` | :obj:`float:, optional): The + :class:`~telegram.ext.Application` will update + the persistence in regular intervals. This parameter specifies the time (in seconds) to + wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. + + .. versionadded:: 14.0 user_data_json (:obj:`str`, optional): JSON string that will be used to reconstruct user_data on creating this persistence. Default is ``""``. chat_data_json (:obj:`str`, optional): JSON string that will be used to reconstruct @@ -94,8 +100,9 @@ def __init__( bot_data_json: str = '', conversations_json: str = '', callback_data_json: str = '', + update_interval: float = 60, ): - super().__init__(store_data=store_data) + super().__init__(store_data=store_data, update_interval=update_interval) self._user_data = None self._chat_data = None self._bot_data = None @@ -231,7 +238,7 @@ def conversations_json(self) -> str: return self._conversations_json return self._encode_conversations_to_json(self.conversations) # type: ignore[arg-type] - def get_user_data(self) -> Dict[int, Dict[object, object]]: + async def get_user_data(self) -> Dict[int, Dict[object, object]]: """Returns the user_data created from the ``user_data_json`` or an empty :obj:`dict`. Returns: @@ -241,7 +248,7 @@ def get_user_data(self) -> Dict[int, Dict[object, object]]: self._user_data = {} return self.user_data # type: ignore[return-value] - def get_chat_data(self) -> Dict[int, Dict[object, object]]: + async def get_chat_data(self) -> Dict[int, Dict[object, object]]: """Returns the chat_data created from the ``chat_data_json`` or an empty :obj:`dict`. Returns: @@ -251,7 +258,7 @@ def get_chat_data(self) -> Dict[int, Dict[object, object]]: self._chat_data = {} return self.chat_data # type: ignore[return-value] - def get_bot_data(self) -> Dict[object, object]: + async def get_bot_data(self) -> Dict[object, object]: """Returns the bot_data created from the ``bot_data_json`` or an empty :obj:`dict`. Returns: @@ -261,7 +268,7 @@ def get_bot_data(self) -> Dict[object, object]: self._bot_data = {} return self.bot_data # type: ignore[return-value] - def get_callback_data(self) -> Optional[CDCData]: + async def get_callback_data(self) -> Optional[CDCData]: """Returns the callback_data created from the ``callback_data_json`` or :obj:`None`. .. versionadded:: 13.6 @@ -276,7 +283,7 @@ def get_callback_data(self) -> Optional[CDCData]: return None return self.callback_data[0], self.callback_data[1].copy() - def get_conversations(self, name: str) -> ConversationDict: + async def get_conversations(self, name: str) -> ConversationDict: """Returns the conversations created from the ``conversations_json`` or an empty :obj:`dict`. @@ -287,7 +294,7 @@ def get_conversations(self, name: str) -> ConversationDict: self._conversations = {} return self.conversations.get(name, {}).copy() # type: ignore[union-attr] - def update_conversation( + async def update_conversation( self, name: str, key: Tuple[int, ...], new_state: Optional[object] ) -> None: """Will update the conversations for the given handler. @@ -304,12 +311,12 @@ def update_conversation( self._conversations[name][key] = new_state self._conversations_json = None - def update_user_data(self, user_id: int, data: Dict) -> None: + async def update_user_data(self, user_id: int, data: Dict) -> None: """Will update the user_data (if changed). Args: user_id (:obj:`int`): The user the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.Dispatcher.user_data` ``[user_id]``. + data (:obj:`dict`): The :attr:`telegram.ext.Application.user_data` ``[user_id]``. """ if self._user_data is None: self._user_data = {} @@ -318,12 +325,12 @@ def update_user_data(self, user_id: int, data: Dict) -> None: self._user_data[user_id] = data self._user_data_json = None - def update_chat_data(self, chat_id: int, data: Dict) -> None: + async def update_chat_data(self, chat_id: int, data: Dict) -> None: """Will update the chat_data (if changed). Args: chat_id (:obj:`int`): The chat the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.Dispatcher.chat_data` ``[chat_id]``. + data (:obj:`dict`): The :attr:`telegram.ext.Application.chat_data` ``[chat_id]``. """ if self._chat_data is None: self._chat_data = {} @@ -332,18 +339,18 @@ def update_chat_data(self, chat_id: int, data: Dict) -> None: self._chat_data[chat_id] = data self._chat_data_json = None - def update_bot_data(self, data: Dict) -> None: + async def update_bot_data(self, data: Dict) -> None: """Will update the bot_data (if changed). Args: - data (:obj:`dict`): The :attr:`telegram.ext.Dispatcher.bot_data`. + data (:obj:`dict`): The :attr:`telegram.ext.Application.bot_data`. """ if self._bot_data == data: return self._bot_data = data self._bot_data_json = None - def update_callback_data(self, data: CDCData) -> None: + async def update_callback_data(self, data: CDCData) -> None: """Will update the callback_data (if changed). .. versionadded:: 13.6 @@ -358,7 +365,7 @@ def update_callback_data(self, data: CDCData) -> None: self._callback_data = (data[0], data[1].copy()) self._callback_data_json = None - def drop_chat_data(self, chat_id: int) -> None: + async def drop_chat_data(self, chat_id: int) -> None: """Will delete the specified key from the :attr:`chat_data`. .. versionadded:: 14.0 @@ -371,7 +378,7 @@ def drop_chat_data(self, chat_id: int) -> None: self._chat_data.pop(chat_id, None) self._chat_data_json = None - def drop_user_data(self, user_id: int) -> None: + async def drop_user_data(self, user_id: int) -> None: """Will delete the specified key from the :attr:`user_data`. .. versionadded:: 14.0 @@ -384,28 +391,28 @@ def drop_user_data(self, user_id: int) -> None: self._user_data.pop(user_id, None) self._user_data_json = None - def refresh_user_data(self, user_id: int, user_data: Dict) -> None: + async def refresh_user_data(self, user_id: int, user_data: Dict) -> None: """Does nothing. .. versionadded:: 13.6 .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data` """ - def refresh_chat_data(self, chat_id: int, chat_data: Dict) -> None: + async def refresh_chat_data(self, chat_id: int, chat_data: Dict) -> None: """Does nothing. .. versionadded:: 13.6 .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data` """ - def refresh_bot_data(self, bot_data: Dict) -> None: + async def refresh_bot_data(self, bot_data: Dict) -> None: """Does nothing. .. versionadded:: 13.6 .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data` """ - def flush(self) -> None: + async def flush(self) -> None: """Does nothing. .. versionadded:: 14.0 @@ -431,7 +438,7 @@ def _encode_conversations_to_json(conversations: Dict[str, Dict[Tuple, object]]) return json.dumps(tmp) @staticmethod - def _decode_conversations_from_json(json_string: str) -> Dict[str, Dict[Tuple, object]]: + def _decode_conversations_from_json(json_string: str) -> Dict[str, ConversationDict]: """Helper method to decode a conversations dict (that uses tuples as keys) from a JSON-string created with :meth:`self._encode_conversations_to_json`. @@ -447,7 +454,7 @@ def _decode_conversations_from_json(json_string: str) -> Dict[str, Dict[Tuple, o conversations[handler] = {} for key, state in states.items(): conversations[handler][tuple(json.loads(key))] = state - return conversations + return conversations # type: ignore[return-value] @staticmethod def _decode_user_chat_data_from_json(data: str) -> Dict[int, Dict[object, object]]: diff --git a/telegram/ext/_dispatcher.py b/telegram/ext/_dispatcher.py deleted file mode 100644 index 9d9d2e9a5aa..00000000000 --- a/telegram/ext/_dispatcher.py +++ /dev/null @@ -1,893 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains the Dispatcher class.""" -import inspect -import logging -import weakref -from collections import defaultdict -from pathlib import Path -from queue import Empty, Queue -from threading import BoundedSemaphore, Event, Lock, Thread, current_thread -from time import sleep -from typing import ( - Callable, - DefaultDict, - Dict, - List, - Optional, - Set, - Union, - Generic, - TypeVar, - TYPE_CHECKING, - Tuple, - Mapping, -) -from types import MappingProxyType -from uuid import uuid4 - -from telegram import Update -from telegram._utils.types import DVInput -from telegram.error import TelegramError -from telegram.ext import BasePersistence, ContextTypes, ExtBot -from telegram.ext._handler import Handler -from telegram.ext._callbackdatacache import CallbackDataCache -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram._utils.warnings import warn -from telegram.ext._utils.promise import Promise -from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ, PT -from telegram.ext._utils.stack import was_called_by - -if TYPE_CHECKING: - from telegram import Message - from telegram.ext._jobqueue import Job - from telegram.ext._builders import InitDispatcherBuilder - -DEFAULT_GROUP: int = 0 - -UT = TypeVar('UT') - - -class DispatcherHandlerStop(Exception): - """ - Raise this in a handler or an error handler to prevent execution of any other handler (even in - different group). - - In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the - optional ``state`` parameter instead of returning the next state: - - .. code-block:: python - - def callback(update, context): - ... - raise DispatcherHandlerStop(next_state) - - Note: - Has no effect, if the handler or error handler is run asynchronously. - - Args: - state (:obj:`object`, optional): The next state of the conversation. - - Attributes: - state (:obj:`object`): Optional. The next state of the conversation. - """ - - __slots__ = ('state',) - - def __init__(self, state: object = None) -> None: - super().__init__() - self.state = state - - -class Dispatcher(Generic[BT, CCT, UD, CD, BD, JQ, PT]): - """This class dispatches all kinds of updates to its registered handlers. - - Note: - This class may not be initialized directly. Use :class:`telegram.ext.DispatcherBuilder` or - :meth:`builder` (for convenience). - - .. versionchanged:: 14.0 - - * Initialization is now done through the :class:`telegram.ext.DispatcherBuilder`. - * Removed the attribute ``groups``. - - Attributes: - bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. - update_queue (:obj:`Queue`): The synchronized queue that will contain the updates. - job_queue (:class:`telegram.ext.JobQueue`): Optional. The :class:`telegram.ext.JobQueue` - instance to pass onto handler callbacks. - workers (:obj:`int`, optional): Number of maximum concurrent worker threads for the - ``@run_async`` decorator and :meth:`run_async`. - chat_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for - the chat. - - .. versionchanged:: 14.0 - :attr:`chat_data` is now read-only - - .. tip:: - Manually modifying :attr:`chat_data` is almost never needed and unadvisable. - - user_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for - the user. - - .. versionchanged:: 14.0 - :attr:`user_data` is now read-only - - .. tip:: - Manually modifying :attr:`user_data` is almost never needed and unadvisable. - - bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. - persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to - store data that should be persistent over restarts. - exception_event (:class:`threading.Event`): When this event is set, the dispatcher will - stop processing updates. If this dispatcher is used together with an - :class:`telegram.ext.Updater`, then this event will be the same object as - :attr:`telegram.ext.Updater.exception_event`. - handlers (Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]): A dictionary mapping each - handler group to the list of handlers registered to that group. - - .. seealso:: - :meth:`add_handler`, :meth:`add_handlers`. - error_handlers (Dict[:obj:`callable`, :obj:`bool`]): A dict, where the keys are error - handlers and the values indicate whether they are to be run asynchronously via - :meth:`run_async`. - - .. seealso:: - :meth:`add_error_handler` - running (:obj:`bool`): Indicates if this dispatcher is running. - - .. seealso:: - :meth:`start`, :meth:`stop` - - """ - - # Allowing '__weakref__' creation here since we need it for the singleton - __slots__ = ( - 'workers', - 'persistence', - 'update_queue', - 'job_queue', - '_user_data', - 'user_data', - '_chat_data', - 'chat_data', - 'bot_data', - '_update_persistence_lock', - 'handlers', - 'error_handlers', - 'running', - '__stop_event', - 'exception_event', - '__async_queue', - '__async_threads', - 'bot', - '__weakref__', - 'context_types', - ) - - __singleton_lock = Lock() - __singleton_semaphore = BoundedSemaphore() - __singleton = None - logger = logging.getLogger(__name__) - - def __init__( - self: 'Dispatcher[BT, CCT, UD, CD, BD, JQ, PT]', - *, - bot: BT, - update_queue: Queue, - job_queue: JQ, - workers: int, - persistence: PT, - context_types: ContextTypes[CCT, UD, CD, BD], - exception_event: Event, - stack_level: int = 4, - ): - if not was_called_by( - inspect.currentframe(), Path(__file__).parent.resolve() / '_builders.py' - ): - warn( - '`Dispatcher` instances should be built via the `DispatcherBuilder`.', - stacklevel=2, - ) - - self.bot = bot - self.update_queue = update_queue - self.job_queue = job_queue - self.workers = workers - self.context_types = context_types - self.exception_event = exception_event - - if self.workers < 1: - warn( - 'Asynchronous callbacks can not be processed without at least one worker thread.', - stacklevel=stack_level, - ) - - self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data) - self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data) - # Read only mapping- - self.user_data: Mapping[int, UD] = MappingProxyType(self._user_data) - self.chat_data: Mapping[int, CD] = MappingProxyType(self._chat_data) - - self.bot_data = self.context_types.bot_data() - - self.persistence: Optional[BasePersistence] - self._update_persistence_lock = Lock() - if persistence: - if not isinstance(persistence, BasePersistence): - raise TypeError("persistence must be based on telegram.ext.BasePersistence") - - self.persistence = persistence - # This raises an exception if persistence.store_data.callback_data is True - # but self.bot is not an instance of ExtBot - so no need to check that later on - self.persistence.set_bot(self.bot) - - if self.persistence.store_data.user_data: - self._user_data.update(self.persistence.get_user_data()) - if self.persistence.store_data.chat_data: - self._chat_data.update(self.persistence.get_chat_data()) - if self.persistence.store_data.bot_data: - self.bot_data = self.persistence.get_bot_data() - if not isinstance(self.bot_data, self.context_types.bot_data): - raise ValueError( - f"bot_data must be of type {self.context_types.bot_data.__name__}" - ) - if self.persistence.store_data.callback_data: - persistent_data = self.persistence.get_callback_data() - if persistent_data is not None: - if not isinstance(persistent_data, tuple) and len(persistent_data) != 2: - raise ValueError('callback_data must be a tuple of length 2') - # Mypy doesn't know that persistence.set_bot (see above) already checks that - # self.bot is an instance of ExtBot if callback_data should be stored ... - self.bot.callback_data_cache = CallbackDataCache( # type: ignore[attr-defined] - self.bot, # type: ignore[arg-type] - self.bot.callback_data_cache.maxsize, # type: ignore[attr-defined] - persistent_data=persistent_data, - ) - else: - self.persistence = None - - self.handlers: Dict[int, List[Handler]] = {} - self.error_handlers: Dict[Callable, Union[bool, DefaultValue]] = {} - - self.running = False - self.__stop_event = Event() - self.__async_queue: Queue = Queue() - self.__async_threads: Set[Thread] = set() - - # For backward compatibility, we allow a "singleton" mode for the dispatcher. When there's - # only one instance of Dispatcher, it will be possible to use the `run_async` decorator. - with self.__singleton_lock: - # pylint: disable=consider-using-with - if self.__singleton_semaphore.acquire(blocking=False): - self._set_singleton(self) - else: - self._set_singleton(None) - - @staticmethod - def builder() -> 'InitDispatcherBuilder': - """Convenience method. Returns a new :class:`telegram.ext.DispatcherBuilder`. - - .. versionadded:: 14.0 - """ - # Unfortunately this needs to be here due to cyclical imports - from telegram.ext import DispatcherBuilder # pylint: disable=import-outside-toplevel - - return DispatcherBuilder() - - def _init_async_threads(self, base_name: str, workers: int) -> None: - base_name = f'{base_name}_' if base_name else '' - - for i in range(workers): - thread = Thread(target=self._pooled, name=f'Bot:{self.bot.id}:worker:{base_name}{i}') - self.__async_threads.add(thread) - thread.start() - - @classmethod - def _set_singleton(cls, val: Optional['Dispatcher']) -> None: - cls.logger.debug('Setting singleton dispatcher as %s', val) - cls.__singleton = weakref.ref(val) if val else None - - @classmethod - def get_instance(cls) -> 'Dispatcher': - """Get the singleton instance of this class. - - Returns: - :class:`telegram.ext.Dispatcher` - - Raises: - RuntimeError - - """ - if cls.__singleton is not None: - return cls.__singleton() # type: ignore[return-value] # pylint: disable=not-callable - raise RuntimeError(f'{cls.__name__} not initialized or multiple instances exist') - - def _pooled(self) -> None: - thr_name = current_thread().name - while 1: - promise = self.__async_queue.get() - - # If unpacking fails, the thread pool is being closed from Updater._join_async_threads - if not isinstance(promise, Promise): - self.logger.debug( - "Closing run_async thread %s/%d", thr_name, len(self.__async_threads) - ) - break - - promise.run() - - if not promise.exception: - self.update_persistence(update=promise.update) - continue - - if isinstance(promise.exception, DispatcherHandlerStop): - warn( - 'DispatcherHandlerStop is not supported with async functions; ' - f'func: {promise.pooled_function.__name__}', - ) - continue - - # Avoid infinite recursion of error handlers. - if promise.pooled_function in self.error_handlers: - self.logger.exception( - 'An error was raised and an uncaught error was raised while ' - 'handling the error with an error_handler.', - exc_info=promise.exception, - ) - continue - - # If we arrive here, an exception happened in the promise and was neither - # DispatcherHandlerStop nor raised by an error handler. So we can and must handle it - self.dispatch_error(promise.update, promise.exception, promise=promise) - - def run_async( - self, func: Callable[..., object], *args: object, update: object = None, **kwargs: object - ) -> Promise: - """ - Queue a function (with given args/kwargs) to be run asynchronously. Exceptions raised - by the function will be handled by the error handlers registered with - :meth:`add_error_handler`. - - Warning: - * If you're using ``@run_async``/:meth:`run_async` you cannot rely on adding custom - attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. - * Calling a function through :meth:`run_async` from within an error handler can lead to - an infinite error handling loop. - - Args: - func (:obj:`callable`): The function to run in the thread. - *args (:obj:`tuple`, optional): Arguments to ``func``. - update (:class:`telegram.Update` | :obj:`object`, optional): The update associated with - the functions call. If passed, it will be available in the error handlers, in case - an exception is raised by :attr:`func`. - **kwargs (:obj:`dict`, optional): Keyword arguments to ``func``. - - Returns: - Promise - - """ - promise = Promise(func, args, kwargs, update=update) - self.__async_queue.put(promise) - return promise - - def start(self, ready: Event = None) -> None: - """Thread target of thread 'dispatcher'. - - Runs in background and processes the update queue. Also starts :attr:`job_queue`, if set. - - Args: - ready (:obj:`threading.Event`, optional): If specified, the event will be set once the - dispatcher is ready. - - """ - if self.running: - self.logger.warning('already running') - if ready is not None: - ready.set() - return - - if self.exception_event.is_set(): - msg = 'reusing dispatcher after exception event is forbidden' - self.logger.error(msg) - raise TelegramError(msg) - - if self.job_queue: - self.job_queue.start() - self._init_async_threads(str(uuid4()), self.workers) - self.running = True - self.logger.debug('Dispatcher started') - - if ready is not None: - ready.set() - - while 1: - try: - # Pop update from update queue. - update = self.update_queue.get(True, 1) - except Empty: - if self.__stop_event.is_set(): - self.logger.debug('orderly stopping') - break - if self.exception_event.is_set(): - self.logger.critical('stopping due to exception in another thread') - break - continue - - self.logger.debug('Processing Update: %s', update) - self.process_update(update) - self.update_queue.task_done() - - self.running = False - self.logger.debug('Dispatcher thread stopped') - - def stop(self) -> None: - """Stops the thread and :attr:`job_queue`, if set. - Also calls :meth:`update_persistence` and :meth:`BasePersistence.flush` on - :attr:`persistence`, if set. - """ - if self.running: - self.__stop_event.set() - while self.running: - sleep(0.1) - self.__stop_event.clear() - - # async threads must be join()ed only after the dispatcher thread was joined, - # otherwise we can still have new async threads dispatched - threads = list(self.__async_threads) - total = len(threads) - - # Stop all threads in the thread pool by put()ting one non-tuple per thread - for i in range(total): - self.__async_queue.put(None) - - for i, thr in enumerate(threads): - self.logger.debug('Waiting for async thread %s/%s to end', i + 1, total) - thr.join() - self.__async_threads.remove(thr) - self.logger.debug('async thread %s/%s has ended', i + 1, total) - - if self.job_queue: - self.job_queue.stop() - self.logger.debug('JobQueue was shut down.') - - self.update_persistence() - if self.persistence: - self.persistence.flush() - - # Clear the connection pool - self.bot.request.stop() - - @property - def has_running_threads(self) -> bool: # skipcq: PY-D0003 - return self.running or bool(self.__async_threads) - - def process_update(self, update: object) -> None: - """Processes a single update and updates the persistence. - - Note: - If the update is handled by least one synchronously running handlers (i.e. - ``run_async=False``), :meth:`update_persistence` is called *once* after all handlers - synchronous handlers are done. Each asynchronously running handler will trigger - :meth:`update_persistence` on its own. - - Args: - update (:class:`telegram.Update` | :obj:`object` | \ - :class:`telegram.error.TelegramError`): - The update to process. - - """ - # An error happened while polling - if isinstance(update, TelegramError): - self.dispatch_error(None, update) - return - - context = None - handled = False - sync_modes = [] - - for handlers in self.handlers.values(): - try: - for handler in handlers: - check = handler.check_update(update) - if check is not None and check is not False: - if not context: - context = self.context_types.context.from_update(update, self) - context.refresh_data() - handled = True - sync_modes.append(handler.run_async) - handler.handle_update(update, self, check, context) - break - - # Stop processing with any other handler. - except DispatcherHandlerStop: - self.logger.debug('Stopping further handlers due to DispatcherHandlerStop') - self.update_persistence(update=update) - break - - # Dispatch any error. - except Exception as exc: - if self.dispatch_error(update, exc): - self.logger.debug('Error handler stopped further handlers.') - break - - # Update persistence, if handled - handled_only_async = all(sync_modes) - if handled: - # Respect default settings - if ( - all(mode is DEFAULT_FALSE for mode in sync_modes) - and isinstance(self.bot, ExtBot) - and self.bot.defaults - ): - handled_only_async = self.bot.defaults.run_async - # If update was only handled by async handlers, we don't need to update here - if not handled_only_async: - self.update_persistence(update=update) - - def add_handler(self, handler: Handler[UT, CCT], group: int = DEFAULT_GROUP) -> None: - """Register a handler. - - TL;DR: Order and priority counts. 0 or 1 handlers per group will be used. End handling of - update with :class:`telegram.ext.DispatcherHandlerStop`. - - A handler must be an instance of a subclass of :class:`telegram.ext.Handler`. All handlers - are organized in groups with a numeric value. The default group is 0. All groups will be - evaluated for handling an update, but only 0 or 1 handler per group will be used. If - :class:`telegram.ext.DispatcherHandlerStop` is raised from one of the handlers, no further - handlers (regardless of the group) will be called. - - The priority/order of handlers is determined as follows: - - * Priority of the group (lower group number == higher priority) - * The first handler in a group which should handle an update (see - :attr:`telegram.ext.Handler.check_update`) will be used. Other handlers from the - group will not be used. The order in which handlers were added to the group defines the - priority. - - Args: - handler (:class:`telegram.ext.Handler`): A Handler instance. - group (:obj:`int`, optional): The group identifier. Default is 0. - - """ - # Unfortunately due to circular imports this has to be here - # pylint: disable=import-outside-toplevel - from telegram.ext._conversationhandler import ConversationHandler - - if not isinstance(handler, Handler): - raise TypeError(f'handler is not an instance of {Handler.__name__}') - if not isinstance(group, int): - raise TypeError('group is not int') - # For some reason MyPy infers the type of handler is here, - # so for now we just ignore all the errors - if ( - isinstance(handler, ConversationHandler) - and handler.persistent # type: ignore[attr-defined] - and handler.name # type: ignore[attr-defined] - ): - if not self.persistence: - raise ValueError( - f"ConversationHandler {handler.name} " # type: ignore[attr-defined] - f"can not be persistent if dispatcher has no persistence" - ) - handler.persistence = self.persistence # type: ignore[attr-defined] - handler.conversations = ( # type: ignore[attr-defined] - self.persistence.get_conversations(handler.name) # type: ignore[attr-defined] - ) - - if group not in self.handlers: - self.handlers[group] = [] - self.handlers = dict(sorted(self.handlers.items())) # lower -> higher groups - - self.handlers[group].append(handler) - - def add_handlers( - self, - handlers: Union[ - Union[List[Handler], Tuple[Handler]], Dict[int, Union[List[Handler], Tuple[Handler]]] - ], - group: DVInput[int] = DefaultValue(0), - ) -> None: - """Registers multiple handlers at once. The order of the handlers in the passed - sequence(s) matters. See :meth:`add_handler` for details. - - .. versionadded:: 14.0 - .. seealso:: :meth:`add_handler` - - Args: - handlers (List[:obj:`telegram.ext.Handler`] | \ - Dict[int, List[:obj:`telegram.ext.Handler`]]): \ - Specify a sequence of handlers *or* a dictionary where the keys are groups and - values are handlers. - group (:obj:`int`, optional): Specify which group the sequence of ``handlers`` - should be added to. Defaults to ``0``. - - """ - if isinstance(handlers, dict) and not isinstance(group, DefaultValue): - raise ValueError('The `group` argument can only be used with a sequence of handlers.') - - if isinstance(handlers, dict): - for handler_group, grp_handlers in handlers.items(): - if not isinstance(grp_handlers, (list, tuple)): - raise ValueError(f'Handlers for group {handler_group} must be a list or tuple') - - for handler in grp_handlers: - self.add_handler(handler, handler_group) - - elif isinstance(handlers, (list, tuple)): - for handler in handlers: - self.add_handler(handler, DefaultValue.get_value(group)) - - else: - raise ValueError( - "The `handlers` argument must be a sequence of handlers or a " - "dictionary where the keys are groups and values are sequences of handlers." - ) - - def remove_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None: - """Remove a handler from the specified group. - - Args: - handler (:class:`telegram.ext.Handler`): A Handler instance. - group (:obj:`object`, optional): The group identifier. Default is 0. - - """ - if handler in self.handlers[group]: - self.handlers[group].remove(handler) - if not self.handlers[group]: - del self.handlers[group] - - def drop_chat_data(self, chat_id: int) -> None: - """Used for deleting a key from the :attr:`chat_data`. - - .. versionadded:: 14.0 - - Args: - chat_id (:obj:`int`): The chat id to delete from the persistence. The entry - will be deleted even if it is not empty. - """ - self._chat_data.pop(chat_id, None) # type: ignore[arg-type] - - if self.persistence: - self.persistence.drop_chat_data(chat_id) - - def drop_user_data(self, user_id: int) -> None: - """Used for deleting a key from the :attr:`user_data`. - - .. versionadded:: 14.0 - - Args: - user_id (:obj:`int`): The user id to delete from the persistence. The entry - will be deleted even if it is not empty. - """ - self._user_data.pop(user_id, None) # type: ignore[arg-type] - - if self.persistence: - self.persistence.drop_user_data(user_id) - - def migrate_chat_data( - self, message: 'Message' = None, old_chat_id: int = None, new_chat_id: int = None - ) -> None: - """Moves the contents of :attr:`chat_data` at key old_chat_id to the key new_chat_id. - Also updates the persistence by calling :attr:`update_persistence`. - - Warning: - - * Any data stored in :attr:`chat_data` at key `new_chat_id` will be overridden - * The key `old_chat_id` of :attr:`chat_data` will be deleted - - Args: - message (:class:`Message`, optional): A message with either - :attr:`telegram.Message.migrate_from_chat_id` or - :attr:`telegram.Message.migrate_to_chat_id`. - Mutually exclusive with passing ``old_chat_id`` and ``new_chat_id`` - - .. seealso: `telegram.ext.filters.StatusUpdate.MIGRATE` - old_chat_id (:obj:`int`, optional): The old chat ID. - Mutually exclusive with passing ``message`` - new_chat_id (:obj:`int`, optional): The new chat ID. - Mutually exclusive with passing ``message`` - - """ - if message and (old_chat_id or new_chat_id): - raise ValueError("Message and chat_id pair are mutually exclusive") - if not any((message, old_chat_id, new_chat_id)): - raise ValueError("chat_id pair or message must be passed") - - if message: - if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: - raise ValueError( - "Invalid message instance. The message must have either " - "`Message.migrate_from_chat_id` or `Message.migrate_to_chat_id`." - ) - - old_chat_id = message.migrate_from_chat_id or message.chat.id - new_chat_id = message.migrate_to_chat_id or message.chat.id - - elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): - raise ValueError("old_chat_id and new_chat_id must be integers") - - self._chat_data[new_chat_id] = self._chat_data[old_chat_id] - self.drop_chat_data(old_chat_id) - self.update_persistence() - - def update_persistence(self, update: object = None) -> None: - """Update :attr:`user_data`, :attr:`chat_data` and :attr:`bot_data` in :attr:`persistence`. - - Args: - update (:class:`telegram.Update`, optional): The update to process. If passed, only the - corresponding ``user_data`` and ``chat_data`` will be updated. - """ - with self._update_persistence_lock: - self.__update_persistence(update) - - def __update_persistence(self, update: object = None) -> None: - if self.persistence: - # We use list() here in order to decouple chat_ids from self.chat_data, as dict view - # objects will change, when the dict does and we want to loop over chat_ids - chat_ids = list(self.chat_data.keys()) - user_ids = list(self.user_data.keys()) - - if isinstance(update, Update): - if update.effective_chat: - chat_ids = [update.effective_chat.id] - else: - chat_ids = [] - if update.effective_user: - user_ids = [update.effective_user.id] - else: - user_ids = [] - - if self.persistence.store_data.callback_data: - try: - # Mypy doesn't know that persistence.set_bot (see above) already checks that - # self.bot is an instance of ExtBot if callback_data should be stored ... - self.persistence.update_callback_data( - self.bot.callback_data_cache.persistence_data # type: ignore[attr-defined] - ) - except Exception as exc: - self.dispatch_error(update, exc) - if self.persistence.store_data.bot_data: - try: - self.persistence.update_bot_data(self.bot_data) - except Exception as exc: - self.dispatch_error(update, exc) - if self.persistence.store_data.chat_data: - for chat_id in chat_ids: - try: - self.persistence.update_chat_data(chat_id, self.chat_data[chat_id]) - except Exception as exc: - self.dispatch_error(update, exc) - if self.persistence.store_data.user_data: - for user_id in user_ids: - try: - self.persistence.update_user_data(user_id, self.user_data[user_id]) - except Exception as exc: - self.dispatch_error(update, exc) - - def add_error_handler( - self, - callback: Callable[[object, CCT], None], - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, - ) -> None: - """Registers an error handler in the Dispatcher. This handler will receive every error - which happens in your bot. See the docs of :meth:`dispatch_error` for more details on how - errors are handled. - - Note: - Attempts to add the same callback multiple times will be ignored. - - Args: - callback (:obj:`callable`): The callback function for this error handler. Will be - called when an error is raised. Callback signature: - ``def callback(update: Update, context: CallbackContext)``. - The error that happened will be present in ``context.error``. - run_async (:obj:`bool`, optional): Whether this handlers callback should be run - asynchronously using :meth:`run_async`. Defaults to :obj:`False`. - """ - if callback in self.error_handlers: - self.logger.debug('The callback is already registered as an error handler. Ignoring.') - return - - if ( - run_async is DEFAULT_FALSE - and isinstance(self.bot, ExtBot) - and self.bot.defaults - and self.bot.defaults.run_async - ): - run_async = True - - self.error_handlers[callback] = run_async - - def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None: - """Removes an error handler. - - Args: - callback (:obj:`callable`): The error handler to remove. - - """ - self.error_handlers.pop(callback, None) - - def dispatch_error( - self, - update: Optional[object], - error: Exception, - promise: Promise = None, - job: 'Job' = None, - ) -> bool: - """Dispatches an error by passing it to all error handlers registered with - :meth:`add_error_handler`. If one of the error handlers raises - :class:`telegram.ext.DispatcherHandlerStop`, the update will not be handled by other error - handlers or handlers (even in other groups). All other exceptions raised by an error - handler will just be logged. - - .. versionchanged:: 14.0 - - * Exceptions raised by error handlers are now properly logged. - * :class:`telegram.ext.DispatcherHandlerStop` is no longer reraised but converted into - the return value. - - Args: - update (:obj:`object` | :class:`telegram.Update`): The update that caused the error. - error (:obj:`Exception`): The error that was raised. - promise (:class:`telegram._utils.Promise`, optional): The promise whose pooled function - raised the error. - job (:class:`telegram.ext.Job`, optional): The job that caused the error. - - .. versionadded:: 14.0 - - Returns: - :obj:`bool`: :obj:`True` if one of the error handlers raised - :class:`telegram.ext.DispatcherHandlerStop`. :obj:`False`, otherwise. - """ - async_args = None if not promise else promise.args - async_kwargs = None if not promise else promise.kwargs - - if self.error_handlers: - for ( - callback, - run_async, - ) in self.error_handlers.items(): # pylint: disable=redefined-outer-name - context = self.context_types.context.from_error( - update=update, - error=error, - dispatcher=self, - async_args=async_args, - async_kwargs=async_kwargs, - job=job, - ) - if run_async: - self.run_async(callback, update, context, update=update) - else: - try: - callback(update, context) - except DispatcherHandlerStop: - return True - except Exception as exc: - self.logger.exception( - 'An error was raised and an uncaught error was raised while ' - 'handling the error with an error_handler.', - exc_info=exc, - ) - return False - - self.logger.exception( - 'No error handlers are registered, logging exception.', exc_info=error - ) - return False diff --git a/telegram/ext/_extbot.py b/telegram/ext/_extbot.py index 5aed2380bab..406977e6c53 100644 --- a/telegram/ext/_extbot.py +++ b/telegram/ext/_extbot.py @@ -51,10 +51,10 @@ from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue from telegram._utils.datetime import to_timestamp from telegram.ext._callbackdatacache import CallbackDataCache +from telegram.request import BaseRequest if TYPE_CHECKING: from telegram import InlineQueryResult, MessageEntity - from telegram.request import Request from telegram.ext import Defaults HandledTypes = TypeVar('HandledTypes', bound=Union[Message, CallbackQuery, Chat]) @@ -96,7 +96,8 @@ def __init__( token: str, base_url: str = 'https://api.telegram.org/bot', base_file_url: str = 'https://api.telegram.org/file/bot', - request: 'Request' = None, + request: BaseRequest = None, + get_updates_request: BaseRequest = None, private_key: bytes = None, private_key_password: bytes = None, defaults: 'Defaults' = None, @@ -107,6 +108,7 @@ def __init__( base_url=base_url, base_file_url=base_file_url, request=request, + get_updates_request=get_updates_request, private_key=private_key, private_key_password=private_key_password, ) @@ -127,9 +129,7 @@ def defaults(self) -> Optional['Defaults']: # This is a property because defaults shouldn't be changed at runtime return self._defaults - def _insert_defaults( - self, data: Dict[str, object], timeout: ODVInput[float] - ) -> Optional[float]: + def _insert_defaults(self, data: Dict[str, object]) -> None: """Inserts the defaults values for optional kwargs for which tg.ext.Defaults provides convenience functionality, i.e. the kwargs with a tg.utils.helpers.DefaultValue default @@ -166,17 +166,6 @@ def _insert_defaults( if media.parse_mode is DEFAULT_NONE: media.parse_mode = self.defaults.parse_mode if self.defaults else None - effective_timeout = DefaultValue.get_value(timeout) - if isinstance(timeout, DefaultValue): - # If we get here, we use Defaults.timeout, unless that's not set, which is the - # case if isinstance(self.defaults.timeout, DefaultValue) - return ( - self.defaults.timeout - if self.defaults and not isinstance(self.defaults.timeout, DefaultValue) - else effective_timeout - ) - return effective_timeout - def _replace_keyboard(self, reply_markup: Optional[ReplyMarkup]) -> Optional[ReplyMarkup]: # If the reply_markup is an inline keyboard and we allow arbitrary callback data, let the # CallbackDataCache build a new keyboard with the data replaced. Otherwise return the input @@ -246,7 +235,7 @@ def _insert_callback_data(self, obj: HandledTypes) -> HandledTypes: return obj - def _message( + async def _send_message( self, endpoint: str, data: JSONDict, @@ -254,20 +243,26 @@ def _message( disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> Union[bool, Message]: # We override this method to call self._replace_keyboard and self._insert_callback_data. # This covers most methods that have a reply_markup - result = super()._message( + result = await super()._send_message( endpoint=endpoint, data=data, reply_to_message_id=reply_to_message_id, disable_notification=disable_notification, reply_markup=self._replace_keyboard(reply_markup), allow_sending_without_reply=allow_sending_without_reply, - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) @@ -275,20 +270,26 @@ def _message( self._insert_callback_data(result) return result - def get_updates( + async def get_updates( self, offset: int = None, limit: int = 100, - timeout: float = 0, - read_latency: float = 2.0, + timeout: int = 0, + read_timeout: float = 2, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, allowed_updates: List[str] = None, api_kwargs: JSONDict = None, ) -> List[Update]: - updates = super().get_updates( + updates = await super().get_updates( offset=offset, limit=limit, timeout=timeout, - read_latency=read_latency, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, allowed_updates=allowed_updates, api_kwargs=api_kwargs, ) @@ -356,24 +357,30 @@ def _insert_defaults_for_ilq_results(self, res: 'InlineQueryResult') -> None: self.defaults.disable_web_page_preview if self.defaults else None ) - def stop_poll( + async def stop_poll( self, chat_id: Union[int, str], message_id: int, reply_markup: InlineKeyboardMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Poll: # We override this method to call self._replace_keyboard - return super().stop_poll( + return await super().stop_poll( chat_id=chat_id, message_id=message_id, reply_markup=self._replace_keyboard(reply_markup), - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, ) - def copy_message( + async def copy_message( self, chat_id: Union[int, str], from_chat_id: Union[str, int], @@ -385,12 +392,15 @@ def copy_message( reply_to_message_id: int = None, allow_sending_without_reply: DVInput[bool] = DEFAULT_NONE, reply_markup: ReplyMarkup = None, - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, protect_content: ODVInput[bool] = DEFAULT_NONE, ) -> MessageId: # We override this method to call self._replace_keyboard - return super().copy_message( + return await super().copy_message( chat_id=chat_id, from_chat_id=from_chat_id, message_id=message_id, @@ -401,19 +411,32 @@ def copy_message( reply_to_message_id=reply_to_message_id, allow_sending_without_reply=allow_sending_without_reply, reply_markup=self._replace_keyboard(reply_markup), - timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, api_kwargs=api_kwargs, protect_content=protect_content, ) - def get_chat( + async def get_chat( self, chat_id: Union[str, int], - timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Chat: # We override this method to call self._insert_callback_data - result = super().get_chat(chat_id=chat_id, timeout=timeout, api_kwargs=api_kwargs) + result = await super().get_chat( + chat_id=chat_id, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + api_kwargs=api_kwargs, + ) return self._insert_callback_data(result) # updated camelCase aliases diff --git a/telegram/ext/_handler.py b/telegram/ext/_handler.py index 49ec8976ccf..317e4770833 100644 --- a/telegram/ext/_handler.py +++ b/telegram/ext/_handler.py @@ -16,17 +16,16 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains the base class for handlers as used by the Dispatcher.""" +"""This module contains the base class for handlers as used by the Application.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, Generic +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, Generic -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.promise import Promise -from telegram.ext._utils.types import CCT -from telegram.ext._extbot import ExtBot +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram._utils.types import DVInput +from telegram.ext._utils.types import CCT, HandlerCallback if TYPE_CHECKING: - from telegram.ext import Dispatcher + from telegram.ext import Application RT = TypeVar('RT') UT = TypeVar('UT') @@ -36,7 +35,7 @@ class Handler(Generic[UT, CCT], ABC): """The base class for all update handlers. Create custom handlers by inheriting from it. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -46,27 +45,28 @@ class Handler(Generic[UT, CCT], ABC): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. """ __slots__ = ( 'callback', - 'run_async', + 'block', ) def __init__( self, - callback: Callable[[UT, CCT], RT], - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + callback: HandlerCallback[UT, CCT, RT], + block: DVInput[bool] = DEFAULT_TRUE, ): self.callback = callback - self.run_async = run_async + self.block = block @abstractmethod def check_update(self, update: object) -> Optional[Union[bool, object]]: @@ -75,7 +75,7 @@ def check_update(self, update: object) -> Optional[Union[bool, object]]: this handler instance. It should always be overridden. Note: - Custom updates types can be handled by the dispatcher. Therefore, an implementation of + Custom updates types can be handled by the application. Therefore, an implementation of this method should always check the type of :attr:`update`. Args: @@ -88,13 +88,13 @@ def check_update(self, update: object) -> Optional[Union[bool, object]]: """ - def handle_update( + async def handle_update( self, update: UT, - dispatcher: 'Dispatcher', + application: 'Application', check_result: object, context: CCT, - ) -> Union[RT, Promise]: + ) -> RT: """ This method is called if it was determined that an update should indeed be handled by this instance. Calls :attr:`callback` along with its respectful @@ -104,31 +104,20 @@ def handle_update( Args: update (:obj:`str` | :class:`telegram.Update`): The update to be handled. - dispatcher (:class:`telegram.ext.Dispatcher`): The calling dispatcher. + application (:class:`telegram.ext.Application`): The calling application. check_result (:obj:`obj`): The result from :attr:`check_update`. context (:class:`telegram.ext.CallbackContext`): The context as provided by - the dispatcher. + the application. """ - run_async = self.run_async - if ( - self.run_async is DEFAULT_FALSE - and isinstance(dispatcher.bot, ExtBot) - and dispatcher.bot.defaults - and dispatcher.bot.defaults.run_async - ): - run_async = True - - self.collect_additional_context(context, update, dispatcher, check_result) - if run_async: - return dispatcher.run_async(self.callback, update, context, update=update) - return self.callback(update, context) + self.collect_additional_context(context, update, application, check_result) + return await self.callback(update, context) def collect_additional_context( self, context: CCT, update: UT, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Any, ) -> None: """Prepares additional arguments for the context. Override if needed. @@ -136,7 +125,7 @@ def collect_additional_context( Args: context (:class:`telegram.ext.CallbackContext`): The context object. update (:class:`telegram.Update`): The update to gather chat/user id from. - dispatcher (:class:`telegram.ext.Dispatcher`): The calling dispatcher. + application (:class:`telegram.ext.Application`): The calling application. check_result: The result (return value) from :attr:`check_update`. """ diff --git a/telegram/ext/_inlinequeryhandler.py b/telegram/ext/_inlinequeryhandler.py index 60883d026c6..0f6ea409dfd 100644 --- a/telegram/ext/_inlinequeryhandler.py +++ b/telegram/ext/_inlinequeryhandler.py @@ -20,7 +20,6 @@ import re from typing import ( TYPE_CHECKING, - Callable, Match, Optional, Pattern, @@ -31,12 +30,13 @@ ) from telegram import Update +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.types import CCT +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram.ext._utils.types import CCT, HandlerCallback if TYPE_CHECKING: - from telegram.ext import Dispatcher + from telegram.ext import Application RT = TypeVar('RT') @@ -47,7 +47,7 @@ class InlineQueryHandler(Handler[Update, CCT]): documentation of the ``re`` module for more information. Warning: - * When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + * When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. * :attr:`telegram.InlineQuery.chat_type` will not be set for inline queries from secret chats and may not be set for inline queries coming from third-party clients. These @@ -67,8 +67,9 @@ class InlineQueryHandler(Handler[Update, CCT]): handle inline queries with the appropriate :attr:`telegram.InlineQuery.chat_type`. .. versionadded:: 13.5 - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. @@ -77,7 +78,9 @@ class InlineQueryHandler(Handler[Update, CCT]): chat_types (List[:obj:`str`], optional): List of allowed chat types. .. versionadded:: 13.5 - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -85,15 +88,12 @@ class InlineQueryHandler(Handler[Update, CCT]): def __init__( self, - callback: Callable[[Update, CCT], RT], + callback: HandlerCallback[Update, CCT, RT], pattern: Union[str, Pattern] = None, - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + block: DVInput[bool] = DEFAULT_TRUE, chat_types: List[str] = None, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) if isinstance(pattern, str): pattern = re.compile(pattern) @@ -130,7 +130,7 @@ def collect_additional_context( self, context: CCT, update: Update, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Optional[Union[bool, Match]], ) -> None: """Add the result of ``re.match(pattern, update.inline_query.query)`` to diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 0498325e3b7..a75435370cb 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -17,21 +17,22 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the classes JobQueue and Job.""" - +import asyncio import datetime import weakref -from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union, cast, overload +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast, overload import pytz -from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.executors.asyncio import AsyncIOExecutor +from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.job import Job as APSJob from telegram._utils.types import JSONDict from telegram.ext._extbot import ExtBot +from telegram.ext._utils.types import JobCallback if TYPE_CHECKING: - from telegram.ext import Dispatcher, CallbackContext - import apscheduler.job # noqa: F401 + from telegram.ext import Application class JobQueue: @@ -39,15 +40,20 @@ class JobQueue: wrapper for the APScheduler library. Attributes: - scheduler (:class:`apscheduler.schedulers.background.BackgroundScheduler`): The APScheduler + scheduler (:class:`apscheduler.schedulers.asyncio.AsyncIOScheduler`): The scheduler. + ..versionchanged:: 14.0 + Use :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instead of + :class:`apscheduler.schedulers.background.BackgroundScheduler` + """ - __slots__ = ('_dispatcher', 'scheduler') + __slots__ = ('_application', 'scheduler', '_executor') def __init__(self) -> None: - self._dispatcher: 'Optional[weakref.ReferenceType[Dispatcher]]' = None - self.scheduler = BackgroundScheduler(timezone=pytz.utc) + self._application: 'Optional[weakref.ReferenceType[Application]]' = None + self._executor = AsyncIOExecutor() + self.scheduler = AsyncIOScheduler(timezone=pytz.utc, executors={'default': self._executor}) def _tz_now(self) -> datetime.datetime: return datetime.datetime.now(self.scheduler.timezone) @@ -87,33 +93,35 @@ def _parse_time_input( # isinstance(time, datetime.datetime): return time - def set_dispatcher(self, dispatcher: 'Dispatcher') -> None: - """Set the dispatcher to be used by this JobQueue. + def set_application(self, application: 'Application') -> None: + """Set the application to be used by this JobQueue. Args: - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher. + application (:class:`telegram.ext.Application`): The application. """ - self._dispatcher = weakref.ref(dispatcher) - if isinstance(dispatcher.bot, ExtBot) and dispatcher.bot.defaults: - self.scheduler.configure(timezone=dispatcher.bot.defaults.tzinfo or pytz.utc) + self._application = weakref.ref(application) + if isinstance(application.bot, ExtBot) and application.bot.defaults: + self.scheduler.configure(timezone=application.bot.defaults.tzinfo or pytz.utc) @property - def dispatcher(self) -> 'Dispatcher': - """The dispatcher this JobQueue is associated with.""" - if self._dispatcher is None: - raise RuntimeError('No dispatcher was set for this JobQueue.') - dispatcher = self._dispatcher() - if dispatcher is not None: - return dispatcher - raise RuntimeError('The dispatcher instance is no longer alive.') + def application(self) -> 'Application': + """The application this JobQueue is associated with.""" + if self._application is None: + raise RuntimeError('No application was set for this JobQueue.') + application = self._application() + if application is not None: + return application + raise RuntimeError('The application instance is no longer alive.') def run_once( self, - callback: Callable[['CallbackContext'], None], + callback: JobCallback, when: Union[float, datetime.timedelta, datetime.datetime, datetime.time], context: object = None, name: str = None, + chat_id: int = None, + user_id: int = None, job_kwargs: JSONDict = None, ) -> 'Job': """Creates a new :class:`Job` instance that runs once and adds it to the queue. @@ -138,6 +146,17 @@ def run_once( tomorrow. If the timezone (:attr:`datetime.time.tzinfo`) is :obj:`None`, the default timezone of the bot will be used. + chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will + be available in the callback. + + .. versionadded:: 14.0 + + user_id (:obj:`int`, optional): User id of the user associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.user_data` will + be available in the callback. + + .. versionadded:: 14.0 context (:obj:`object`, optional): Additional data needed for the callback function. Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. @@ -155,15 +174,15 @@ def run_once( job_kwargs = {} name = name or callback.__name__ - job = Job(callback, context, name) + job = Job(callback=callback, context=context, name=name, chat_id=chat_id, user_id=user_id) date_time = self._parse_time_input(when, shift_day=True) j = self.scheduler.add_job( - job, + job.run, name=name, trigger='date', run_date=date_time, - args=(self.dispatcher,), + args=(self.application,), timezone=date_time.tzinfo or self.scheduler.timezone, **job_kwargs, ) @@ -173,12 +192,14 @@ def run_once( def run_repeating( self, - callback: Callable[['CallbackContext'], None], + callback: JobCallback, interval: Union[float, datetime.timedelta], first: Union[float, datetime.timedelta, datetime.datetime, datetime.time] = None, last: Union[float, datetime.timedelta, datetime.datetime, datetime.time] = None, context: object = None, name: str = None, + chat_id: int = None, + user_id: int = None, job_kwargs: JSONDict = None, ) -> 'Job': """Creates a new :class:`Job` instance that runs at specified intervals and adds it to the @@ -229,6 +250,17 @@ def run_repeating( :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to ``callback.__name__``. + chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will + be available in the callback. + + .. versionadded:: 14.0 + + user_id (:obj:`int`, optional): User id of the user associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.user_data` will + be available in the callback. + + .. versionadded:: 14.0 job_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to pass to the :meth:`apscheduler.schedulers.base.BaseScheduler.add_job()`. @@ -241,7 +273,7 @@ def run_repeating( job_kwargs = {} name = name or callback.__name__ - job = Job(callback, context, name) + job = Job(callback=callback, context=context, name=name, chat_id=chat_id, user_id=user_id) dt_first = self._parse_time_input(first) dt_last = self._parse_time_input(last) @@ -253,9 +285,9 @@ def run_repeating( interval = interval.total_seconds() j = self.scheduler.add_job( - job, + job.run, trigger='interval', - args=(self.dispatcher,), + args=(self.application,), start_date=dt_first, end_date=dt_last, seconds=interval, @@ -268,11 +300,13 @@ def run_repeating( def run_monthly( self, - callback: Callable[['CallbackContext'], None], + callback: JobCallback, when: datetime.time, day: int, context: object = None, name: str = None, + chat_id: int = None, + user_id: int = None, job_kwargs: JSONDict = None, ) -> 'Job': """Creates a new :class:`Job` that runs on a monthly basis and adds it to the queue. @@ -299,6 +333,17 @@ def run_monthly( :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to ``callback.__name__``. + chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will + be available in the callback. + + .. versionadded:: 14.0 + + user_id (:obj:`int`, optional): User id of the user associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.user_data` will + be available in the callback. + + .. versionadded:: 14.0 job_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to pass to the :meth:`apscheduler.schedulers.base.BaseScheduler.add_job()`. @@ -311,12 +356,12 @@ def run_monthly( job_kwargs = {} name = name or callback.__name__ - job = Job(callback, context, name) + job = Job(callback=callback, context=context, name=name, chat_id=chat_id, user_id=user_id) j = self.scheduler.add_job( - job, + job.run, trigger='cron', - args=(self.dispatcher,), + args=(self.application,), name=name, day='last' if day == -1 else day, hour=when.hour, @@ -330,11 +375,13 @@ def run_monthly( def run_daily( self, - callback: Callable[['CallbackContext'], None], + callback: JobCallback, time: datetime.time, days: Tuple[int, ...] = tuple(range(7)), context: object = None, name: str = None, + chat_id: int = None, + user_id: int = None, job_kwargs: JSONDict = None, ) -> 'Job': """Creates a new :class:`Job` that runs on a daily basis and adds it to the queue. @@ -357,6 +404,17 @@ def run_daily( :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to ``callback.__name__``. + chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will + be available in the callback. + + .. versionadded:: 14.0 + + user_id (:obj:`int`, optional): User id of the user associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.user_data` will + be available in the callback. + + .. versionadded:: 14.0 job_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to pass to the :meth:`apscheduler.schedulers.base.BaseScheduler.add_job()`. @@ -369,12 +427,12 @@ def run_daily( job_kwargs = {} name = name or callback.__name__ - job = Job(callback, context, name) + job = Job(callback=callback, context=context, name=name, chat_id=chat_id, user_id=user_id) j = self.scheduler.add_job( - job, + job.run, name=name, - args=(self.dispatcher,), + args=(self.application,), trigger='cron', day_of_week=','.join([str(d) for d in days]), hour=time.hour, @@ -389,10 +447,12 @@ def run_daily( def run_custom( self, - callback: Callable[['CallbackContext'], None], + callback: JobCallback, job_kwargs: JSONDict, context: object = None, name: str = None, + chat_id: int = None, + user_id: int = None, ) -> 'Job': """Creates a new custom defined :class:`Job`. @@ -406,6 +466,17 @@ def run_custom( :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to ``callback.__name__``. + chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will + be available in the callback. + + .. versionadded:: 14.0 + + user_id (:obj:`int`, optional): User id of the user associated with this job. If + passed, the corresponding :attr:`~telegram.ext.CallbackContext.user_data` will + be available in the callback. + + .. versionadded:: 14.0 Returns: :class:`telegram.ext.Job`: The new :class:`Job` instance that has been added to the job @@ -413,9 +484,9 @@ def run_custom( """ name = name or callback.__name__ - job = Job(callback, context, name) + job = Job(callback=callback, context=context, name=name, chat_id=chat_id, user_id=user_id) - j = self.scheduler.add_job(job, args=(self.dispatcher,), name=name, **job_kwargs) + j = self.scheduler.add_job(job, args=(self.application,), name=name, **job_kwargs) job.job = j return job @@ -425,10 +496,22 @@ def start(self) -> None: if not self.scheduler.running: self.scheduler.start() - def stop(self) -> None: - """Stops the thread.""" + async def stop(self, wait: bool = True) -> None: + """Shuts down the :class:`~telegram.ext.JobQueue`. + + Args: + wait (:obj:`bool`, optional): Whether or not to wait until all currently running jobs + have finished. Defaults to :obj:`True`. + + """ + if wait: + # Unfortunately AsyncIOExecutor just cancels them all ... + await asyncio.gather( + *self._executor._pending_futures, # pylint: disable=protected-access + return_exceptions=True, + ) if self.scheduler.running: - self.scheduler.shutdown() + self.scheduler.shutdown(wait=wait) def jobs(self) -> Tuple['Job', ...]: """Returns a tuple of all *scheduled* jobs that are currently in the :class:`JobQueue`.""" @@ -470,12 +553,24 @@ class Job: accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to ``callback.__name__``. job (:class:`apscheduler.job.Job`, optional): The APS Job this job is a wrapper for. + chat_id (:obj:`int`, optional): Chat id of the chat that this job is associated with. + + ..versionadded:: 14.0 + user_id (:obj:`int`, optional): User id of the user that this job is associated with. + + ..versionadded:: 14.0 Attributes: callback (:obj:`callable`): The callback function that should be executed by the new job. context (:obj:`object`): Optional. Additional data needed for the callback function. name (:obj:`str`): Optional. The name of the new job. job (:class:`apscheduler.job.Job`): Optional. The APS Job this job is a wrapper for. + chat_id (:obj:`int`): Optional. Chat id of the chat that this job is associated with. + + ..versionadded:: 14.0 + user_id (:obj:`int`): Optional. User id of the user that this job is associated with. + + ..versionadded:: 14.0 """ __slots__ = ( @@ -485,59 +580,52 @@ class Job: '_removed', '_enabled', 'job', + 'chat_id', + 'user_id', ) def __init__( self, - callback: Callable[['CallbackContext'], None], + callback: JobCallback, context: object = None, name: str = None, job: APSJob = None, + chat_id: int = None, + user_id: int = None, ): self.callback = callback self.context = context self.name = name or callback.__name__ + self.chat_id = chat_id + self.user_id = user_id self._removed = False self._enabled = False self.job = cast(APSJob, job) # skipcq: PTC-W0052 - def run(self, dispatcher: 'Dispatcher') -> None: + async def run(self, application: 'Application') -> None: """Executes the callback function independently of the jobs schedule. Also calls - :meth:`telegram.ext.Dispatcher.update_persistence`. + :meth:`telegram.ext.Application.update_persistence`. .. versionchanged:: 14.0 - Calls :meth:`telegram.ext.Dispatcher.update_persistence`. + Calls :meth:`telegram.ext.Application.update_persistence`. Args: - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher this job is associated + application (:class:`telegram.ext.Application`): The application this job is associated with. """ + # We shield the task such that the job isn't cancelled mid-run + await asyncio.shield(self._run(application)) + + async def _run(self, application: 'Application') -> None: try: - self.callback(dispatcher.context_types.context.from_job(self, dispatcher)) + context = application.context_types.context.from_job(self, application) + await context.refresh_data() + await self.callback(context) except Exception as exc: - dispatcher.dispatch_error(None, exc, job=self) - finally: - dispatcher.update_persistence(None) - - def __call__(self, dispatcher: 'Dispatcher') -> None: - """Shortcut for:: - - job.run(dispatcher) - - Warning: - The fact that jobs are callable should be considered an implementation detail and not - as part of PTBs public API. - - .. versionadded:: 14.0 - - Args: - dispatcher (:class:`telegram.ext.Dispatcher`): The dispatcher this job is associated - with. - """ - self.run(dispatcher=dispatcher) + await application.create_task(application.dispatch_error(None, exc, job=self)) def schedule_removal(self) -> None: """ @@ -580,7 +668,7 @@ def next_t(self) -> Optional[datetime.datetime]: @classmethod def _from_aps_job(cls, job: APSJob) -> 'Job': - return job.func + return job.func.__self__ def __getattr__(self, item: str) -> object: try: diff --git a/telegram/ext/_messagehandler.py b/telegram/ext/_messagehandler.py index 8d20fbb15de..45237259d1b 100644 --- a/telegram/ext/_messagehandler.py +++ b/telegram/ext/_messagehandler.py @@ -17,16 +17,17 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the MessageHandler class.""" -from typing import TYPE_CHECKING, Callable, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Dict, Optional, TypeVar, Union from telegram import Update +from telegram._utils.types import DVInput from telegram.ext import filters as filters_module, Handler -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE +from telegram._utils.defaultvalue import DEFAULT_TRUE -from telegram.ext._utils.types import CCT +from telegram.ext._utils.types import CCT, HandlerCallback if TYPE_CHECKING: - from telegram.ext import Dispatcher + from telegram.ext import Application RT = TypeVar('RT') @@ -35,7 +36,7 @@ class MessageHandler(Handler[Update, CCT]): """Handler class to handle telegram messages. They might contain text, media or status updates. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -53,8 +54,9 @@ class MessageHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Raises: ValueError @@ -63,7 +65,9 @@ class MessageHandler(Handler[Update, CCT]): filters (:class:`telegram.ext.filters.BaseFilter`): Only allow updates with these Filters. See :mod:`telegram.ext.filters` for a full list of all available filters. callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -72,11 +76,11 @@ class MessageHandler(Handler[Update, CCT]): def __init__( self, filters: filters_module.BaseFilter, - callback: Callable[[Update, CCT], RT], - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + callback: HandlerCallback[Update, CCT, RT], + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__(callback, run_async=run_async) + super().__init__(callback, block=block) self.filters = filters if filters is not None else filters_module.ALL def check_update(self, update: object) -> Optional[Union[bool, Dict[str, list]]]: @@ -97,7 +101,7 @@ def collect_additional_context( self, context: CCT, update: Update, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Optional[Union[bool, Dict[str, object]]], ) -> None: """Adds possible output of data filters to the :class:`CallbackContext`.""" diff --git a/telegram/ext/_picklepersistence.py b/telegram/ext/_picklepersistence.py index bf2d6b70dc2..423808aae25 100644 --- a/telegram/ext/_picklepersistence.py +++ b/telegram/ext/_picklepersistence.py @@ -59,6 +59,12 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. + update_interval (:obj:`int` | :obj:`float:, optional): The + :class:`~telegram.ext.Application` will update + the persistence in regular intervals. This parameter specifies the time (in seconds) to + wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. + + .. versionadded:: 14.0 single_file (:obj:`bool`, optional): When :obj:`False` will store 5 separate files of `filename_user_data`, `filename_bot_data`, `filename_chat_data`, `filename_callback_data` and `filename_conversations`. Default is :obj:`True`. @@ -110,6 +116,7 @@ def __init__( store_data: PersistenceInput = None, single_file: bool = True, on_flush: bool = False, + update_interval: float = 60, ): ... @@ -120,6 +127,7 @@ def __init__( store_data: PersistenceInput = None, single_file: bool = True, on_flush: bool = False, + update_interval: float = 60, context_types: ContextTypes[Any, UD, CD, BD] = None, ): ... @@ -130,9 +138,10 @@ def __init__( store_data: PersistenceInput = None, single_file: bool = True, on_flush: bool = False, + update_interval: float = 60, context_types: ContextTypes[Any, UD, CD, BD] = None, ): - super().__init__(store_data=store_data) + super().__init__(store_data=store_data, update_interval=update_interval) self.filepath = Path(filepath) self.single_file = single_file self.on_flush = on_flush @@ -193,7 +202,7 @@ def _dump_file(filepath: Path, data: object) -> None: with filepath.open("wb") as file: pickle.dump(data, file) - def get_user_data(self) -> Dict[int, UD]: + async def get_user_data(self) -> Dict[int, UD]: """Returns the user_data from the pickle file if it exists or an empty :obj:`dict`. Returns: @@ -210,7 +219,7 @@ def get_user_data(self) -> Dict[int, UD]: self._load_singlefile() return self.user_data # type: ignore[return-value] - def get_chat_data(self) -> Dict[int, CD]: + async def get_chat_data(self) -> Dict[int, CD]: """Returns the chat_data from the pickle file if it exists or an empty :obj:`dict`. Returns: @@ -227,7 +236,7 @@ def get_chat_data(self) -> Dict[int, CD]: self._load_singlefile() return self.chat_data # type: ignore[return-value] - def get_bot_data(self) -> BD: + async def get_bot_data(self) -> BD: """Returns the bot_data from the pickle file if it exists or an empty object of type :obj:`dict` | :attr:`telegram.ext.ContextTypes.bot_data`. @@ -245,7 +254,7 @@ def get_bot_data(self) -> BD: self._load_singlefile() return self.bot_data # type: ignore[return-value] - def get_callback_data(self) -> Optional[CDCData]: + async def get_callback_data(self) -> Optional[CDCData]: """Returns the callback data from the pickle file if it exists or :obj:`None`. .. versionadded:: 13.6 @@ -268,7 +277,7 @@ def get_callback_data(self) -> Optional[CDCData]: return None return self.callback_data[0], self.callback_data[1].copy() - def get_conversations(self, name: str) -> ConversationDict: + async def get_conversations(self, name: str) -> ConversationDict: """Returns the conversations from the pickle file if it exists or an empty dict. Args: @@ -288,7 +297,7 @@ def get_conversations(self, name: str) -> ConversationDict: self._load_singlefile() return self.conversations.get(name, {}).copy() # type: ignore[union-attr] - def update_conversation( + async def update_conversation( self, name: str, key: Tuple[int, ...], new_state: Optional[object] ) -> None: """Will update the conversations for the given handler and depending on :attr:`on_flush` @@ -310,12 +319,12 @@ def update_conversation( else: self._dump_singlefile() - def update_user_data(self, user_id: int, data: UD) -> None: + async def update_user_data(self, user_id: int, data: UD) -> None: """Will update the user_data and depending on :attr:`on_flush` save the pickle file. Args: user_id (:obj:`int`): The user the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.Dispatcher.user_data` ``[user_id]``. + data (:obj:`dict`): The :attr:`telegram.ext.Application.user_data` ``[user_id]``. """ if self.user_data is None: self.user_data = {} @@ -328,12 +337,12 @@ def update_user_data(self, user_id: int, data: UD) -> None: else: self._dump_singlefile() - def update_chat_data(self, chat_id: int, data: CD) -> None: + async def update_chat_data(self, chat_id: int, data: CD) -> None: """Will update the chat_data and depending on :attr:`on_flush` save the pickle file. Args: chat_id (:obj:`int`): The chat the data might have been changed for. - data (:obj:`dict`): The :attr:`telegram.ext.Dispatcher.chat_data` ``[chat_id]``. + data (:obj:`dict`): The :attr:`telegram.ext.Application.chat_data` ``[chat_id]``. """ if self.chat_data is None: self.chat_data = {} @@ -346,12 +355,12 @@ def update_chat_data(self, chat_id: int, data: CD) -> None: else: self._dump_singlefile() - def update_bot_data(self, data: BD) -> None: + async def update_bot_data(self, data: BD) -> None: """Will update the bot_data and depending on :attr:`on_flush` save the pickle file. Args: data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.bot_data`): The - :attr:`telegram.ext.Dispatcher.bot_data`. + :attr:`telegram.ext.Application.bot_data`. """ if self.bot_data == data: return @@ -362,7 +371,7 @@ def update_bot_data(self, data: BD) -> None: else: self._dump_singlefile() - def update_callback_data(self, data: CDCData) -> None: + async def update_callback_data(self, data: CDCData) -> None: """Will update the callback_data (if changed) and depending on :attr:`on_flush` save the pickle file. @@ -382,7 +391,7 @@ def update_callback_data(self, data: CDCData) -> None: else: self._dump_singlefile() - def drop_chat_data(self, chat_id: int) -> None: + async def drop_chat_data(self, chat_id: int) -> None: """Will delete the specified key from the :attr:`chat_data` and depending on :attr:`on_flush` save the pickle file. @@ -401,7 +410,7 @@ def drop_chat_data(self, chat_id: int) -> None: else: self._dump_singlefile() - def drop_user_data(self, user_id: int) -> None: + async def drop_user_data(self, user_id: int) -> None: """Will delete the specified key from the :attr:`user_data` and depending on :attr:`on_flush` save the pickle file. @@ -420,28 +429,28 @@ def drop_user_data(self, user_id: int) -> None: else: self._dump_singlefile() - def refresh_user_data(self, user_id: int, user_data: UD) -> None: + async def refresh_user_data(self, user_id: int, user_data: UD) -> None: """Does nothing. .. versionadded:: 13.6 .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_user_data` """ - def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: + async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: """Does nothing. .. versionadded:: 13.6 .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_chat_data` """ - def refresh_bot_data(self, bot_data: BD) -> None: + async def refresh_bot_data(self, bot_data: BD) -> None: """Does nothing. .. versionadded:: 13.6 .. seealso:: :meth:`telegram.ext.BasePersistence.refresh_bot_data` """ - def flush(self) -> None: + async def flush(self) -> None: """Will save all data in memory to pickle file(s).""" if self.single_file: if ( diff --git a/telegram/ext/_pollanswerhandler.py b/telegram/ext/_pollanswerhandler.py index aa4630feed4..39b366ef2dc 100644 --- a/telegram/ext/_pollanswerhandler.py +++ b/telegram/ext/_pollanswerhandler.py @@ -29,7 +29,7 @@ class PollAnswerHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a poll answer. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -39,12 +39,13 @@ class PollAnswerHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. """ diff --git a/telegram/ext/_pollhandler.py b/telegram/ext/_pollhandler.py index 45bcfe5c134..8474c27ffc4 100644 --- a/telegram/ext/_pollhandler.py +++ b/telegram/ext/_pollhandler.py @@ -29,7 +29,7 @@ class PollHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a poll. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -39,12 +39,13 @@ class PollHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. """ diff --git a/telegram/ext/_precheckoutqueryhandler.py b/telegram/ext/_precheckoutqueryhandler.py index 37c7aae43e1..c5f0275564b 100644 --- a/telegram/ext/_precheckoutqueryhandler.py +++ b/telegram/ext/_precheckoutqueryhandler.py @@ -28,7 +28,7 @@ class PreCheckoutQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram PreCheckout callback queries. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -38,12 +38,13 @@ class PreCheckoutQueryHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. """ diff --git a/telegram/ext/_shippingqueryhandler.py b/telegram/ext/_shippingqueryhandler.py index 078f3614d72..44200c1c659 100644 --- a/telegram/ext/_shippingqueryhandler.py +++ b/telegram/ext/_shippingqueryhandler.py @@ -28,7 +28,7 @@ class ShippingQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram shipping callback queries. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -38,12 +38,13 @@ class ShippingQueryHandler(Handler[Update, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. """ diff --git a/telegram/ext/_stringcommandhandler.py b/telegram/ext/_stringcommandhandler.py index 556e217e1b5..30b1a28faa7 100644 --- a/telegram/ext/_stringcommandhandler.py +++ b/telegram/ext/_stringcommandhandler.py @@ -18,16 +18,15 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the StringCommandHandler class.""" -from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, List, Optional +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE -from telegram.ext._utils.types import CCT +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram.ext._utils.types import CCT, HandlerCallback, RT if TYPE_CHECKING: - from telegram.ext import Dispatcher - -RT = TypeVar('RT') + from telegram.ext import Application class StringCommandHandler(Handler[str, CCT]): @@ -41,7 +40,7 @@ class StringCommandHandler(Handler[str, CCT]): put in the queue. For example to send messages with the bot using command line or API. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -52,13 +51,16 @@ class StringCommandHandler(Handler[str, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: command (:obj:`str`): The command this handler should listen for. callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -67,13 +69,10 @@ class StringCommandHandler(Handler[str, CCT]): def __init__( self, command: str, - callback: Callable[[str, CCT], RT], - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + callback: HandlerCallback[str, CCT, RT], + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) self.command = command def check_update(self, update: object) -> Optional[List[str]]: @@ -96,7 +95,7 @@ def collect_additional_context( self, context: CCT, update: str, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Optional[List[str]], ) -> None: """Add text after the command to :attr:`CallbackContext.args` as list, split on single diff --git a/telegram/ext/_stringregexhandler.py b/telegram/ext/_stringregexhandler.py index 919e56f947a..3c6ff57b33a 100644 --- a/telegram/ext/_stringregexhandler.py +++ b/telegram/ext/_stringregexhandler.py @@ -19,14 +19,15 @@ """This module contains the StringRegexHandler class.""" import re -from typing import TYPE_CHECKING, Callable, Match, Optional, Pattern, TypeVar, Union +from typing import TYPE_CHECKING, Match, Optional, Pattern, TypeVar, Union +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram.ext._utils.types import CCT -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE +from telegram.ext._utils.types import CCT, HandlerCallback +from telegram._utils.defaultvalue import DEFAULT_TRUE if TYPE_CHECKING: - from telegram.ext import Dispatcher + from telegram.ext import Application RT = TypeVar('RT') @@ -42,7 +43,7 @@ class StringRegexHandler(Handler[str, CCT]): put in the queue. For example to send messages with the bot using command line or API. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -53,13 +54,16 @@ class StringRegexHandler(Handler[str, CCT]): The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: pattern (:obj:`str` | :obj:`Pattern`): The regex pattern. callback (:obj:`callable`): The callback function for this handler. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -68,13 +72,10 @@ class StringRegexHandler(Handler[str, CCT]): def __init__( self, pattern: Union[str, Pattern], - callback: Callable[[str, CCT], RT], - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + callback: HandlerCallback[str, CCT, RT], + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) if isinstance(pattern, str): pattern = re.compile(pattern) @@ -101,7 +102,7 @@ def collect_additional_context( self, context: CCT, update: str, - dispatcher: 'Dispatcher', + application: 'Application', check_result: Optional[Match], ) -> None: """Add the result of ``re.match(pattern, update)`` to :attr:`CallbackContext.matches` as diff --git a/telegram/ext/_typehandler.py b/telegram/ext/_typehandler.py index eb4dc272060..7565dc837d8 100644 --- a/telegram/ext/_typehandler.py +++ b/telegram/ext/_typehandler.py @@ -18,11 +18,12 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the TypeHandler class.""" -from typing import Callable, Type, TypeVar, Union +from typing import Type, TypeVar +from telegram._utils.types import DVInput from telegram.ext import Handler -from telegram.ext._utils.types import CCT -from telegram._utils.defaultvalue import DefaultValue, DEFAULT_FALSE +from telegram.ext._utils.types import CCT, HandlerCallback +from telegram._utils.defaultvalue import DEFAULT_TRUE RT = TypeVar('RT') UT = TypeVar('UT') @@ -32,7 +33,7 @@ class TypeHandler(Handler[UT, CCT]): """Handler class to handle updates of custom types. Warning: - When setting ``run_async`` to :obj:`True`, you cannot rely on adding custom + When setting ``block`` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -46,14 +47,17 @@ class TypeHandler(Handler[UT, CCT]): :class:`telegram.ext.ConversationHandler`. strict (:obj:`bool`, optional): Use ``type`` instead of ``isinstance``. Default is :obj:`False` - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. - Defaults to :obj:`False`. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: type (:obj:`type`): The ``type`` of updates this handler should process. callback (:obj:`callable`): The callback function for this handler. strict (:obj:`bool`): Use ``type`` instead of ``isinstance``. Default is :obj:`False`. - run_async (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the return value of the callback should be + awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. """ @@ -62,14 +66,11 @@ class TypeHandler(Handler[UT, CCT]): def __init__( self, type: Type[UT], # pylint: disable=redefined-builtin - callback: Callable[[UT, CCT], RT], + callback: HandlerCallback[UT, CCT, RT], strict: bool = False, - run_async: Union[bool, DefaultValue] = DEFAULT_FALSE, + block: DVInput[bool] = DEFAULT_TRUE, ): - super().__init__( - callback, - run_async=run_async, - ) + super().__init__(callback, block=block) self.type = type # pylint: disable=assigning-non-slot self.strict = strict # pylint: disable=assigning-non-slot diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 689a59e7b19..82573fad0f2 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -17,174 +17,146 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the class Updater, which tries to make creating Telegram bots intuitive.""" -import inspect +import asyncio import logging import ssl -import signal from pathlib import Path -from queue import Queue -from threading import Event, Lock, Thread, current_thread -from time import sleep +from types import TracebackType from typing import ( - Any, Callable, List, Optional, - Tuple, Union, - no_type_check, - Generic, TypeVar, TYPE_CHECKING, + Coroutine, + Type, ) -from telegram.error import InvalidToken, RetryAfter, TimedOut, Unauthorized, TelegramError -from telegram._utils.warnings import warn -from telegram.ext import Dispatcher +from telegram._utils.defaultvalue import DEFAULT_NONE +from telegram._utils.types import ODVInput +from telegram.error import InvalidToken, RetryAfter, TimedOut, Forbidden, TelegramError from telegram.ext._utils.webhookhandler import WebhookAppClass, WebhookServer -from telegram.ext._utils.stack import was_called_by -from telegram.ext._utils.types import BT if TYPE_CHECKING: - from telegram.ext._builders import InitUpdaterBuilder + from telegram import Bot -DT = TypeVar('DT', bound=Union[None, Dispatcher]) +_UpdaterType = TypeVar('_UpdaterType', bound="Updater") -class Updater(Generic[BT, DT]): - """ - This class, which employs the :class:`telegram.ext.Dispatcher`, provides a frontend to - :class:`telegram.Bot` to the programmer, so they can focus on coding the bot. Its purpose is to - receive the updates from Telegram and to deliver them to said dispatcher. It also runs in a - separate thread, so the user can interact with the bot, for example on the command line. The - dispatcher supports handlers for different kinds of data: Updates from Telegram, basic text - commands and even arbitrary types. The updater can be started as a polling service or, for - production, use a webhook to receive updates. This is achieved using the WebhookServer and - WebhookHandler classes. - - Note: - This class may not be initialized directly. Use :class:`telegram.ext.UpdaterBuilder` or - :meth:`builder` (for convenience). +class Updater: + """This class fetches updates for the bot either via long polling or by starting a webhook + server. Received updates are enqueued into the :attr:`update_queue` and may be fetched from + there to handle them appropriately. .. versionchanged:: 14.0 - * Initialization is now done through the :class:`telegram.ext.UpdaterBuilder`. - * Renamed ``user_sig_handler`` to :attr:`user_signal_handler`. - * Removed the attributes ``job_queue``, and ``persistence`` - use the corresponding - attributes of :attr:`dispatcher` instead. + * Removed argument and attribute ``user_sig_handler`` + * The only arguments and attributes are now :attr:`bot` and :attr:`update_queue` as now + the sole purpose of this class is to fetch updates. The entry point to a PTB application + is now :class:`telegram.ext.Application`. Attributes: bot (:class:`telegram.Bot`): The bot used with this Updater. - user_signal_handler (:obj:`function`): Optional. Function to be called when a signal is - received. - - .. versionchanged:: 14.0 - Renamed ``user_sig_handler`` to ``user_signal_handler``. - update_queue (:obj:`Queue`): Queue for the updates. - dispatcher (:class:`telegram.ext.Dispatcher`): Optional. Dispatcher that handles the - updates and dispatches them to the handlers. - running (:obj:`bool`): Indicates if the updater is running. - exception_event (:class:`threading.Event`): When an unhandled exception happens while - fetching updates, this event will be set. If :attr:`dispatcher` is not :obj:`None`, it - is the same object as :attr:`telegram.ext.Dispatcher.exception_event`. + update_queue (:class:`asyncio.Queue`): Queue for the updates. - .. versionadded:: 14.0 + Args: + bot (:class:`telegram.Bot`): The bot used with this Updater. + update_queue (:class:`asyncio.Queue`): Queue for the updates. """ __slots__ = ( - 'dispatcher', - 'user_signal_handler', 'bot', - 'logger', + '_logger', 'update_queue', - 'exception_event', 'last_update_id', - 'running', - 'is_idle', - 'httpd', + '_running', + '_httpd', '__lock', - '__threads', + '__asyncio_tasks', ) def __init__( - self: 'Updater[BT, DT]', - *, - user_signal_handler: Callable[[int, object], Any] = None, - dispatcher: DT = None, - bot: BT = None, - update_queue: Queue = None, - exception_event: Event = None, + self, + bot: 'Bot', + update_queue: asyncio.Queue, ): - if not was_called_by( - inspect.currentframe(), Path(__file__).parent.resolve() / '_builders.py' - ): - warn( - '`Updater` instances should be built via the `UpdaterBuilder`.', - stacklevel=2, - ) - - self.user_signal_handler = user_signal_handler - self.dispatcher = dispatcher - if self.dispatcher: - self.bot = self.dispatcher.bot - self.update_queue = self.dispatcher.update_queue - self.exception_event = self.dispatcher.exception_event - else: - self.bot = bot - self.update_queue = update_queue - self.exception_event = exception_event + self.bot = bot + self.update_queue = update_queue self.last_update_id = 0 - self.running = False - self.is_idle = False - self.httpd = None - self.__lock = Lock() - self.__threads: List[Thread] = [] - self.logger = logging.getLogger(__name__) + self._running = False + self._httpd: Optional[WebhookServer] = None + self.__lock = asyncio.Lock() + self.__asyncio_tasks: List[asyncio.Task] = [] + self._logger = logging.getLogger(__name__) - @staticmethod - def builder() -> 'InitUpdaterBuilder': - """Convenience method. Returns a new :class:`telegram.ext.UpdaterBuilder`. + @property + def running(self) -> bool: + return self._running - .. versionadded:: 14.0 - """ - # Unfortunately this needs to be here due to cyclical imports - from telegram.ext import UpdaterBuilder # pylint: disable=import-outside-toplevel + async def initialize(self) -> None: + await self.bot.initialize() + + async def shutdown(self) -> None: + await self.bot.shutdown() + self._logger.debug('Shut down of Updater complete') + + async def __aenter__(self: _UpdaterType) -> _UpdaterType: + try: + await self.initialize() + return self + except Exception as exc: + await self.shutdown() + raise exc - return UpdaterBuilder() + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # Make sure not to return `True` so that exceptions are not suppressed + # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ + await self.shutdown() - def _init_thread(self, target: Callable, name: str, *args: object, **kwargs: object) -> None: - thr = Thread( - target=self._thread_wrapper, - name=f"Bot:{self.bot.id}:{name}", - args=(target,) + args, - kwargs=kwargs, + def _init_task( + self, target: Callable[..., Coroutine], name: str, *args: object, **kwargs: object + ) -> None: + task = asyncio.create_task( + coro=self._task_wrapper(target, name, *args, **kwargs), + # TODO: Add this once we drop py3.7 + # name=f"Updater:{self.bot.id}:{name}", ) - thr.start() - self.__threads.append(thr) + self.__asyncio_tasks.append(task) - def _thread_wrapper(self, target: Callable, *args: object, **kwargs: object) -> None: - thr_name = current_thread().name - self.logger.debug('%s - started', thr_name) + async def _task_wrapper( + self, target: Callable, name: str, *args: object, **kwargs: object + ) -> None: + self._logger.debug('%s - started', name) try: - target(*args, **kwargs) + await target(*args, **kwargs) except Exception: - self.exception_event.set() - self.logger.exception('unhandled exception in %s', thr_name) - raise - self.logger.debug('%s - ended', thr_name) + self._logger.exception('Unhandled exception in %s.', name) + self._logger.debug('%s - ended', name) - def start_polling( + # TODO: Probably drop `pool_connect` timeout again, because we probably want to just make + # sure that `getUpdates` always gets a connection without waiting + async def start_polling( self, poll_interval: float = 0.0, - timeout: float = 10, + timeout: int = 10, bootstrap_retries: int = -1, - read_latency: float = 2.0, + read_timeout: float = 2, + write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, allowed_updates: List[str] = None, drop_pending_updates: bool = None, - ) -> Optional[Queue]: + error_callback: Callable[[TelegramError], None] = None, + ) -> asyncio.Queue: """Starts polling updates from Telegram. .. versionchanged:: 14.0 @@ -207,60 +179,126 @@ def start_polling( allowed_updates (List[:obj:`str`], optional): Passed to :meth:`telegram.Bot.get_updates`. - read_latency (:obj:`float` | :obj:`int`, optional): Grace time in seconds for receiving + read_timeout (:obj:`float` | :obj:`int`, optional): Grace time in seconds for receiving the reply from server. Will be added to the ``timeout`` value and used as the read timeout from server (Default: ``2``). + error_callback (Callable[[:exc:`telegram.error.TelegramError`], :obj:`None`], \ + optional): Callback to handle :exc:`telegram.error.TelegramError` s that occur + while calling :meth:`telegram.Bot.get_updates` during polling. Defaults to + :obj:`None`, in which case errors will be logged. Returns: - :obj:`Queue`: The update queue that can be filled from the main thread. + :class:`asyncio.Queue`: The update queue that can be filled from the main thread. """ - with self.__lock: - if not self.running: - self.running = True - - # Create & start threads - dispatcher_ready = Event() - polling_ready = Event() - - if self.dispatcher: - self._init_thread(self.dispatcher.start, "dispatcher", ready=dispatcher_ready) - self._init_thread( - self._start_polling, - "updater", - poll_interval, - timeout, - read_latency, - bootstrap_retries, - drop_pending_updates, - allowed_updates, - ready=polling_ready, - ) + async with self.__lock: + if self.running: + return self.update_queue - self.logger.debug('Waiting for polling to start') - polling_ready.wait() - if self.dispatcher: - self.logger.debug('Waiting for Dispatcher to start') - dispatcher_ready.wait() + self._running = True + + # Create & start tasks + polling_ready = asyncio.Event() + + self._init_task( + self._start_polling, + "Polling Background task", + poll_interval, + timeout, + read_timeout, + write_timeout, + connect_timeout, + pool_timeout, + bootstrap_retries, + drop_pending_updates, + allowed_updates, + ready=polling_ready, + error_callback=error_callback, + ) - # Return the update queue so the main thread can insert updates - return self.update_queue - return None + self._logger.debug('Waiting for polling to start') + await polling_ready.wait() + self._logger.debug('Polling to started') + + return self.update_queue + + async def _start_polling( + self, + poll_interval: float, + timeout: int, + read_timeout: Optional[float], + write_timeout: Optional[float], + connect_timeout: Optional[float], + pool_timeout: Optional[float], + bootstrap_retries: int, + drop_pending_updates: bool, + allowed_updates: Optional[List[str]], + ready: asyncio.Event = None, + error_callback: Callable[[TelegramError], None] = None, + ) -> None: + # Target of task 'updater.start_polling()'. Runs in background, pulls + # updates from Telegram and inserts them in the update queue of the + # Application. + + self._logger.debug('Updater started (polling)') + + await self._bootstrap( + bootstrap_retries, + drop_pending_updates=drop_pending_updates, + webhook_url='', + allowed_updates=None, + ) + + self._logger.debug('Bootstrap done') + + async def polling_action_cb() -> bool: + updates = await self.bot.get_updates( + self.last_update_id, + timeout=timeout, + read_timeout=read_timeout, + connect_timeout=connect_timeout, + write_timeout=write_timeout, + pool_timeout=pool_timeout, + allowed_updates=allowed_updates, + ) + + if updates: + if not self.running: + self._logger.debug('Updates ignored and will be pulled again on restart') + else: + for update in updates: + await self.update_queue.put(update) + self.last_update_id = updates[-1].update_id + 1 + + return True + + def default_error_callback(exc: TelegramError) -> None: + self._logger.exception('Exception happened while polling for updates.', exc_info=exc) - def start_webhook( + if ready is not None: + ready.set() + + await self._network_loop_retry( + action_cb=polling_action_cb, + onerr_cb=error_callback or default_error_callback, + description='getting Updates', + interval=poll_interval, + ) + + async def start_webhook( self, listen: str = '127.0.0.1', port: int = 80, url_path: str = '', - cert: str = None, - key: str = None, + cert: Union[str, Path] = None, + key: Union[str, Path] = None, bootstrap_retries: int = 0, webhook_url: str = None, allowed_updates: List[str] = None, drop_pending_updates: bool = None, ip_address: str = None, max_connections: int = 40, - ) -> Optional[Queue]: + ) -> asyncio.Queue: """ Starts a small http server to listen for updates via webhook. If :attr:`cert` and :attr:`key` are not provided, the webhook will be started directly on @@ -271,7 +309,6 @@ def start_webhook( .. versionchanged:: 13.4 :meth:`start_webhook` now *always* calls :meth:`telegram.Bot.set_webhook`, so pass ``webhook_url`` instead of calling ``updater.bot.set_webhook(webhook_url)`` manually. - .. versionchanged:: 14.0 Removed the ``clean`` argument in favor of ``drop_pending_updates`` and removed the deprecated argument ``force_event_loop``. @@ -281,11 +318,10 @@ def start_webhook( port (:obj:`int`, optional): Port the bot should be listening on. Must be one of :attr:`telegram.constants.SUPPORTED_WEBHOOK_PORTS`. Defaults to ``80``. url_path (:obj:`str`, optional): Path inside url. - cert (:obj:`str`, optional): Path to the SSL certificate file. - key (:obj:`str`, optional): Path to the SSL key file. + cert (:class:`pathlib.Path` | :obj:`str`, optional): Path to the SSL certificate file. + key (:class:`pathlib.Path` | :obj:`str`, optional): Path to the SSL key file. drop_pending_updates (:obj:`bool`, optional): Whether to clean any pending updates on Telegram servers before actually starting to poll. Default is :obj:`False`. - .. versionadded :: 13.4 bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the :class:`telegram.ext.Updater` will retry on failures on the Telegram server. @@ -293,192 +329,66 @@ def start_webhook( * < 0 - retry indefinitely (default) * 0 - no retries * > 0 - retry up to X times - webhook_url (:obj:`str`, optional): Explicitly specify the webhook url. Useful behind NAT, reverse proxy, etc. Default is derived from ``listen``, ``port`` & ``url_path``. ip_address (:obj:`str`, optional): Passed to :meth:`telegram.Bot.set_webhook`. - .. versionadded :: 13.4 allowed_updates (List[:obj:`str`], optional): Passed to :meth:`telegram.Bot.set_webhook`. max_connections (:obj:`int`, optional): Passed to :meth:`telegram.Bot.set_webhook`. - .. versionadded:: 13.6 - Returns: :obj:`Queue`: The update queue that can be filled from the main thread. - """ - with self.__lock: - if not self.running: - self.running = True - - # Create & start threads - webhook_ready = Event() - dispatcher_ready = Event() - - if self.dispatcher: - self._init_thread(self.dispatcher.start, "dispatcher", dispatcher_ready) - self._init_thread( - self._start_webhook, - "updater", - listen, - port, - url_path, - cert, - key, - bootstrap_retries, - drop_pending_updates, - webhook_url, - allowed_updates, - ready=webhook_ready, - ip_address=ip_address, - max_connections=max_connections, - ) - - self.logger.debug('Waiting for webhook to start') - webhook_ready.wait() - if self.dispatcher: - self.logger.debug('Waiting for Dispatcher to start') - dispatcher_ready.wait() - - # Return the update queue so the main thread can insert updates + async with self.__lock: + if self.running: return self.update_queue - return None - - @no_type_check - def _start_polling( - self, - poll_interval, - timeout, - read_latency, - bootstrap_retries, - drop_pending_updates, - allowed_updates, - ready=None, - ): # pragma: no cover - # Thread target of thread 'updater'. Runs in background, pulls - # updates from Telegram and inserts them in the update queue of the - # Dispatcher. - self.logger.debug('Updater thread started (polling)') - - self._bootstrap( - bootstrap_retries, - drop_pending_updates=drop_pending_updates, - webhook_url='', - allowed_updates=None, - ) + self._running = True - self.logger.debug('Bootstrap done') + # Create & start tasks + webhook_ready = asyncio.Event() - def polling_action_cb(): - updates = self.bot.get_updates( - self.last_update_id, - timeout=timeout, - read_latency=read_latency, + await self._start_webhook( + listen=listen, + port=port, + url_path=url_path, + cert=cert, + key=key, + bootstrap_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + webhook_url=webhook_url, allowed_updates=allowed_updates, + ready=webhook_ready, + ip_address=ip_address, + max_connections=max_connections, ) - if updates: - if not self.running: - self.logger.debug('Updates ignored and will be pulled again on restart') - else: - for update in updates: - self.update_queue.put(update) - self.last_update_id = updates[-1].update_id + 1 - - return True - - def polling_onerr_cb(exc): - # Put the error into the update queue and let the Dispatcher - # broadcast it - self.update_queue.put(exc) - - if ready is not None: - ready.set() - - self._network_loop_retry( - polling_action_cb, polling_onerr_cb, 'getting Updates', poll_interval - ) - - @no_type_check - def _network_loop_retry(self, action_cb, onerr_cb, description, interval): - """Perform a loop calling `action_cb`, retrying after network errors. - - Stop condition for loop: `self.running` evaluates :obj:`False` or return value of - `action_cb` evaluates :obj:`False`. - - Args: - action_cb (:obj:`callable`): Network oriented callback function to call. - onerr_cb (:obj:`callable`): Callback to call when TelegramError is caught. Receives the - exception object as a parameter. - description (:obj:`str`): Description text to use for logs and exception raised. - interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to - `action_cb`. - - """ - self.logger.debug('Start network loop retry %s', description) - cur_interval = interval - while self.running: - try: - if not action_cb(): - break - except RetryAfter as exc: - self.logger.info('%s', exc) - cur_interval = 0.5 + exc.retry_after - except TimedOut as toe: - self.logger.debug('Timed out %s: %s', description, toe) - # If failure is due to timeout, we should retry asap. - cur_interval = 0 - except InvalidToken as pex: - self.logger.error('Invalid token; aborting') - raise pex - except TelegramError as telegram_exc: - self.logger.error('Error while %s: %s', description, telegram_exc) - onerr_cb(telegram_exc) - cur_interval = self._increase_poll_interval(cur_interval) - else: - cur_interval = interval - - if cur_interval: - sleep(cur_interval) + self._logger.debug('Waiting for webhook server to start') + await webhook_ready.wait() + self._logger.debug('Webhook server started') - @staticmethod - def _increase_poll_interval(current_interval: float) -> float: - # increase waiting times on subsequent errors up to 30secs - if current_interval == 0: - current_interval = 1 - elif current_interval < 30: - current_interval *= 1.5 - else: - current_interval = min(30.0, current_interval) - return current_interval + # Return the update queue so the main thread can insert updates + return self.update_queue - @no_type_check - def _start_webhook( + async def _start_webhook( self, - listen, - port, - url_path, - cert, - key, - bootstrap_retries, - drop_pending_updates, - webhook_url, - allowed_updates, - ready=None, - ip_address=None, + listen: str, + port: int, + url_path: str, + bootstrap_retries: int, + allowed_updates: Optional[List[str]], + cert: Union[str, Path] = None, + key: Union[str, Path] = None, + drop_pending_updates: bool = None, + webhook_url: str = None, + ready: asyncio.Event = None, + ip_address: str = None, max_connections: int = 40, - ): - self.logger.debug('Updater thread started (webhook)') - - # Note that we only use the SSL certificate for the WebhookServer, if the key is also - # present. This is because the WebhookServer may not actually be in charge of performing - # the SSL handshake, e.g. in case a reverse proxy is used - use_ssl = cert is not None and key is not None + ) -> None: + self._logger.debug('Updater thread started (webhook)') if not url_path.startswith('/'): url_path = f'/{url_path}' @@ -488,35 +398,39 @@ def _start_webhook( # Form SSL Context # An SSLError is raised if the private key does not match with the certificate - if use_ssl: + # Note that we only use the SSL certificate for the WebhookServer, if the key is also + # present. This is because the WebhookServer may not actually be in charge of performing + # the SSL handshake, e.g. in case a reverse proxy is used + if cert is not None and key is not None: try: - ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_ctx.load_cert_chain(cert, key) + ssl_ctx: Optional[ssl.SSLContext] = ssl.create_default_context( + ssl.Purpose.CLIENT_AUTH + ) + ssl_ctx.load_cert_chain(cert, key) # type: ignore[union-attr] except ssl.SSLError as exc: raise TelegramError('Invalid SSL Certificate') from exc else: ssl_ctx = None # Create and start server - self.httpd = WebhookServer(listen, port, app, ssl_ctx) + self._httpd = WebhookServer(listen, port, app, ssl_ctx) if not webhook_url: webhook_url = self._gen_webhook_url(listen, port, url_path) # We pass along the cert to the webhook if present. if cert is not None: - with open(cert, 'rb') as cert_file: - self._bootstrap( - cert=cert_file, - max_retries=bootstrap_retries, - drop_pending_updates=drop_pending_updates, - webhook_url=webhook_url, - allowed_updates=allowed_updates, - ip_address=ip_address, - max_connections=max_connections, - ) + await self._bootstrap( + cert=cert, + max_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + webhook_url=webhook_url, + allowed_updates=allowed_updates, + ip_address=ip_address, + max_connections=max_connections, + ) else: - self._bootstrap( + await self._bootstrap( max_retries=bootstrap_retries, drop_pending_updates=drop_pending_updates, webhook_url=webhook_url, @@ -525,38 +439,100 @@ def _start_webhook( max_connections=max_connections, ) - self.httpd.serve_forever(ready=ready) + await self._httpd.serve_forever(ready=ready) @staticmethod def _gen_webhook_url(listen: str, port: int, url_path: str) -> str: return f'https://{listen}:{port}{url_path}' - @no_type_check - def _bootstrap( + async def _network_loop_retry( + self, + action_cb: Callable[..., Coroutine], + onerr_cb: Callable[[TelegramError], None], + description: str, + interval: float, + ) -> None: + """Perform a loop calling `action_cb`, retrying after network errors. + + Stop condition for loop: `self.running` evaluates :obj:`False` or return value of + `action_cb` evaluates :obj:`False`. + + Args: + action_cb (:obj:`callable`): Network oriented callback function to call. + onerr_cb (:obj:`callable`): Callback to call when TelegramError is caught. Receives the + exception object as a parameter. + description (:obj:`str`): Description text to use for logs and exception raised. + interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to + `action_cb`. + + """ + self._logger.debug('Start network loop retry %s', description) + cur_interval = interval + while self.running: + try: + try: + if not await action_cb(): + break + except RetryAfter as exc: + self._logger.info('%s', exc) + cur_interval = 0.5 + exc.retry_after + except TimedOut as toe: + self._logger.debug('Timed out %s: %s', description, toe) + # If failure is due to timeout, we should retry asap. + cur_interval = 0 + except InvalidToken as pex: + self._logger.error('Invalid token; aborting') + raise pex + except TelegramError as telegram_exc: + self._logger.error('Error while %s: %s', description, telegram_exc) + onerr_cb(telegram_exc) + cur_interval = self._increase_poll_interval(cur_interval) + else: + cur_interval = interval + + if cur_interval: + await asyncio.sleep(cur_interval) + + except asyncio.CancelledError: + self._logger.debug('Network loop retry %s was cancelled', description) + break + + @staticmethod + def _increase_poll_interval(current_interval: float) -> float: + # increase waiting times on subsequent errors up to 30secs + if current_interval == 0: + current_interval = 1 + elif current_interval < 30: + current_interval *= 1.5 + else: + current_interval = min(30.0, current_interval) + return current_interval + + async def _bootstrap( self, - max_retries, - drop_pending_updates, - webhook_url, - allowed_updates, - cert=None, - bootstrap_interval=5, - ip_address=None, + max_retries: int, + webhook_url: Optional[str], + allowed_updates: Optional[List[str]], + drop_pending_updates: bool = None, + cert: Union[str, Path] = None, + bootstrap_interval: float = 5, + ip_address: str = None, max_connections: int = 40, - ): + ) -> None: retries = [0] - def bootstrap_del_webhook(): - self.logger.debug('Deleting webhook') + async def bootstrap_del_webhook() -> bool: + self._logger.debug('Deleting webhook') if drop_pending_updates: - self.logger.debug('Dropping pending updates from Telegram server') - self.bot.delete_webhook(drop_pending_updates=drop_pending_updates) + self._logger.debug('Dropping pending updates from Telegram server') + await self.bot.delete_webhook(drop_pending_updates=drop_pending_updates) return False - def bootstrap_set_webhook(): - self.logger.debug('Setting webhook') + async def bootstrap_set_webhook() -> bool: + self._logger.debug('Setting webhook') if drop_pending_updates: - self.logger.debug('Dropping pending updates from Telegram server') - self.bot.set_webhook( + self._logger.debug('Dropping pending updates from Telegram server') + await self.bot.set_webhook( url=webhook_url, certificate=cert, allowed_updates=allowed_updates, @@ -566,14 +542,14 @@ def bootstrap_set_webhook(): ) return False - def bootstrap_onerr_cb(exc): - if not isinstance(exc, Unauthorized) and (max_retries < 0 or retries[0] < max_retries): + def bootstrap_onerr_cb(exc: Exception) -> None: + if not isinstance(exc, Forbidden) and (max_retries < 0 or retries[0] < max_retries): retries[0] += 1 - self.logger.warning( + self._logger.warning( 'Failed bootstrap phase; try=%s max_retries=%s', retries[0], max_retries ) else: - self.logger.error('Failed bootstrap phase after %s retries (%s)', retries[0], exc) + self._logger.error('Failed bootstrap phase after %s retries (%s)', retries[0], exc) raise exc # Dropping pending updates from TG can be efficiently done with the drop_pending_updates @@ -581,7 +557,7 @@ def bootstrap_onerr_cb(exc): # sure that no webhook is configured in case of polling, so we just always call # delete_webhook for polling if drop_pending_updates or not webhook_url: - self._network_loop_retry( + await self._network_loop_retry( bootstrap_del_webhook, bootstrap_onerr_cb, 'bootstrap del webhook', @@ -592,94 +568,35 @@ def bootstrap_onerr_cb(exc): # Restore/set webhook settings, if needed. Again, we don't know ahead if a webhook is set, # so we set it anyhow. if webhook_url: - self._network_loop_retry( + await self._network_loop_retry( bootstrap_set_webhook, bootstrap_onerr_cb, 'bootstrap set webhook', bootstrap_interval, ) - def stop(self) -> None: - """Stops the polling/webhook thread, the dispatcher and the job queue.""" - with self.__lock: - if self.running or (self.dispatcher and self.dispatcher.has_running_threads): - self.logger.debug( - 'Stopping Updater %s...', 'and Dispatcher ' if self.dispatcher else '' - ) - - self.running = False - - self._stop_httpd() - self._stop_dispatcher() - self._join_threads() - - # Clear the connection pool only if the bot is managed by the Updater - # Otherwise `dispatcher.stop()` already does that - if not self.dispatcher: - self.bot.request.stop() - - @no_type_check - def _stop_httpd(self) -> None: - if self.httpd: - self.logger.debug( - 'Waiting for current webhook connection to be ' - 'closed... Send a Telegram message to the bot to exit ' - 'immediately.' - ) - self.httpd.shutdown() - self.httpd = None - - @no_type_check - def _stop_dispatcher(self) -> None: - if self.dispatcher: - self.logger.debug('Requesting Dispatcher to stop...') - self.dispatcher.stop() - - @no_type_check - def _join_threads(self) -> None: - for thr in self.__threads: - self.logger.debug('Waiting for %s thread to end', thr.name) - thr.join() - self.logger.debug('%s thread has ended', thr.name) - self.__threads = [] - - @no_type_check - def _signal_handler(self, signum, frame) -> None: - self.is_idle = False - if self.running: - self.logger.info( - 'Received signal %s (%s), stopping...', - signum, - # signal.Signals is undocumented for some reason see - # https://github.com/python/typeshed/pull/555#issuecomment-247874222 - # https://bugs.python.org/issue28206 - signal.Signals(signum), # pylint: disable=no-member - ) - self.stop() - if self.user_signal_handler: - self.user_signal_handler(signum, frame) - else: - self.logger.warning('Exiting immediately!') - # pylint: disable=import-outside-toplevel, protected-access - import os - - os._exit(1) + async def stop(self) -> None: + """Stops the polling/webhook.""" + async with self.__lock: + if self.running: + self._logger.debug('Stopping Updater') - def idle( - self, stop_signals: Union[List, Tuple] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT) - ) -> None: - """Blocks until one of the signals are received and stops the updater. + self._running = False - Args: - stop_signals (:obj:`list` | :obj:`tuple`): List containing signals from the signal - module that should be subscribed to. :meth:`Updater.stop()` will be called on - receiving one of those signals. Defaults to (``SIGINT``, ``SIGTERM``, ``SIGABRT``). + await self._stop_httpd() + await self._join_tasks() - """ - for sig in stop_signals: - signal.signal(sig, self._signal_handler) + self._logger.debug('Updater.stop() is complete') - self.is_idle = True + async def _stop_httpd(self) -> None: + if self._httpd: + self._logger.debug('Waiting for current webhook connection to be closed.') + await self._httpd.shutdown() + self._httpd = None - while self.is_idle: - sleep(1) + async def _join_tasks(self) -> None: + self._logger.debug('Stopping Background tasks') + for task in self.__asyncio_tasks: + task.cancel() + await asyncio.gather(*self.__asyncio_tasks) + self.__asyncio_tasks = [] diff --git a/telegram/ext/_utils/promise.py b/telegram/ext/_utils/promise.py deleted file mode 100644 index 549f50b057b..00000000000 --- a/telegram/ext/_utils/promise.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains the Promise class.""" - -import logging -from threading import Event -from typing import Callable, List, Optional, Tuple, TypeVar, Union - -from telegram._utils.types import JSONDict - -RT = TypeVar('RT') - - -logger = logging.getLogger(__name__) - - -class Promise: - """A simple Promise implementation for use with the run_async decorator, DelayQueue etc. - - .. versionchanged:: 14.0 - Removed the argument and attribute ``error_handler``. - - Args: - pooled_function (:obj:`callable`): The callable that will be called concurrently. - args (:obj:`list` | :obj:`tuple`): Positional arguments for :attr:`pooled_function`. - kwargs (:obj:`dict`): Keyword arguments for :attr:`pooled_function`. - update (:class:`telegram.Update` | :obj:`object`, optional): The update this promise is - associated with. - - Attributes: - pooled_function (:obj:`callable`): The callable that will be called concurrently. - args (:obj:`list` | :obj:`tuple`): Positional arguments for :attr:`pooled_function`. - kwargs (:obj:`dict`): Keyword arguments for :attr:`pooled_function`. - done (:obj:`threading.Event`): Is set when the result is available. - update (:class:`telegram.Update` | :obj:`object`): Optional. The update this promise is - associated with. - - """ - - __slots__ = ( - 'pooled_function', - 'args', - 'kwargs', - 'update', - 'done', - '_done_callback', - '_result', - '_exception', - ) - - def __init__( - self, - pooled_function: Callable[..., RT], - args: Union[List, Tuple], - kwargs: JSONDict, - update: object = None, - ): - self.pooled_function = pooled_function - self.args = args - self.kwargs = kwargs - self.update = update - self.done = Event() - self._done_callback: Optional[Callable] = None - self._result: Optional[RT] = None - self._exception: Optional[Exception] = None - - def run(self) -> None: - """Calls the :attr:`pooled_function` callable.""" - try: - self._result = self.pooled_function(*self.args, **self.kwargs) - - except Exception as exc: - self._exception = exc - - finally: - self.done.set() - if self._exception is None and self._done_callback: - try: - self._done_callback(self.result()) - except Exception as exc: - logger.warning( - "`done_callback` of a Promise raised the following exception." - " The exception won't be handled by error handlers." - ) - logger.warning("Full traceback:", exc_info=exc) - - def __call__(self) -> None: - self.run() - - def result(self, timeout: float = None) -> Optional[RT]: - """Return the result of the ``Promise``. - - Args: - timeout (:obj:`float`, optional): Maximum time in seconds to wait for the result to be - calculated. ``None`` means indefinite. Default is ``None``. - - Returns: - Returns the return value of :attr:`pooled_function` or ``None`` if the ``timeout`` - expires. - - Raises: - object exception raised by :attr:`pooled_function`. - """ - self.done.wait(timeout=timeout) - if self._exception is not None: - raise self._exception # pylint: disable=raising-bad-type - return self._result - - def add_done_callback(self, callback: Callable) -> None: - """ - Callback to be run when :class:`telegram.ext._utils.promise.Promise` becomes done. - - Note: - Callback won't be called if :attr:`pooled_function` - raises an exception. - - Args: - callback (:obj:`callable`): The callable that will be called when promise is done. - callback will be called by passing ``Promise.result()`` as only positional argument. - - """ - if self.done.wait(0): - callback(self.result()) - else: - self._done_callback = callback - - @property - def exception(self) -> Optional[Exception]: - """The exception raised by :attr:`pooled_function` or ``None`` if no exception has been - raised (yet). - """ - return self._exception diff --git a/telegram/ext/_utils/trackingdefaultdict.py b/telegram/ext/_utils/trackingdefaultdict.py new file mode 100644 index 00000000000..1b48f4b4206 --- /dev/null +++ b/telegram/ext/_utils/trackingdefaultdict.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains subclasses of :class:`collections.defaultdict` that keeps track of the +keys that where accessed. + +.. versionadded:: 14.0 + +Warning: + Contents of this module are intended to be used internally by the library and *not* by the + user. Changes to this module are not considered breaking changes and may not be documented in + the changelog. +""" +from typing import ( + TypeVar, + DefaultDict, + Callable, + Set, + ClassVar, + Iterator, + Optional, + Union, + Tuple, + overload, + MutableMapping, + List, + Mapping, +) +from collections import defaultdict + +from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue + +_VT = TypeVar('_VT') +_KT = TypeVar('_KT') +_T = TypeVar('_T') + + +# TODO: Implement tests for this class that cover all methods implemented by (Mutable)Mapping and +# check if they give the correct behavior in terms of keeping track on the access. This includes +# __eq__ & access through Key/ItemViews +# We should also test that all this behavior stays the same when accessing the mapping through +# a MappingProxyType +# For methods like `pop`, `get`, `setdefault`, we should also check that we have the same +# behavior as defaultdict + + +class TrackingDefaultDict(MutableMapping[_KT, _VT]): + """DefaultDict that keeps track of which keys where accessed. + + Note: + * ``key in tdd`` is not considered reading + * ``setdefault()`` is considered both reading and writing depending on + whether or not the key is present + * ``pop`` is only considered writing, since the value is deleted instead of being changed + + Args: + default_factory (Callable): Default factory for missing entries + track_read (:obj:`bool`): Whether read access should be tracked. Deleting entries is + not considered reading. + track_write (:obj:`bool`): Whether write access should be tracked. Deleting entries is + considered writing. + """ + + DELETED: ClassVar = object() + """Special marker indicating that an entry was deleted.""" + + __slots__ = ('_data', '_write_access_keys', '_read_access_keys', 'track_read', 'track_write') + + def __init__(self, default_factory: Callable[[], _VT], track_read: bool, track_write: bool): + # The default_factory argument for defaultdict is positional only! + self._data: DefaultDict[_KT, _VT] = defaultdict(default_factory) + self.track_read = track_read + self.track_write = track_write + self._write_access_keys: Set[_KT] = set() + self._read_access_keys: Set[_KT] = set() + + def __track_read(self, key: Union[_KT, Set[_KT]]) -> None: + if self.track_read: + if isinstance(key, set): + self._read_access_keys |= key + else: + self._read_access_keys.add(key) + + def __track_write(self, key: Union[_KT, Set[_KT]]) -> None: + if self.track_write: + if isinstance(key, set): + self._write_access_keys |= key + else: + self._write_access_keys.add(key) + + def __repr__(self) -> str: + return repr(self._data) + + def __str__(self) -> str: + return str(self._data) + + def __eq__(self, other: object) -> bool: + return other == self._data + + def pop_accessed_read_keys(self) -> Set[_KT]: + """Returns all keys that were read-accessed since the last time this method was called.""" + if not self.track_read: + raise RuntimeError('Not tracking read access!') + + out = self._read_access_keys + self._read_access_keys = set() + return out + + def pop_accessed_write_keys(self) -> Set[_KT]: + """Returns all keys that were write-accessed since the last time this method was called.""" + if not self.track_write: + raise RuntimeError('Not tracking write access!') + + out = self._write_access_keys + self._write_access_keys = set() + return out + + def pop_accessed_read_items(self) -> List[Tuple[_KT, _VT]]: + """ + Returns all keys & corresponding values as set of tuples that were read-accessed since + the last time this method was called. + """ + keys = self.pop_accessed_read_keys() + return [(key, self._data[key]) for key in keys] + + def pop_accessed_write_items(self) -> List[Tuple[_KT, _VT]]: + """ + Returns all keys & corresponding values as set of tuples that were write-accessed since + the last time this method was called. If a key was deleted, the value will be + :attr:`DELETED`. + """ + keys = self.pop_accessed_write_keys() + return [(key, self._data[key] if key in self._data else self.DELETED) for key in keys] + + # Implement abstract interface + + def __getitem__(self, key: _KT) -> _VT: + item = self._data[key] + self.__track_read(key) + return item + + def __setitem__(self, key: _KT, value: _VT) -> None: + self._data[key] = value + self.__track_write(key) + + def __delitem__(self, key: _KT) -> None: + del self._data[key] + self.__track_write(key) + + def __iter__(self) -> Iterator[_KT]: + for key in self._data: + self.__track_read(key) + yield key + + def __len__(self) -> int: + return len(self._data) + + def update_no_track(self, mapping: Mapping[_KT, _VT]) -> None: + return self._data.update(mapping) + + # Override some methods so that they fit better with the read/write access book keeping + + def __contains__(self, key: object) -> bool: + return key in self._data + + # Mypy seems a bit inconsistent about what it wants as types for `default` and return value + # so we just ignore a bit + def pop( # type: ignore[override] + self, key: _KT, default: _VT = DEFAULT_NONE # type: ignore[assignment] + ) -> _VT: + self.__track_write(key) + if isinstance(default, DefaultValue): + return self._data.pop(key) + return self._data.pop(key, default=default) + + def clear(self) -> None: + self.__track_write(set(self._data.keys())) + self._data.clear() + + # Mypy seems a bit inconsistent about what it wants as types for `default` and return value + # so we just ignore a bit + def setdefault(self: 'TrackingDefaultDict[_KT, _T]', key: _KT, default: _T = None) -> _T: + if key in self._data: + self.__track_read(key) + return self._data[key] + + self.__track_write(key) + self._data[key] = default # type: ignore[assignment] + return default # type: ignore[return-value] + + # Overriding to comply with the behavior of `defaultdict` + + @overload + def get(self, key: _KT) -> Optional[_VT]: # pylint: disable=arguments-differ + ... + + @overload + def get(self, key: _KT, default: _T) -> _T: # pylint: disable=signature-differs + ... + + def get(self, key: _KT, default: _T = None) -> Optional[Union[_VT, _T]]: + if key in self: + return self[key] + return default diff --git a/telegram/ext/_utils/types.py b/telegram/ext/_utils/types.py index 105091d7622..0c899404d09 100644 --- a/telegram/ext/_utils/types.py +++ b/telegram/ext/_utils/types.py @@ -25,14 +25,43 @@ user. Changes to this module are not considered breaking changes and may not be documented in the changelog. """ -from typing import TypeVar, TYPE_CHECKING, Tuple, List, Dict, Any, Optional, Union +from typing import ( + TypeVar, + TYPE_CHECKING, + Tuple, + List, + Dict, + Any, + Union, + Callable, + Coroutine, + MutableMapping, +) if TYPE_CHECKING: - from telegram.ext import CallbackContext, JobQueue, BasePersistence # noqa: F401 + from telegram.ext import CallbackContext, JobQueue, BasePersistence, Updater # noqa: F401 from telegram import Bot +CCT = TypeVar('CCT', bound='CallbackContext') +"""An instance of :class:`telegram.ext.CallbackContext` or a custom subclass. + +.. versionadded:: 13.6 +""" -ConversationDict = Dict[Tuple[int, ...], Optional[object]] +RT = TypeVar('RT') +UT = TypeVar('UT') +HandlerCallback = Callable[[UT, CCT], Coroutine[Any, Any, RT]] +"""Type of a handler callback + + .. versionadded:: 14.0 +""" +JobCallback = Callable[[CCT], Coroutine[Any, Any, Any]] +"""Type of a job callback + + .. versionadded:: 14.0 +""" + +ConversationDict = MutableMapping[Tuple[int, ...], object] """Dict[Tuple[:obj:`int`, ...], Optional[:obj:`object`]]: Dicts as maintained by the :class:`telegram.ext.ConversationHandler`. @@ -47,11 +76,6 @@ .. versionadded:: 13.6 """ -CCT = TypeVar('CCT', bound='CallbackContext') -"""An instance of :class:`telegram.ext.CallbackContext` or a custom subclass. - -.. versionadded:: 13.6 -""" BT = TypeVar('BT', bound='Bot') """Type of the bot. @@ -76,7 +100,3 @@ """Type of the job queue. .. versionadded:: 14.0""" -PT = TypeVar('PT', bound=Union[None, 'BasePersistence']) -"""Type of the persistence. - -.. versionadded:: 14.0""" diff --git a/telegram/ext/_utils/webhookhandler.py b/telegram/ext/_utils/webhookhandler.py index 5bdb70b40eb..1b7631c37f0 100644 --- a/telegram/ext/_utils/webhookhandler.py +++ b/telegram/ext/_utils/webhookhandler.py @@ -17,21 +17,18 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. # pylint: disable=missing-module-docstring - +import asyncio import logging -from queue import Queue +from http import HTTPStatus from ssl import SSLContext -from threading import Event, Lock -from typing import TYPE_CHECKING, Any, Optional +from types import TracebackType +from typing import TYPE_CHECKING, Optional, Type import tornado.web -from tornado import httputil from tornado.httpserver import HTTPServer -from tornado.ioloop import IOLoop from telegram import Update from telegram.ext import ExtBot -from telegram._utils.types import JSONDict if TYPE_CHECKING: from telegram import Bot @@ -43,55 +40,53 @@ class WebhookServer: + """Thin wrapper around ``tornado.httpserver.HTTPServer``.""" + __slots__ = ( - 'http_server', + '_http_server', 'listen', 'port', - 'loop', - 'logger', + '_logger', 'is_running', - 'server_lock', - 'shutdown_lock', + '_server_lock', + '_shutdown_lock', ) def __init__( - self, listen: str, port: int, webhook_app: 'WebhookAppClass', ssl_ctx: SSLContext + self, listen: str, port: int, webhook_app: 'WebhookAppClass', ssl_ctx: Optional[SSLContext] ): - self.http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx) + self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx) self.listen = listen self.port = port - self.loop: Optional[IOLoop] = None - self.logger = logging.getLogger(__name__) + self._logger = logging.getLogger(__name__) self.is_running = False - self.server_lock = Lock() - self.shutdown_lock = Lock() + self._server_lock = asyncio.Lock() + self._shutdown_lock = asyncio.Lock() - def serve_forever(self, ready: Event = None) -> None: - with self.server_lock: - IOLoop().make_current() - self.is_running = True - self.logger.debug('Webhook Server started.') - self.loop = IOLoop.current() - self.http_server.listen(self.port, address=self.listen) + async def serve_forever(self, ready: asyncio.Event = None) -> None: + async with self._server_lock: + self._http_server.listen(self.port, address=self.listen) + self.is_running = True if ready is not None: ready.set() - self.loop.start() - self.logger.debug('Webhook Server stopped.') - self.is_running = False + self._logger.debug('Webhook Server started.') - def shutdown(self) -> None: - with self.shutdown_lock: + async def shutdown(self) -> None: + async with self._shutdown_lock: if not self.is_running: - self.logger.warning('Webhook Server already stopped.') + self._logger.warning('Webhook Server already stopped.') return - self.loop.add_callback(self.loop.stop) # type: ignore + self.is_running = False + self._http_server.stop() + await self._http_server.close_all_connections() + self._logger.debug('Webhook Server stopped') # pylint: disable=unused-argument def handle_error(self, request: object, client_address: str) -> None: """Handle an error gracefully.""" - self.logger.debug( + self._logger.debug( 'Exception happened during processing of request from %s', client_address, exc_info=True, @@ -99,75 +94,72 @@ def handle_error(self, request: object, client_address: str) -> None: class WebhookAppClass(tornado.web.Application): - def __init__(self, webhook_path: str, bot: 'Bot', update_queue: Queue): + """Application used in the Webserver""" + + def __init__(self, webhook_path: str, bot: 'Bot', update_queue: asyncio.Queue): self.shared_objects = {"bot": bot, "update_queue": update_queue} - handlers = [(rf"{webhook_path}/?", WebhookHandler, self.shared_objects)] # noqa + handlers = [(rf"{webhook_path}/?", TelegramHandler, self.shared_objects)] # noqa tornado.web.Application.__init__(self, handlers) # type: ignore - def log_request(self, handler: tornado.web.RequestHandler) -> None: # skipcq: PTC-W0049 - pass + def log_request(self, handler: tornado.web.RequestHandler) -> None: + """Overrides the default implementation since we have our own logging setup.""" -# WebhookHandler, process webhook calls # pylint: disable=abstract-method -class WebhookHandler(tornado.web.RequestHandler): - SUPPORTED_METHODS = ["POST"] # type: ignore +class TelegramHandler(tornado.web.RequestHandler): + """Handler that processes incoming requests from Telegram""" - def __init__( - self, - application: tornado.web.Application, - request: httputil.HTTPServerRequest, - **kwargs: JSONDict, - ): - super().__init__(application, request, **kwargs) - self.logger = logging.getLogger(__name__) + __slots__ = ('bot', 'update_queue', '_logger') + + SUPPORTED_METHODS = ("POST",) # type: ignore[assignment] - def initialize(self, bot: 'Bot', update_queue: Queue) -> None: + def initialize(self, bot: 'Bot', update_queue: asyncio.Queue) -> None: + """Initialize for each request - that's the interface provided by tornado""" # pylint: disable=attribute-defined-outside-init self.bot = bot self.update_queue = update_queue + self._logger = logging.getLogger(__name__) def set_default_headers(self) -> None: + """Sets default headers""" self.set_header("Content-Type", 'application/json; charset="utf-8"') - def post(self) -> None: - self.logger.debug('Webhook triggered') + async def post(self) -> None: + """Handle incoming POST request""" + self._logger.debug('Webhook triggered') self._validate_post() + json_string = self.request.body.decode() data = json.loads(json_string) - self.set_status(200) - self.logger.debug('Webhook received data: %s', json_string) + self.set_status(HTTPStatus.OK) + self._logger.debug('Webhook received data: %s', json_string) + update = Update.de_json(data, self.bot) if update: - self.logger.debug('Received Update with ID %d on Webhook', update.update_id) + self._logger.debug('Received Update with ID %d on Webhook', update.update_id) + # handle arbitrary callback data, if necessary if isinstance(self.bot, ExtBot): self.bot.insert_callback_data(update) - self.update_queue.put(update) + + await self.update_queue.put(update) def _validate_post(self) -> None: + """Only accept requests with content type JSON""" ct_header = self.request.headers.get("Content-Type", None) if ct_header != 'application/json': - raise tornado.web.HTTPError(403) - - def write_error(self, status_code: int, **kwargs: Any) -> None: - """Log an arbitrary message. + raise tornado.web.HTTPError(HTTPStatus.FORBIDDEN) - This is used by all other logging functions. - - It overrides ``BaseHTTPRequestHandler.log_message``, which logs to ``sys.stderr``. - - The first argument, FORMAT, is a format string for the message to be logged. If the format - string contains any % escapes requiring parameters, they should be specified as subsequent - arguments (it's just like printf!). - - The client ip is prefixed to every message. - - """ - super().write_error(status_code, **kwargs) - self.logger.debug( - "%s - - %s", + def log_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + """Override the default logging and instead use our custom logging.""" + self._logger.debug( + "%s - %s", self.request.remote_ip, - "Exception in WebhookHandler", - exc_info=kwargs['exc_info'], + "Exception in TelegramHandler", + exc_info=(typ, value, tb) if typ and value and tb else value, ) diff --git a/telegram/request.py b/telegram/request.py deleted file mode 100644 index 20cc485d5ef..00000000000 --- a/telegram/request.py +++ /dev/null @@ -1,405 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains the Request class which handles the communication with the Telegram -servers. -""" - -__all__ = ('Request',) - -import logging -import os -import socket -import sys -import warnings -from pathlib import Path - -try: - import ujson as json -except ImportError: - import json # type: ignore[no-redef] - -from typing import Any, Union - -import certifi - -try: - from telegram.vendor.ptb_urllib3 import urllib3 - from telegram.vendor.ptb_urllib3.urllib3.contrib import appengine - from telegram.vendor.ptb_urllib3.urllib3.connection import HTTPConnection - from telegram.vendor.ptb_urllib3.urllib3.fields import RequestField - from telegram.vendor.ptb_urllib3.urllib3.util.timeout import Timeout -except ImportError: # pragma: no cover - try: - import urllib3 # type: ignore[no-redef] - from urllib3.contrib import appengine # type: ignore[no-redef] - from urllib3.connection import HTTPConnection # type: ignore[no-redef] - from urllib3.fields import RequestField # type: ignore[no-redef] - from urllib3.util.timeout import Timeout # type: ignore[no-redef] - - warnings.warn( - 'python-telegram-bot is using upstream urllib3. This is allowed but not ' - 'supported by python-telegram-bot maintainers.' - ) - except ImportError: - warnings.warn( - "python-telegram-bot wasn't properly installed. Please refer to README.rst on " - "how to properly install." - ) - raise - -# pylint: disable=ungrouped-imports -from telegram import InputFile -from telegram.error import ( - TelegramError, - BadRequest, - ChatMigrated, - Conflict, - InvalidToken, - NetworkError, - RetryAfter, - TimedOut, - Unauthorized, -) -from telegram._utils.types import JSONDict, FilePathInput - - -# pylint: disable=unused-argument -def _render_part(self: RequestField, name: str, value: str) -> str: - r""" - Monkey patch urllib3.urllib3.fields.RequestField to make it *not* support RFC2231 compliant - Content-Disposition headers since telegram servers don't understand it. Instead just escape - \\ and " and replace any \n and \r with a space. - - """ - value = value.replace('\\', '\\\\').replace('"', '\\"') - value = value.replace('\r', ' ').replace('\n', ' ') - return f'{name}="{value}"' - - -RequestField._render_part = _render_part # type: ignore # pylint: disable=protected-access - -logging.getLogger('telegram.vendor.ptb_urllib3.urllib3').setLevel(logging.WARNING) - -USER_AGENT = 'Python Telegram Bot (https://github.com/python-telegram-bot/python-telegram-bot)' - - -class Request: - """Helper class for python-telegram-bot which provides methods to perform POST & GET towards - Telegram servers. - - Args: - con_pool_size (:obj:`int`): Number of connections to keep in the connection pool. - proxy_url (:obj:`str`): The URL to the proxy server. For example: `http://127.0.0.1:3128`. - urllib3_proxy_kwargs (:obj:`dict`): Arbitrary arguments passed as-is to - :obj:`urllib3.ProxyManager`. This value will be ignored if :attr:`proxy_url` is not - set. - connect_timeout (:obj:`int` | :obj:`float`): The maximum amount of time (in seconds) to - wait for a connection attempt to a server to succeed. :obj:`None` will set an - infinite timeout for connection attempts. Defaults to ``5.0``. - read_timeout (:obj:`int` | :obj:`float`): The maximum amount of time (in seconds) to wait - between consecutive read operations for a response from the server. :obj:`None` will - set an infinite timeout. This value is usually overridden by the various - :class:`telegram.Bot` methods. Defaults to ``5.0``. - - """ - - __slots__ = ('_connect_timeout', '_con_pool_size', '_con_pool') - - def __init__( - self, - con_pool_size: int = 1, - proxy_url: str = None, - urllib3_proxy_kwargs: JSONDict = None, - connect_timeout: float = 5.0, - read_timeout: float = 5.0, - ): - if urllib3_proxy_kwargs is None: - urllib3_proxy_kwargs = {} - - self._connect_timeout = connect_timeout - - sockopts = HTTPConnection.default_socket_options + [ - (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - ] - - # TODO: Support other platforms like mac and windows. - if 'linux' in sys.platform: - sockopts.append( - (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 120) # pylint: disable=no-member - ) - sockopts.append( - (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30) # pylint: disable=no-member - ) - sockopts.append( - (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 8) # pylint: disable=no-member - ) - - self._con_pool_size = con_pool_size - - kwargs = dict( - maxsize=con_pool_size, - cert_reqs='CERT_REQUIRED', - ca_certs=certifi.where(), - socket_options=sockopts, - timeout=urllib3.Timeout(connect=self._connect_timeout, read=read_timeout, total=None), - ) - - # Set a proxy according to the following order: - # * proxy defined in proxy_url (+ urllib3_proxy_kwargs) - # * proxy set in `HTTPS_PROXY` env. var. - # * proxy set in `https_proxy` env. var. - # * None (if no proxy is configured) - - if not proxy_url: - proxy_url = os.environ.get('HTTPS_PROXY') or os.environ.get('https_proxy') - - self._con_pool: Union[ - urllib3.PoolManager, - appengine.AppEngineManager, - 'SOCKSProxyManager', # noqa: F821 - urllib3.ProxyManager, - ] = None # type: ignore - if not proxy_url: - if appengine.is_appengine_sandbox(): - # Use URLFetch service if running in App Engine - self._con_pool = appengine.AppEngineManager() - else: - self._con_pool = urllib3.PoolManager(**kwargs) - else: - kwargs.update(urllib3_proxy_kwargs) - if proxy_url.startswith('socks'): - try: - # pylint: disable=import-outside-toplevel - from telegram.vendor.ptb_urllib3.urllib3.contrib.socks import SOCKSProxyManager - except ImportError as exc: - raise RuntimeError('PySocks is missing') from exc - self._con_pool = SOCKSProxyManager(proxy_url, **kwargs) - else: - mgr = urllib3.proxy_from_url(proxy_url, **kwargs) - if mgr.proxy.auth: - # TODO: what about other auth types? - auth_hdrs = urllib3.make_headers(proxy_basic_auth=mgr.proxy.auth) - mgr.proxy_headers.update(auth_hdrs) - - self._con_pool = mgr - - @property - def con_pool_size(self) -> int: - """The size of the connection pool used.""" - return self._con_pool_size - - def stop(self) -> None: - """Performs cleanup on shutdown.""" - self._con_pool.clear() # type: ignore - - @staticmethod - def _parse(json_data: bytes) -> Union[JSONDict, bool]: - """Try and parse the JSON returned from Telegram. - - Returns: - dict: A JSON parsed as Python dict with results - on error this dict will be empty. - - """ - decoded_s = json_data.decode('utf-8', 'replace') - try: - data = json.loads(decoded_s) - except ValueError as exc: - raise TelegramError('Invalid server response') from exc - - if not data.get('ok'): # pragma: no cover - description = data.get('description') - parameters = data.get('parameters') - if parameters: - migrate_to_chat_id = parameters.get('migrate_to_chat_id') - if migrate_to_chat_id: - raise ChatMigrated(migrate_to_chat_id) - retry_after = parameters.get('retry_after') - if retry_after: - raise RetryAfter(retry_after) - if description: - return description - - return data['result'] - - def _request_wrapper(self, *args: object, **kwargs: Any) -> bytes: - """Wraps urllib3 request for handling known exceptions. - - Args: - args: unnamed arguments, passed to urllib3 request. - kwargs: keyword arguments, passed to urllib3 request. - - Returns: - bytes: A non-parsed JSON text. - - Raises: - TelegramError - - """ - # Make sure to hint Telegram servers that we reuse connections by sending - # "Connection: keep-alive" in the HTTP headers. - if 'headers' not in kwargs: - kwargs['headers'] = {} - kwargs['headers']['connection'] = 'keep-alive' - # Also set our user agent - kwargs['headers']['user-agent'] = USER_AGENT - - try: - resp = self._con_pool.request(*args, **kwargs) - except urllib3.exceptions.TimeoutError as error: - raise TimedOut() from error - except urllib3.exceptions.HTTPError as error: - # HTTPError must come last as its the base urllib3 exception class - # TODO: do something smart here; for now just raise NetworkError - raise NetworkError(f'urllib3 HTTPError {error}') from error - - if 200 <= resp.status <= 299: - # 200-299 range are HTTP success statuses - return resp.data - - try: - message = str(self._parse(resp.data)) - except ValueError: - message = 'Unknown HTTPError' - - if resp.status in (401, 403): - raise Unauthorized(message) - if resp.status == 400: - raise BadRequest(message) - if resp.status == 404: - raise InvalidToken() - if resp.status == 409: - raise Conflict(message) - if resp.status == 413: - raise NetworkError( - 'File too large. Check telegram api limits ' - 'https://core.telegram.org/bots/api#senddocument' - ) - if resp.status == 502: - raise NetworkError('Bad Gateway') - raise NetworkError(f'{message} ({resp.status})') - - def post(self, url: str, data: JSONDict, timeout: float = None) -> Union[JSONDict, bool]: - """Request an URL. - - Args: - url (:obj:`str`): The web location we want to retrieve. - data (Dict[:obj:`str`, :obj:`str` | :obj:`int`], optional): A dict of key/value pairs. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). - - Returns: - A JSON object. - - """ - urlopen_kwargs = {} - - if timeout is not None: - urlopen_kwargs['timeout'] = Timeout(read=timeout, connect=self._connect_timeout) - - if data is None: - data = {} - - # Are we uploading files? - files = False - - # pylint: disable=too-many-nested-blocks - for key, val in data.copy().items(): - if isinstance(val, InputFile): - # Convert the InputFile to urllib3 field format - data[key] = val.field_tuple - files = True - elif isinstance(val, (float, int)): - # Urllib3 doesn't like floats it seems - data[key] = str(val) - elif key == 'media': - files = True - # List of media - if isinstance(val, list): - # Attach and set val to attached name for all - media = [] - for med in val: - media_dict = med.to_dict() - media.append(media_dict) - if isinstance(med.media, InputFile): - data[med.media.attach] = med.media.field_tuple # type: ignore[index] - # if the file has a thumb, we also need to attach it to the data - if "thumb" in media_dict: - data[med.thumb.attach] = med.thumb.field_tuple - data[key] = json.dumps(media) - # Single media - else: - # Attach and set val to attached name - media_dict = val.to_dict() - if isinstance(val.media, InputFile): - data[val.media.attach] = val.media.field_tuple # type: ignore[index] - # if the file has a thumb, we also need to attach it to the data - if "thumb" in media_dict: - data[val.thumb.attach] = val.thumb.field_tuple - data[key] = json.dumps(media_dict) - elif isinstance(val, list): - # In case we're sending files, we need to json-dump lists manually - # As we can't know if that's the case, we just json-dump here - data[key] = json.dumps(val) - - # Use multipart upload if we're uploading files, otherwise use JSON - if files: - result = self._request_wrapper('POST', url, fields=data, **urlopen_kwargs) - else: - result = self._request_wrapper( - 'POST', - url, - body=json.dumps(data).encode('utf-8'), - headers={'Content-Type': 'application/json'}, - **urlopen_kwargs, - ) - - return self._parse(result) - - def retrieve(self, url: str, timeout: float = None) -> bytes: - """Retrieve the contents of a file by its URL. - - Args: - url (:obj:`str`): The web location we want to retrieve. - timeout (:obj:`int` | :obj:`float`): If this value is specified, use it as the read - timeout from the server (instead of the one specified during creation of the - connection pool). - - """ - urlopen_kwargs = {} - if timeout is not None: - urlopen_kwargs['timeout'] = Timeout(read=timeout, connect=self._connect_timeout) - - return self._request_wrapper('GET', url, **urlopen_kwargs) - - def download(self, url: str, filepath: FilePathInput, timeout: float = None) -> None: - """Download a file by its URL. - - Args: - url (:obj:`str`): The web location we want to retrieve. - filepath (:obj:`pathlib.Path` | :obj:`str`): The filepath to download the file to. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). - - .. versionchanged:: 14.0 - The ``filepath`` parameter now also accepts :obj:`pathlib.Path` objects as argument. - - """ - Path(filepath).write_bytes(self.retrieve(url, timeout)) diff --git a/telegram/request/__init__.py b/telegram/request/__init__.py new file mode 100644 index 00000000000..98e9a676740 --- /dev/null +++ b/telegram/request/__init__.py @@ -0,0 +1,24 @@ +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains classes that handle the networking backend of ``python-telegram-bot``.""" + +from ._requestdata import RequestData +from ._baserequest import BaseRequest +from ._httpxrequest import HTTPXRequest + +__all__ = ('BaseRequest', 'HTTPXRequest', 'RequestData') diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py new file mode 100644 index 00000000000..260e3ace1f5 --- /dev/null +++ b/telegram/request/_baserequest.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains an abstract class to make POST and GET requests.""" +import abc +import traceback +from contextlib import AbstractAsyncContextManager +from http import HTTPStatus +from types import TracebackType +from typing import Union, Tuple, Type, Optional, ClassVar, TypeVar + +try: + import ujson as json +except ImportError: + import json # type: ignore[no-redef] + +from telegram._version import __version__ as ptb_ver +from telegram.request import RequestData + +from telegram.error import ( + TelegramError, + BadRequest, + ChatMigrated, + Conflict, + InvalidToken, + NetworkError, + RetryAfter, + Forbidden, +) +from telegram._utils.types import JSONDict, ODVInput +from telegram._utils.defaultvalue import DEFAULT_NONE as _DEFAULT_NONE + +RT = TypeVar('RT', bound='BaseRequest') + + +class BaseRequest( + AbstractAsyncContextManager, + abc.ABC, +): + """Abstract interface class that allows python-telegram-bot to make requests to the Bot API. + Can be implemented via different asyncio HTTP libraries. An implementation of this class + must implement all abstract methods and properties. In addition, :attr:`connection_pool_size` + can optionally be overridden. + + Instances of this class can be used as asyncio context managers, where + + .. code:: python + + async with request_object: + # code + + is roughly equivalent to + + .. code:: python + + try: + await request_object.initialize() + # code + finally: + await request_object.stop() + """ + + __slots__ = () + + USER_AGENT: ClassVar[str] = f'python-telegram-bot v{ptb_ver} (https://python-telegram-bot.org)' + """:obj:`str`: A description that can be used as user agent for requests made to the Bot API. + """ + DEFAULT_NONE: ClassVar = _DEFAULT_NONE + """:class:`object`: A special object that indicates that an argument of a function was not + explicitly passed. Used for the timeout parameters of :meth:`post` and :meth:`do_request`. + + Example: + When calling ``request.post(url)``, ``request`` should use the default timeouts set on + initialization. When calling ``request.post(url, connect_timeout=5, read_timeout=None)``, + ``request`` should use ``5`` for the connect timeout and :obj:`None` for the read timeout. + + Use ``if parameter is (not) BaseRequest.DEFAULT_NONE:`` to check if the parameter was set. + """ + + async def __aenter__(self: RT) -> RT: + try: + await self.initialize() + return self + except Exception as exc: + await self.shutdown() + raise exc + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # Make sure not to return `True` so that exceptions are not suppressed + # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ + await self.shutdown() + + @abc.abstractmethod + async def initialize(self) -> None: + """Initialize resources used by this class. Must be implemented by a subclass.""" + + @abc.abstractmethod + async def shutdown(self) -> None: + """Stop & clear resources used by this class. Must be implemented by a subclass.""" + + async def post( + self, + url: str, + request_data: RequestData = None, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> Union[JSONDict, bool]: + """Makes a request to the Bot API handles the return code and parses the answer. + + Warning: + This method will be called by the methods of :class:`Bot` and should *not* be called + manually. + + + Args: + url (:obj:`str`): The URL to request. + request_data (:class:`telegram.request.RequestData`, optional): An object containing + information about parameters and files to upload for the request. + connect_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of + time (in seconds) to wait for a connection attempt to a server to succeed instead + of the time specified during creating of this object. + read_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of time + (in seconds) to wait for a response from Telegram's server instead + of the time specified during creating of this object. + write_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of time + (in seconds) to wait for a write operation to complete (in terms of a network + socket; i.e. POSTing a request or uploading a file) instead + of the time specified during creating of this object. + pool_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of time + (in seconds) to wait for a connection to become available instead + of the time specified during creating of this object. + + Returns: + Dict[:obj:`str`, ...]: The JSON response of the Bot API. + + """ + result = await self._request_wrapper( + url=url, + method='POST', + request_data=request_data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) + json_data = self._parse_json_response(result) + # For successful requests, the results are in the 'result' entry + # see https://core.telegram.org/bots/api#making-requests + return json_data['result'] + + async def retrieve( + self, + url: str, + read_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> bytes: + """Retrieve the contents of a file by its URL. + + Warning: + This method will be called by the methods of :class:`Bot` and should *not* be called + manually. + + Args: + url (:obj:`str`): The web location we want to retrieve. + timeout (:obj:`float`, optional): If this value is specified, use it as the read + timeout from the server (instead of the one specified during creation of the + connection pool). + + Returns: + :obj:`bytes`: The files contents. + + """ + return await self._request_wrapper( + url=url, + method='GET', + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) + + async def _request_wrapper( + self, + url: str, + method: str, + request_data: RequestData = None, + read_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> bytes: + """Wraps the real implementation request method. + + Performs the following tasks: + * Handle the various HTTP response codes. + * Parse the Telegram server response. + + Args: + url (:obj:`str`): The URL to request. + method (:obj:`str`): HTTP method (i.e. 'POST', 'GET', etc.). + url (:obj:`str`): The request's URL. + request_data (:class:`telegram.request.RequestData`, optional): An object containing + information about parameters and files to upload for the request. + read_timeout: Timeout for waiting to server's response. + + Returns: + bytes: The payload part of the HTTP server response. + + Raises: + TelegramError + + """ + # TGs response also has the fields 'ok' and 'error_code'. + # However, we rather rely on the HTTP status code for now. + + try: + code, payload = await self.do_request( + url=url, + method=method, + request_data=request_data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) + except TelegramError as exc: + raise exc + except Exception as exc: + traceback.print_tb(exc.__traceback__) + raise NetworkError(f"Unknown error in HTTP implementation: {repr(exc)}") from exc + + if HTTPStatus.OK <= code <= 299: + # 200-299 range are HTTP success statuses + return payload + + response_data = self._parse_json_response(payload) + + # In some special cases, we ca raise more informative exceptions: + # see https://core.telegram.org/bots/api#responseparameters and + # https://core.telegram.org/bots/api#making-requests + parameters = response_data.get('parameters') + if parameters: + migrate_to_chat_id = parameters.get('migrate_to_chat_id') + if migrate_to_chat_id: + raise ChatMigrated(migrate_to_chat_id) + retry_after = parameters.get('retry_after') + if retry_after: + raise RetryAfter(retry_after) + + description = response_data.get('description') + if description: + message = description + else: + message = 'Unknown HTTPError' + + if code == HTTPStatus.FORBIDDEN: + raise Forbidden(message) + if code in (HTTPStatus.NOT_FOUND, HTTPStatus.UNAUTHORIZED): + # TG returns 404 Not found for + # 1) malformed tokens + # 2) correct tokens but non-existing method, e.g. api.tg.org/botTOKEN/unkonwnMethod + # We can basically rule out 2) since we don't let users make requests manually + # TG returns 401 Unauthorized for correctly formatted tokens that are not valid + raise InvalidToken(message) + if code == HTTPStatus.BAD_REQUEST: + raise BadRequest(message) + if code == HTTPStatus.CONFLICT: + raise Conflict(message) + if code == HTTPStatus.BAD_GATEWAY: + raise NetworkError(description or 'Bad Gateway') + raise NetworkError(f'{message} ({code})') + + @staticmethod + def _parse_json_response(json_payload: bytes) -> JSONDict: + """Try and parse the JSON returned from Telegram. + + Returns: + dict: A JSON parsed as Python dict with results. + + Raises: + TelegramError: If the data could not be json_loaded + """ + decoded_s = json_payload.decode('utf-8', 'replace') + try: + return json.loads(decoded_s) + except ValueError as exc: + raise TelegramError('Invalid server response') from exc + + @abc.abstractmethod + async def do_request( + self, + url: str, + method: str, + request_data: RequestData = None, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> Tuple[int, bytes]: + """Makes a request to the Bot API. Must be implemented by a subclass. + + Warning: + This method will be called by :meth:`post` and :meth:`retrieve`. It should *not* be + called manually. + + Args: + url (:obj:`str`): The URL to request. + method (:obj:`str`): HTTP method (i.e. ``'POST'``, ``'GET'``, etc.). + request_data (:class:`telegram.request.RequestData`, optional): An object containing + information about parameters and files to upload for the request. + read_timeout (:obj:`float`, optional): If this value is specified, use it as the read + timeout from the server (instead of the one specified during creation of the + connection pool). + write_timeout (:obj:`float`, optional): If this value is specified, use it as the write + timeout from the server (instead of the one specified during creation of the + connection pool). + + Returns: + Tuple[:obj:`int`, :obj:`bytes`]: The HTTP return code & the payload part of the server + response. + """ diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py new file mode 100644 index 00000000000..6654eb89b2f --- /dev/null +++ b/telegram/request/_httpxrequest.py @@ -0,0 +1,211 @@ +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains methods to make POST and GET requests using the httpx library.""" +import asyncio +import logging +from typing import Tuple, Optional + +import httpx + +from telegram._utils.defaultvalue import DefaultValue +from telegram._utils.types import ODVInput +from telegram.error import TimedOut, NetworkError +from telegram.request import BaseRequest, RequestData + + +# Note to future devs: +# Proxies are currently only tested manually. The httpx development docs have a nice guide on that: +# https://www.python-httpx.org/contributing/#development-proxy-setup (also saved on archive.org) +# That also works with socks5. Just pass `--mode socks5` to mitmproxy + +_logger = logging.getLogger(__name__) + + +class HTTPXRequest(BaseRequest): + """Implementation of :class:`~telegram.request.BaseRequest` using the library + `httpx `_. + + Args: + connection_pool_size (:obj:`int`, optional): Number of connections to keep in the + connection pool. Defaults to :obj:`1`. + + Note: + Independent of the value, one additional connection will be reserved for + :meth:`telegram.Bot.get_updates`. + proxy_url (:obj:`str`, optional): The URL to the proxy server. For example + ``'http://127.0.0.1:3128'`` or ``'socks5://127.0.0.1:3128'``. Defaults to :obj:`None`. + + Note: + * The proxy URL can also be set via the environment variables ``HTTPS_PROXY`` or + ``ALL_PROXY``. See `the docs`_ of ``httpx`` for more info. + * For Socks5 support, additional dependencies are required. Make sure to install + PTB via ``pip install python-telegram-bot[socks]`` in this case. + * Socks5 proxies can not be set via environment variables. + + .. _the docs: https://www.python-httpx.org/environment_variables/#proxies + connect_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait + for a connection attempt to a server to succeed. :obj:`None` will set an infinite + timeout for connection attempts. Defaults to ``5.0``. + read_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait for + a response from Telegram's server. :obj:`None` will set an infinite timeout. This value + is usually overridden by the various methods of :class:`telegram.Bot`. Defaults to + ``5.0``. + write_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait for + a write operation to complete (in terms of a network socket; i.e. POSTing a request or + uploading a file).:obj:`None` will set an infinite timeout. Defaults to ``5.0``. + pool_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait for + a connection from the connection pool becoming available. :obj:`None` will set an + infinite timeout. Defaults to :obj:`None`. + + Warning: + With a finite pool timeout, you must expect :exc:`telegram.error.TimeOut` + exceptions to be thrown when more requests are made simultaneously than there are + connections in the connection pool! + """ + + __slots__ = ('_client', '__pool_semaphore') + + def __init__( + self, + connection_pool_size: int = 1, + proxy_url: str = None, + connect_timeout: Optional[float] = 5.0, + read_timeout: Optional[float] = 5.0, + write_timeout: Optional[float] = 5.0, + pool_timeout: Optional[float] = 1.0, + ): + self.__pool_semaphore = asyncio.BoundedSemaphore(connection_pool_size) + self._pool_timeout = pool_timeout + + timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=write_timeout, + pool=1, + ) + limits = httpx.Limits( + max_connections=connection_pool_size, + max_keepalive_connections=connection_pool_size, + ) + + self._client = httpx.AsyncClient( + timeout=timeout, + proxies=proxy_url, + limits=limits, + ) + + async def initialize(self) -> None: + """See :meth:`BaseRequest.initialize`.""" + + async def shutdown(self) -> None: + """See :meth:`BaseRequest.stop`.""" + await self._client.aclose() + + async def do_request( + self, + url: str, + method: str, + request_data: RequestData = None, + connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + read_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + write_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + pool_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + ) -> Tuple[int, bytes]: + """See :meth:`BaseRequest.do_request`.""" + if isinstance(pool_timeout, DefaultValue): + pool_timeout = self._pool_timeout + + if pool_timeout != 0 and self.__pool_semaphore.locked(): + _logger.debug( + 'All connections in the pool are currently busy. Waiting pool_timeout=%s for ' + 'a connection to become available.', + pool_timeout, + ) + + try: + await asyncio.wait_for(self.__pool_semaphore.acquire(), timeout=pool_timeout) + except asyncio.TimeoutError as exc: + raise TimedOut('Pool timeout') from exc + + try: + out = await self._do_request( + url=url, + method=method, + request_data=request_data, + connect_timeout=connect_timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + ) + return out + finally: + self.__pool_semaphore.release() + + async def _do_request( + self, + url: str, + method: str, + request_data: RequestData = None, + connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + read_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + write_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + ) -> Tuple[int, bytes]: + timeout = httpx.Timeout( + connect=self._client.timeout.connect, + read=self._client.timeout.read, + write=self._client.timeout.write, + pool=1, + ) + if not isinstance(read_timeout, DefaultValue): + timeout.read = read_timeout + if not isinstance(write_timeout, DefaultValue): + timeout.write = write_timeout + if not isinstance(connect_timeout, DefaultValue): + timeout.connect = connect_timeout + + # TODO p0: On Linux, use setsockopt to properly set socket level keepalive. + # (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 120) + # (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30) + # (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 8) + # TODO p4: Support setsockopt on lesser platforms than Linux. + + files = request_data.multipart_data if request_data else None + data = request_data.json_parameters if request_data else None + + try: + res = await self._client.request( + method=method, + url=url, + headers={'User-Agent': self.USER_AGENT}, + timeout=timeout, + files=files, + data=data, + ) + except httpx.TimeoutException as err: + if isinstance(err, httpx.PoolTimeout): + _logger.critical( + 'All connections in the connection pool are occupied. Request was *not* sent ' + 'to Telegram. Adjust connection pool size!', + ) + raise TimedOut('Pool timeout') from err + raise TimedOut from err + except httpx.HTTPError as err: + # HTTPError must come last as its the base httpx exception class + # TODO p4: do something smart here; for now just raise NetworkError + raise NetworkError(f'httpx HTTPError: {err}') from err + + return res.status_code, res.content diff --git a/telegram/request/_requestdata.py b/telegram/request/_requestdata.py new file mode 100644 index 00000000000..f84a9eb1259 --- /dev/null +++ b/telegram/request/_requestdata.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains an class that holds a parameters of a request to the Bot API.""" +from typing import List, Dict, Any, Union +from urllib.parse import urlencode + +from telegram._utils.types import UploadFileDict +from telegram.request._requestparameter import RequestParameter + +try: + import ujson as json +except ImportError: + import json # type: ignore[no-redef] # noqa: F723 + + +class RequestData: + """Instances of this class collect the data needed for one request to the Bot API, including + all parameters and files to be sent along with the request. + + .. versionadded:: 14.0 + + Warning: + How exactly instances of this will are created should be considered an implementation + detail and not part of PTBs public API. Users should exclusively rely on the documented + attributes, properties and methods. + + Attributes: + contains_files (:obj:`bool`): Whether this object contains files to be uploaded via + ``multipart/form-data``. + """ + + __slots__ = ('_parameters', 'contains_files') + + def __init__( + self, + parameters: List[RequestParameter] = None, + ): + self._parameters = parameters or [] + self.contains_files = any(param.input_files for param in self._parameters) + + @property + def parameters(self) -> Dict[str, Union[str, int, List, Dict]]: + """Gives the parameters as mapping of parameter name to the parameter value, which can be + a single object of type :obj:`int`, :obj:`float`, :obj:`str` or :obj:`bool` or any + (possibly nested) composition of lists, tuples and dictionaries, where each entry, key + and value is of one of the mentioned types. + """ + return {param.name: param.value for param in self._parameters} # type: ignore[misc] + + @property + def json_parameters(self) -> Dict[str, str]: + """Gives the parameters as mapping of parameter name to the respective JSON encoded + value. + """ + return {param.name: param.json_value for param in self._parameters} + + def url_encoded_parameters(self, encode_kwargs: Dict[str, Any] = None) -> str: + """Encodes the parameters with :meth:`urllib.parse.urlencode`. + + Args: + encode_kwargs (Dict[:obj:`str`, any], optional): Additional keyword arguments to pass + along to :meth:`urllib.parse.urlencode`. + """ + if encode_kwargs: + return urlencode(self.json_parameters, **encode_kwargs) + return urlencode(self.json_parameters) + + def parametrized_url(self, url: str, encode_kwargs: Dict[str, Any] = None) -> str: + """Shortcut for attaching the return value of :meth:`url_encoded_parameters` to the + :attr:`url`. + + Args: + url (:obj:`str`): The URL the parameters will be attached to. + encode_kwargs (Dict[:obj:`str`, any], optional): Additional keyword arguments to pass + along to :meth:`urllib.parse.urlencode`. + """ + url_parameters = self.url_encoded_parameters(encode_kwargs=encode_kwargs) + return f'{url}?{url_parameters}' + + @property + def json_payload(self) -> bytes: + """The parameters as UTF-8 encoded JSON payload.""" + return json.dumps(self.json_parameters).encode('utf-8') + + @property + def multipart_data(self) -> UploadFileDict: + """Gives the files contained in this object as mapping of part name to encoded content.""" + multipart_data: UploadFileDict = {} + for param in self._parameters: + m_data = param.multipart_data + if m_data: + multipart_data.update(m_data) + return multipart_data diff --git a/telegram/request/_requestparameter.py b/telegram/request/_requestparameter.py new file mode 100644 index 00000000000..da20cabe700 --- /dev/null +++ b/telegram/request/_requestparameter.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains an class that describes a single parameter of a request to the Bot API.""" +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Optional, List, Tuple + +from telegram import InputFile, InputMedia, TelegramObject +from telegram._utils.datetime import to_timestamp +from telegram._utils.types import UploadFileDict + +try: + import ujson as json +except ImportError: + import json # type: ignore[no-redef] # noqa: F723 + + +@dataclass(repr=False, eq=False, order=False, frozen=True) +class RequestParameter: + """Instances of this class represent a single parameter to be sent along with a request to + the Bot API. + + .. versionadded:: 14.0 + + Warning: + This class intended is to be used internally by the library and *not* by the user. Changes + to this class are not considered breaking changes and may not be documented in the + changelog. + + Args: + name (:obj:`str`): The name of the parameter. + value (:obj:`object`): The value of the parameter. Must be JSON-dumpable. + input_files (List[:class:`telegram.InputFile`, optional): A list of files that should be + uploaded along with this parameter. + + Attributes: + name (:obj:`str`): The name of the parameter. + value (:obj:`object`): The value of the parameter. + input_files (List[:class:`telegram.InputFile` | :obj:`None`): A list of files that should + be uploaded along with this parameter. + """ + + __slots__ = ('name', 'value', 'input_files') + + name: str + value: object + input_files: Optional[List[InputFile]] + + @property + def json_value(self) -> str: + """The JSON dumped :attr:`value`""" + if isinstance(self.value, str): + return self.value + return json.dumps(self.value) + + @property + def multipart_data(self) -> Optional[UploadFileDict]: + """A dict with the file data to upload, if any.""" + if not self.input_files: + return None + return {input_file.attach_name: input_file.field_tuple for input_file in self.input_files} + + @staticmethod + def _value_and_input_file_from_input( # pylint: disable=too-many-return-statements + value: object, + ) -> Tuple[object, List[InputFile]]: + """Converts `value` into something that we can json-dump. If `value` contains a file to be + uploaded, it will be returned as second return value and the corresponding attach:// value + will be returned as first return value. + Note that we use this for *all* files to be uploaded. This is not documented in the + official API, but has been confirmed to be supported in the official Bot API repository. + See https://github.com/tdlib/telegram-bot-api/issues/167 + """ + if isinstance(value, datetime): + return to_timestamp(value), [] + if isinstance(value, Enum): + return value.value, [] + if isinstance(value, InputFile): + return value.attach_uri, [ + value, + ] + if isinstance(value, InputMedia) and isinstance(value.media, InputFile): + # We call to_dict and change the returned dict instead of overriding + # value.media in case the same value is reused for another request + data = value.to_dict() + data['media'] = value.media.attach_uri + + thumb = data.get('thumb', None) + if isinstance(thumb, InputFile): + data['thumb'] = thumb.attach_uri + return data, [value.media, thumb] + + return data, [value.media] + if isinstance(value, TelegramObject): + # Needs to be last, because InputMedia is a subclass of TelegramObject + return value.to_dict(), [] + return value, [] + + @classmethod + def from_input(cls, key: str, value: object) -> 'RequestParameter': + """Builds an instance of this class for a given key-value pair that represents the raw + input as passed along from a method of :class:`telegram.Bot`. + """ + if isinstance(value, list): + param_values = [] + input_files = [] + for obj in value: + param_value, input_file = cls._value_and_input_file_from_input(obj) + param_values.append(param_value) + input_files.extend(input_file) + return RequestParameter( + name=key, value=param_values, input_files=input_files if input_files else None + ) + + param_value, input_files = cls._value_and_input_file_from_input(value) + return RequestParameter( + name=key, value=param_value, input_files=input_files if input_files else None + ) diff --git a/telegram/vendor/__init__.py b/telegram/vendor/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/telegram/vendor/ptb_urllib3 b/telegram/vendor/ptb_urllib3 deleted file mode 160000 index 1954df03958..00000000000 --- a/telegram/vendor/ptb_urllib3 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1954df03958b164483282330b3a58092c070bc7a diff --git a/tests/bots.py b/tests/bots.py index 590c190579b..7311ad26f47 100644 --- a/tests/bots.py +++ b/tests/bots.py @@ -21,9 +21,6 @@ import base64 import os import random -import pytest -from telegram.request import Request -from telegram.error import RetryAfter, TimedOut # Provide some public fallbacks so it's easy for contributors to run tests on their local machine # These bots are only able to talk in our test chats, so they are quite useless for other @@ -70,19 +67,3 @@ def get(name, fallback): def get_bot(): return {k: get(k, v) for k, v in random.choice(FALLBACKS).items()} - - -# Patch request to xfail on flood control errors and TimedOut errors -original_request_wrapper = Request._request_wrapper - - -def patient_request_wrapper(*args, **kwargs): - try: - return original_request_wrapper(*args, **kwargs) - except RetryAfter as e: - pytest.xfail(f'Not waiting for flood control: {e}') - except TimedOut as e: - pytest.xfail(f'Ignoring TimedOut error: {e}') - - -Request._request_wrapper = patient_request_wrapper diff --git a/tests/conftest.py b/tests/conftest.py index bd7b5276164..6a061b41342 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio import datetime import functools import inspect @@ -27,7 +28,7 @@ from queue import Queue from threading import Thread, Event from time import sleep -from typing import Callable, List, Iterable, Any +from typing import Callable, List, Iterable, Any, Dict from types import MappingProxyType import pytest @@ -51,19 +52,21 @@ InputTextMessageContent, InlineQueryResultCachedPhoto, InputMediaPhoto, - InputMedia, ) +from telegram._utils.types import ODVInput +from telegram.constants import InputMediaType from telegram.ext import ( - Dispatcher, + Application, Defaults, ExtBot, - DispatcherBuilder, - UpdaterBuilder, + ApplicationBuilder, + Updater, ) from telegram.ext.filters import UpdateFilter, MessageFilter -from telegram.error import BadRequest +from telegram.error import BadRequest, TimedOut, RetryAfter from telegram._utils.defaultvalue import DefaultValue, DEFAULT_NONE -from telegram.request import Request +from telegram.request import RequestData +from telegram.request._httpxrequest import HTTPXRequest from tests.bots import get_bot @@ -92,14 +95,53 @@ def env_var_2_bool(env_var: object) -> bool: return env_var.lower().strip() == 'true' +# Redefine the event_loop fixture to have a session scope. Otherwise `bot` fixture can't be +# session. See https://github.com/pytest-dev/pytest-asyncio/issues/68 for more details. +@pytest.fixture(scope='session') +def event_loop(request): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + # loop.close() # instead of closing here, do that at the every end of the test session + + +# Related to the above, see https://stackoverflow.com/a/67307042/10606962 +def pytest_sessionfinish(session, exitstatus): + asyncio.get_event_loop().close() + + @pytest.fixture(scope='session') def bot_info(): return get_bot() -# Below Dict* classes are used to monkeypatch attributes since parent classes don't have __dict__ -class DictRequest(Request): - pass +# Below classes are used to monkeypatch attributes since parent classes don't have __dict__ + + +class TestHttpxRequest(HTTPXRequest): + async def _request_wrapper( + self, + method: str, + url: str, + request_data: RequestData = None, + read_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = DEFAULT_NONE, + pool_timeout: ODVInput[float] = DEFAULT_NONE, + ) -> bytes: + try: + return await super()._request_wrapper( + method=method, + url=url, + request_data=request_data, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + ) + except RetryAfter as e: + pytest.xfail(f'Not waiting for flood control: {e}') + except TimedOut as e: + pytest.xfail(f'Ignoring TimedOut error: {e}') class DictExtBot(ExtBot): @@ -110,45 +152,48 @@ class DictBot(Bot): pass -class DictDispatcher(Dispatcher): +class DictApplication(Application): pass @pytest.fixture(scope='session') -def bot(bot_info): - return DictExtBot(bot_info['token'], private_key=PRIVATE_KEY, request=DictRequest(8)) +@pytest.mark.asyncio +async def bot(bot_info): + async with DictExtBot( + bot_info['token'], + private_key=PRIVATE_KEY, + request=TestHttpxRequest(8), + get_updates_request=TestHttpxRequest(1), + ) as _bot: + yield _bot @pytest.fixture(scope='session') -def raw_bot(bot_info): - return DictBot(bot_info['token'], private_key=PRIVATE_KEY, request=DictRequest(8)) - - -DEFAULT_BOTS = {} +@pytest.mark.asyncio +async def raw_bot(bot_info): + async with DictBot( + bot_info['token'], + private_key=PRIVATE_KEY, + request=TestHttpxRequest(8), + get_updates_request=TestHttpxRequest(1), + ) as _bot: + yield _bot @pytest.fixture(scope='function') -def default_bot(request, bot_info): +async def default_bot(request, bot_info): param = request.param if hasattr(request, 'param') else {} - defaults = Defaults(**param) - default_bot = DEFAULT_BOTS.get(defaults) - if default_bot: - return default_bot - default_bot = make_bot(bot_info, **{'defaults': defaults}) - DEFAULT_BOTS[defaults] = default_bot - return default_bot + default_bot = make_bot(bot_info, defaults=Defaults(**param)) + async with default_bot: + yield default_bot @pytest.fixture(scope='function') -def tz_bot(timezone, bot_info): - defaults = Defaults(tzinfo=timezone) - default_bot = DEFAULT_BOTS.get(defaults) - if default_bot: - return default_bot - default_bot = make_bot(bot_info, **{'defaults': defaults}) - DEFAULT_BOTS[defaults] = default_bot - return default_bot +async def tz_bot(timezone, bot_info): + default_bot = make_bot(bot_info, defaults=Defaults(tzinfo=timezone)) + async with default_bot: + yield default_bot @pytest.fixture(scope='session') @@ -172,47 +217,50 @@ def provider_token(bot_info): def create_dp(bot): - # Dispatcher is heavy to init (due to many threads and such) so we have a single session - # scoped one here, but before each test, reset it (dp fixture below) - dispatcher = DispatcherBuilder().bot(bot).workers(2).dispatcher_class(DictDispatcher).build() - thr = Thread(target=dispatcher.start) + # Application is heavy to init (due to many threads and such) so we have a single session + # scoped one here, but before each test, reset it (app fixture below) + application = ( + ApplicationBuilder().bot(bot).workers(2).application_class(DictApplication).build() + ) + # TODO: Do we need the thread? + thr = Thread(target=application.start) thr.start() sleep(2) - yield dispatcher + yield application sleep(1) - if dispatcher.running: - dispatcher.stop() + if application.running: + application.stop() thr.join() @pytest.fixture(scope='session') -def _dp(bot): +def _app(bot): yield from create_dp(bot) @pytest.fixture(scope='function') -def dp(_dp): - # Reset the dispatcher first - while not _dp.update_queue.empty(): - _dp.update_queue.get(False) - _dp._chat_data = defaultdict(dict) - _dp._user_data = defaultdict(dict) - _dp.chat_data = MappingProxyType(_dp._chat_data) # Rebuild the mapping so it updates - _dp.user_data = MappingProxyType(_dp._user_data) - _dp.bot_data = {} - _dp.handlers = {} - _dp.error_handlers = {} - _dp.exception_event = Event() - _dp.__stop_event = Event() - _dp.__async_queue = Queue() - _dp.__async_threads = set() - _dp.persistence = None - yield _dp +def app(_app): + # Reset the application first + # TODO: consider just using the builder pattern to build a new object + while not _app.update_queue.empty(): + _app.update_queue.get(False) + _app._chat_data = defaultdict(dict) + _app._user_data = defaultdict(dict) + _app.chat_data = MappingProxyType(_app._chat_data) # Rebuild the mapping so it updates + _app.user_data = MappingProxyType(_app._user_data) + _app.bot_data = {} + _app.handlers = {} + _app.error_handlers = {} + _app.__stop_event = Event() + _app.__async_queue = Queue() + _app.__async_threads = set() + _app.persistence = None + yield _app @pytest.fixture(scope='function') def updater(bot): - up = UpdaterBuilder().bot(bot).workers(2).build() + up = Updater(bot=bot, update_queue=asyncio.Queue()) yield up if up.running: up.stop() @@ -249,27 +297,37 @@ def make_bot(bot_info, **kwargs): """ Tests are executed on tg.ext.ExtBot, as that class only extends the functionality of tg.bot """ - return ExtBot(bot_info['token'], private_key=PRIVATE_KEY, request=DictRequest(), **kwargs) + _bot = ExtBot( + bot_info['token'], + private_key=PRIVATE_KEY, + request=TestHttpxRequest(8), + get_updates_request=TestHttpxRequest(1), + **kwargs, + ) + return _bot CMD_PATTERN = re.compile(r'/[\da-z_]{1,32}(?:@\w{1,32})?') DATE = datetime.datetime.now() -def make_message(text, **kwargs): +async def make_message(text, **kwargs): """ Testing utility factory to create a fake ``telegram.Message`` with reasonable defaults for mimicking a real message. :param text: (str) message text :return: a (fake) ``telegram.Message`` """ + bot = kwargs.pop('bot', None) + if bot is None: + bot = make_bot(get_bot()) return Message( message_id=1, from_user=kwargs.pop('user', User(id=1, first_name='', is_bot=False)), date=kwargs.pop('date', DATE), chat=kwargs.pop('chat', Chat(id=1, type='')), text=text, - bot=kwargs.pop('bot', make_bot(get_bot())), + bot=bot, **kwargs, ) @@ -389,13 +447,13 @@ def _mro_slots(_class): return _mro_slots -def expect_bad_request(func, message, reason): +async def expect_bad_request(func, message, reason): """ Wrapper for testing bot functions expected to result in an :class:`telegram.error.BadRequest`. Makes it XFAIL, if the specified error message is present. Args: - func: The callable to be executed. + func: The awaitable to be executed. message: The expected message of the bad request error. If another message is present, the error will be reraised. reason: Explanation for the XFAIL. @@ -404,7 +462,7 @@ def expect_bad_request(func, message, reason): On success, returns the return value of :attr:`func` """ try: - return func() + return await func() except BadRequest as e: if message in str(e): pytest.xfail(f'{reason}. {e}') @@ -474,7 +532,7 @@ def check_shortcut_signature( return True -def check_shortcut_call( +async def check_shortcut_call( shortcut_method: Callable, bot: ExtBot, bot_method_name: str, @@ -484,7 +542,7 @@ def check_shortcut_call( """ Checks that a shortcut passes all the existing arguments to the underlying bot method. Use as:: - assert check_shortcut_call(message.reply_text, message.bot, 'send_message') + assert await check_shortcut_call(message.reply_text, message.bot, 'send_message') Args: shortcut_method: The shortcut method, e.g. `message.reply_text` @@ -513,7 +571,7 @@ def check_shortcut_call( # auto_pagination: Special casing for InlineQuery.answer kwargs = {name: name for name in shortcut_signature.parameters if name != 'auto_pagination'} - def make_assertion(**kw): + async def make_assertion(**kw): # name == value makes sure that # a) we receive non-None input for all parameters # b) we receive the correct input for each kwarg @@ -534,7 +592,7 @@ def make_assertion(**kw): setattr(bot, bot_method_name, make_assertion) try: - shortcut_method(**kwargs) + await shortcut_method(**kwargs) except Exception as exc: raise exc finally: @@ -595,7 +653,7 @@ def build_kwargs(signature: inspect.Signature, default_kwargs, dfv: Any = DEFAUL return kws -def check_defaults_handling( +async def check_defaults_handling( method: Callable, bot: ExtBot, return_value=None, @@ -615,10 +673,8 @@ def check_defaults_handling( kwargs_need_default = [ kwarg for kwarg, value in shortcut_signature.parameters.items() - if isinstance(value.default, DefaultValue) + if isinstance(value.default, DefaultValue) and not kwarg.endswith('_timeout') ] - # shortcut_signature.parameters['timeout'] is of type DefaultValue - method_timeout = shortcut_signature.parameters['timeout'].default.value defaults_no_custom_defaults = Defaults() kwargs = {kwarg: 'custom_default' for kwarg in inspect.signature(Defaults).parameters.keys()} @@ -627,11 +683,10 @@ def check_defaults_handling( expected_return_values = [None, []] if return_value is None else [return_value] - def make_assertion(_, data, timeout=DEFAULT_NONE, df_value=DEFAULT_NONE): - # Check timeout first - expected_timeout = method_timeout if df_value is DEFAULT_NONE else df_value - if timeout != expected_timeout: - pytest.fail(f'Got value {timeout} for "timeout", expected {expected_timeout}') + async def make_assertion( + url, request_data: RequestData, df_value=DEFAULT_NONE, *args, **kwargs + ): + data = request_data.parameters # Check regular arguments that need defaults for arg in (dkw for dkw in kwargs_need_default if dkw != 'timeout'): @@ -647,8 +702,8 @@ def make_assertion(_, data, timeout=DEFAULT_NONE, df_value=DEFAULT_NONE): pytest.fail(f'Got value {value} for argument {arg} instead of {df_value}') # Check InputMedia (parse_mode can have a default) - def check_input_media(m: InputMedia): - parse_mode = m.parse_mode + def check_input_media(m: Dict): + parse_mode = m.get('parse_mode', None) if df_value is DEFAULT_NONE: if parse_mode is not None: pytest.fail('InputMedia has non-None parse_mode') @@ -659,7 +714,7 @@ def check_input_media(m: InputMedia): media = data.pop('media', None) if media: - if isinstance(media, InputMedia): + if isinstance(media, dict) and isinstance(media.get('type', None), InputMediaType): check_input_media(media) else: for m in media: @@ -732,13 +787,13 @@ def check_input_media(m: InputMedia): ) assertion_callback = functools.partial(make_assertion, df_value=default_value) setattr(bot.request, 'post', assertion_callback) - assert method(**kwargs) in expected_return_values + assert await method(**kwargs) in expected_return_values # 2: test that we get the manually passed non-None value kwargs = build_kwargs(shortcut_signature, kwargs_need_default, dfv='non-None-value') assertion_callback = functools.partial(make_assertion, df_value='non-None-value') setattr(bot.request, 'post', assertion_callback) - assert method(**kwargs) in expected_return_values + assert await method(**kwargs) in expected_return_values # 3: test that we get the manually passed None value kwargs = build_kwargs( @@ -748,7 +803,7 @@ def check_input_media(m: InputMedia): ) assertion_callback = functools.partial(make_assertion, df_value=None) setattr(bot.request, 'post', assertion_callback) - assert method(**kwargs) in expected_return_values + assert await method(**kwargs) in expected_return_values except Exception as exc: raise exc finally: diff --git a/tests/data/text_file.txt b/tests/data/text_file.txt index c5ed74441c4..e938b357bbf 100644 --- a/tests/data/text_file.txt +++ b/tests/data/text_file.txt @@ -1 +1 @@ -PTB Rocks! \ No newline at end of file +PTB Rocks! ⅞ \ No newline at end of file diff --git a/tests/test_animation.py b/tests/test_animation.py index b07394eea0a..2061717d7b1 100644 --- a/tests/test_animation.py +++ b/tests/test_animation.py @@ -25,6 +25,7 @@ from telegram import PhotoSize, Animation, Voice, MessageEntity, Bot from telegram.error import BadRequest, TelegramError from telegram.helpers import escape_markdown +from telegram.request import RequestData from tests.conftest import ( check_shortcut_call, check_shortcut_signature, @@ -41,11 +42,12 @@ def animation_file(): @pytest.fixture(scope='class') -def animation(bot, chat_id): +@pytest.mark.asyncio +async def animation(bot, chat_id): with data_file('game.gif').open('rb') as f: thumb = data_file('thumb.jpg') - return bot.send_animation( - chat_id, animation=f, timeout=50, thumb=thumb.open('rb') + return ( + await bot.send_animation(chat_id, animation=f, read_timeout=50, thumb=thumb.open('rb')) ).animation @@ -81,8 +83,9 @@ def test_expected_values(self, animation): assert isinstance(animation.thumb, PhotoSize) @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, animation_file, animation, thumb_file): - message = bot.send_animation( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, animation_file, animation, thumb_file): + message = await bot.send_animation( chat_id, animation_file, duration=self.duration, @@ -108,29 +111,36 @@ def test_send_all_args(self, bot, chat_id, animation_file, animation, thumb_file assert message.has_protected_content @flaky(3, 1) - def test_send_animation_custom_filename(self, bot, chat_id, animation_file, monkeypatch): - def make_assertion(url, data, **kwargs): - return data['animation'].filename == 'custom_filename' + @pytest.mark.asyncio + async def test_send_animation_custom_filename(self, bot, chat_id, animation_file, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return list(request_data.multipart_data.values())[0][0] == 'custom_filename' monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_animation(chat_id, animation_file, filename='custom_filename') + assert await bot.send_animation(chat_id, animation_file, filename='custom_filename') monkeypatch.delattr(bot.request, 'post') @flaky(3, 1) - def test_get_and_download(self, bot, animation): - new_file = bot.get_file(animation.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, animation): + path = Path('game.gif') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(animation.file_id) assert new_file.file_id == animation.file_id assert new_file.file_path.startswith('https://') - new_filepath: Path = new_file.download('game.gif') + new_filepath = await new_file.download('game.gif') assert new_filepath.is_file() @flaky(3, 1) - def test_send_animation_url_file(self, bot, chat_id, animation): - message = bot.send_animation( + @pytest.mark.asyncio + async def test_send_animation_url_file(self, bot, chat_id, animation): + message = await bot.send_animation( chat_id=chat_id, animation=self.animation_file_url, caption=self.caption ) @@ -149,14 +159,15 @@ def test_send_animation_url_file(self, bot, chat_id, animation): assert message.animation.mime_type == animation.mime_type @flaky(3, 1) - def test_send_animation_caption_entities(self, bot, chat_id, animation): + @pytest.mark.asyncio + async def test_send_animation_caption_entities(self, bot, chat_id, animation): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_animation( + message = await bot.send_animation( chat_id, animation, caption=test_string, caption_entities=entities ) @@ -165,20 +176,24 @@ def test_send_animation_caption_entities(self, bot, chat_id, animation): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_animation_default_parse_mode_1(self, default_bot, chat_id, animation_file): + @pytest.mark.asyncio + async def test_send_animation_default_parse_mode_1(self, default_bot, chat_id, animation_file): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_animation(chat_id, animation_file, caption=test_markdown_string) + message = await default_bot.send_animation( + chat_id, animation_file, caption=test_markdown_string + ) assert message.caption_markdown == test_markdown_string assert message.caption == test_string @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_animation_default_parse_mode_2(self, default_bot, chat_id, animation_file): + @pytest.mark.asyncio + async def test_send_animation_default_parse_mode_2(self, default_bot, chat_id, animation_file): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_animation( + message = await default_bot.send_animation( chat_id, animation_file, caption=test_markdown_string, parse_mode=None ) assert message.caption == test_markdown_string @@ -186,27 +201,29 @@ def test_send_animation_default_parse_mode_2(self, default_bot, chat_id, animati @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_animation_default_parse_mode_3(self, default_bot, chat_id, animation_file): + @pytest.mark.asyncio + async def test_send_animation_default_parse_mode_3(self, default_bot, chat_id, animation_file): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_animation( + message = await default_bot.send_animation( chat_id, animation_file, caption=test_markdown_string, parse_mode='HTML' ) assert message.caption == test_markdown_string assert message.caption_markdown == escape_markdown(test_markdown_string) - def test_send_animation_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_animation_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('animation') == expected and data.get('thumb') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_animation(chat_id, file, thumb=file) + await bot.send_animation(chat_id, file, thumb=file) assert test_flag monkeypatch.delattr(bot, '_post') @@ -220,13 +237,14 @@ def make_assertion(_, data, *args, **kwargs): ], indirect=['default_bot'], ) - def test_send_animation_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_animation_default_allow_sending_without_reply( self, default_bot, chat_id, animation, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_animation( + message = await default_bot.send_animation( chat_id, animation, allow_sending_without_reply=custom, @@ -234,36 +252,41 @@ def test_send_animation_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_animation( + message = await default_bot.send_animation( chat_id, animation, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_animation( + await default_bot.send_animation( chat_id, animation, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_animation_default_protect_content(self, default_bot, chat_id, animation): - animation_protected = default_bot.send_animation(chat_id, animation) + async def test_send_animation_default_protect_content(self, default_bot, chat_id, animation): + animation_protected = await default_bot.send_animation(chat_id, animation) assert animation_protected.has_protected_content - ani_unprotected = default_bot.send_animation(chat_id, animation, protect_content=False) + ani_unprotected = await default_bot.send_animation( + chat_id, animation, protect_content=False + ) assert not ani_unprotected.has_protected_content @flaky(3, 1) - def test_resend(self, bot, chat_id, animation): - message = bot.send_animation(chat_id, animation.file_id) + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, animation): + message = await bot.send_animation(chat_id, animation.file_id) assert message.animation == animation - def test_send_with_animation(self, monkeypatch, bot, chat_id, animation): - def test(url, data, **kwargs): - return data['animation'] == animation.file_id + @pytest.mark.asyncio + async def test_send_with_animation(self, monkeypatch, bot, chat_id, animation): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['animation'] == animation.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_animation(animation=animation, chat_id=chat_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_animation(animation=animation, chat_id=chat_id) assert message def test_de_json(self, bot, animation): @@ -300,31 +323,35 @@ def test_to_dict(self, animation): assert animation_dict['file_size'] == animation.file_size @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): animation_file = open(os.devnull, 'rb') with pytest.raises(TelegramError): - bot.send_animation(chat_id=chat_id, animation=animation_file) + await bot.send_animation(chat_id=chat_id, animation=animation_file) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_animation(chat_id=chat_id, animation='') + await bot.send_animation(chat_id=chat_id, animation='') - def test_error_send_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_animation(chat_id=chat_id) + await bot.send_animation(chat_id=chat_id) - def test_get_file_instance_method(self, monkeypatch, animation): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, animation): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == animation.file_id assert check_shortcut_signature(Animation.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(animation.get_file, animation.get_bot(), 'get_file') - assert check_defaults_handling(animation.get_file, animation.get_bot()) + assert await check_shortcut_call(animation.get_file, animation.get_bot(), 'get_file') + assert await check_defaults_handling(animation.get_file, animation.get_bot()) monkeypatch.setattr(animation.get_bot(), 'get_file', make_assertion) - assert animation.get_file() + assert await animation.get_file() def test_equality(self): a = Animation( diff --git a/tests/test_audio.py b/tests/test_audio.py index a8f29745626..4e696a976da 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -25,6 +25,7 @@ from telegram import Audio, Voice, MessageEntity, Bot from telegram.error import TelegramError from telegram.helpers import escape_markdown +from telegram.request import RequestData from tests.conftest import ( check_shortcut_call, check_shortcut_signature, @@ -40,10 +41,13 @@ def audio_file(): @pytest.fixture(scope='class') -def audio(bot, chat_id): +@pytest.mark.asyncio +async def audio(bot, chat_id): with data_file('telegram.mp3').open('rb') as f: thumb = data_file('thumb.jpg') - return bot.send_audio(chat_id, audio=f, timeout=50, thumb=thumb.open('rb')).audio + return ( + await bot.send_audio(chat_id, audio=f, read_timeout=50, thumb=thumb.open('rb')) + ).audio class TestAudio: @@ -87,8 +91,9 @@ def test_expected_values(self, audio): assert audio.thumb.height == self.thumb_height @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, audio_file, thumb_file): - message = bot.send_audio( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, audio_file, thumb_file): + message = await bot.send_audio( chat_id, audio=audio_file, caption=self.caption, @@ -120,30 +125,39 @@ def test_send_all_args(self, bot, chat_id, audio_file, thumb_file): assert message.has_protected_content @flaky(3, 1) - def test_send_audio_custom_filename(self, bot, chat_id, audio_file, monkeypatch): - def make_assertion(url, data, **kwargs): - return data['audio'].filename == 'custom_filename' + @pytest.mark.asyncio + async def test_send_audio_custom_filename(self, bot, chat_id, audio_file, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return list(request_data.multipart_data.values())[0][0] == 'custom_filename' monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_audio(chat_id, audio_file, filename='custom_filename') + assert await bot.send_audio(chat_id, audio_file, filename='custom_filename') @flaky(3, 1) - def test_get_and_download(self, bot, audio): - new_file = bot.get_file(audio.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, audio): + path = Path('telegram.mp3') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(audio.file_id) assert new_file.file_size == self.file_size assert new_file.file_id == audio.file_id assert new_file.file_unique_id == audio.file_unique_id assert str(new_file.file_path).startswith('https://') - new_file.download('telegram.mp3') + await new_file.download('telegram.mp3') - assert Path('telegram.mp3').is_file() + assert path.is_file() @flaky(3, 1) - def test_send_mp3_url_file(self, bot, chat_id, audio): - message = bot.send_audio(chat_id=chat_id, audio=self.audio_file_url, caption=self.caption) + @pytest.mark.asyncio + async def test_send_mp3_url_file(self, bot, chat_id, audio): + message = await bot.send_audio( + chat_id=chat_id, audio=self.audio_file_url, caption=self.caption + ) assert message.caption == self.caption @@ -157,48 +171,55 @@ def test_send_mp3_url_file(self, bot, chat_id, audio): assert message.audio.file_size == audio.file_size @flaky(3, 1) - def test_resend(self, bot, chat_id, audio): - message = bot.send_audio(chat_id=chat_id, audio=audio.file_id) + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, audio): + message = await bot.send_audio(chat_id=chat_id, audio=audio.file_id) assert message.audio == audio - def test_send_with_audio(self, monkeypatch, bot, chat_id, audio): - def test(url, data, **kwargs): - return data['audio'] == audio.file_id + @pytest.mark.asyncio + async def test_send_with_audio(self, monkeypatch, bot, chat_id, audio): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['audio'] == audio.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_audio(audio=audio, chat_id=chat_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_audio(audio=audio, chat_id=chat_id) assert message @flaky(3, 1) - def test_send_audio_caption_entities(self, bot, chat_id, audio): + @pytest.mark.asyncio + async def test_send_audio_caption_entities(self, bot, chat_id, audio): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_audio(chat_id, audio, caption=test_string, caption_entities=entities) + message = await bot.send_audio( + chat_id, audio, caption=test_string, caption_entities=entities + ) assert message.caption == test_string assert message.caption_entities == entities @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_audio_default_parse_mode_1(self, default_bot, chat_id, audio_file): + @pytest.mark.asyncio + async def test_send_audio_default_parse_mode_1(self, default_bot, chat_id, audio_file): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_audio(chat_id, audio_file, caption=test_markdown_string) + message = await default_bot.send_audio(chat_id, audio_file, caption=test_markdown_string) assert message.caption_markdown == test_markdown_string assert message.caption == test_string @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_audio_default_parse_mode_2(self, default_bot, chat_id, audio_file): + @pytest.mark.asyncio + async def test_send_audio_default_parse_mode_2(self, default_bot, chat_id, audio_file): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_audio( + message = await default_bot.send_audio( chat_id, audio_file, caption=test_markdown_string, parse_mode=None ) assert message.caption == test_markdown_string @@ -206,35 +227,38 @@ def test_send_audio_default_parse_mode_2(self, default_bot, chat_id, audio_file) @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_audio_default_parse_mode_3(self, default_bot, chat_id, audio_file): + @pytest.mark.asyncio + async def test_send_audio_default_parse_mode_3(self, default_bot, chat_id, audio_file): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_audio( + message = await default_bot.send_audio( chat_id, audio_file, caption=test_markdown_string, parse_mode='HTML' ) assert message.caption == test_markdown_string assert message.caption_markdown == escape_markdown(test_markdown_string) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_audio_default_protect_content(self, default_bot, chat_id, audio): - protected_audio = default_bot.send_audio(chat_id, audio) + async def test_send_audio_default_protect_content(self, default_bot, chat_id, audio): + protected_audio = await default_bot.send_audio(chat_id, audio) assert protected_audio.has_protected_content - unprotected = default_bot.send_audio(chat_id, audio, protect_content=False) + unprotected = await default_bot.send_audio(chat_id, audio, protect_content=False) assert not unprotected.has_protected_content - def test_send_audio_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_audio_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('audio') == expected and data.get('thumb') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_audio(chat_id, file, thumb=file) + await bot.send_audio(chat_id, file, thumb=file) assert test_flag monkeypatch.delattr(bot, '_post') @@ -275,31 +299,35 @@ def test_to_dict(self, audio): assert audio_dict['file_name'] == audio.file_name @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): audio_file = open(os.devnull, 'rb') with pytest.raises(TelegramError): - bot.send_audio(chat_id=chat_id, audio=audio_file) + await bot.send_audio(chat_id=chat_id, audio=audio_file) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_audio(chat_id=chat_id, audio='') + await bot.send_audio(chat_id=chat_id, audio='') - def test_error_send_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_audio(chat_id=chat_id) + await bot.send_audio(chat_id=chat_id) - def test_get_file_instance_method(self, monkeypatch, audio): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, audio): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == audio.file_id assert check_shortcut_signature(Audio.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(audio.get_file, audio.get_bot(), 'get_file') - assert check_defaults_handling(audio.get_file, audio.get_bot()) + assert await check_shortcut_call(audio.get_file, audio.get_bot(), 'get_file') + assert await check_defaults_handling(audio.get_file, audio.get_bot()) monkeypatch.setattr(audio._bot, 'get_file', make_assertion) - assert audio.get_file() + assert await audio.get_file() def test_equality(self, audio): a = Audio(audio.file_id, audio.file_unique_id, audio.duration) diff --git a/tests/test_bot.py b/tests/test_bot.py index 72e100c750e..d020c884402 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import datetime +import asyncio import inspect import logging import time @@ -58,6 +59,7 @@ from telegram._utils.datetime import from_timestamp, to_timestamp from telegram._utils.defaultvalue import DefaultValue from telegram.helpers import escape_markdown +from telegram.request import RequestData, BaseRequest from tests.conftest import ( expect_bad_request, check_defaults_handling, @@ -87,11 +89,12 @@ class BotSubClass(Bot): @pytest.fixture(scope='class') -def message(bot, chat_id): - to_reply_to = bot.send_message( +@pytest.mark.asyncio +async def message(bot, chat_id): + to_reply_to = await bot.send_message( chat_id, 'Text', disable_web_page_preview=True, disable_notification=True ) - return bot.send_message( + return await bot.send_message( chat_id, 'Text', reply_to_message_id=to_reply_to.message_id, @@ -101,9 +104,10 @@ def message(bot, chat_id): @pytest.fixture(scope='class') -def media_message(bot, chat_id): +@pytest.mark.asyncio +async def media_message(bot, chat_id): with data_file('telegram.ogg').open('rb') as f: - return bot.send_voice(chat_id, voice=f, caption='my caption', timeout=10) + return await bot.send_voice(chat_id, voice=f, caption='my caption', read_timeout=10) @pytest.fixture(scope='class') @@ -137,8 +141,13 @@ def inline_results(): @pytest.fixture(scope='function') -def inst(request, bot_info, default_bot): - return Bot(bot_info['token']) if request.param == 'bot' else default_bot +@pytest.mark.asyncio +async def inst(request, bot_info, default_bot): + if request.param == 'bot': + async with Bot(bot_info['token']) as _bot: + yield _bot + else: + yield default_bot class TestBot: @@ -146,6 +155,12 @@ class TestBot: Most are executed on tg.ext.ExtBot, as that class only extends the functionality of tg.bot """ + test_flag = None + + @pytest.fixture(scope='function', autouse=True) + def reset(self): + self.test_flag = None + @pytest.mark.parametrize('inst', ['bot', "default_bot"], indirect=True) def test_slot_behaviour(self, inst, mro_slots): for attr in inst.__slots__: @@ -164,13 +179,72 @@ def test_slot_behaviour(self, inst, mro_slots): '1234:abcd 1234', ], ) - def test_invalid_token(self, token): + @pytest.mark.asyncio + async def test_invalid_token(self, token): with pytest.raises(InvalidToken, match='Invalid token'): Bot(token) - def test_log_decorator(self, bot, caplog): - with caplog.at_level(logging.DEBUG): - bot.get_me() + @pytest.mark.asyncio + async def test_initialize_and_stop(self, bot, monkeypatch): + async def initialize(*args, **kwargs): + self.test_flag = ['initialize'] + + async def stop(*args, **kwargs): + self.test_flag.append('stop') + + temp_bot = Bot(token=bot.token) + orig_stop = temp_bot.request.shutdown + + try: + monkeypatch.setattr(temp_bot.request, 'initialize', initialize) + monkeypatch.setattr(temp_bot.request, 'shutdown', stop) + await temp_bot.initialize() + assert self.test_flag == ['initialize'] + assert temp_bot.bot == bot.bot + + await temp_bot.shutdown() + assert self.test_flag == ['initialize', 'stop'] + finally: + await orig_stop() + + @pytest.mark.asyncio + async def test_context_manager(self, monkeypatch, bot): + async def initialize(): + self.test_flag = ['initialize'] + + async def shutdown(*args): + self.test_flag.append('stop') + + monkeypatch.setattr(bot, 'initialize', initialize) + monkeypatch.setattr(bot, 'shutdown', shutdown) + + async with bot: + pass + + assert self.test_flag == ['initialize', 'stop'] + + @pytest.mark.asyncio + async def test_context_manager_exception_on_init(self, monkeypatch, bot): + async def initialize(): + raise RuntimeError('initialize') + + async def shutdown(): + self.test_flag = 'stop' + + monkeypatch.setattr(bot, 'initialize', initialize) + monkeypatch.setattr(bot, 'shutdown', shutdown) + + with pytest.raises(RuntimeError, match='initialize'): + async with bot: + pass + + assert self.test_flag == 'stop' + + @pytest.mark.asyncio + async def test_log_decorator(self, bot, caplog): + # Second argument makes sure that we ignore logs from e.g. httpx + with caplog.at_level(logging.DEBUG, logger='telegram'): + await bot.get_me() assert len(caplog.records) == 3 assert caplog.records[0].getMessage().startswith('Entering: get_me') assert caplog.records[-1].getMessage().startswith('Exiting: get_me') @@ -179,29 +253,37 @@ def test_log_decorator(self, bot, caplog): 'acd_in,maxsize,acd', [(True, 1024, True), (False, 1024, False), (0, 0, True), (None, None, True)], ) - def test_callback_data_maxsize(self, bot, acd_in, maxsize, acd): - bot = ExtBot(bot.token, arbitrary_callback_data=acd_in) - assert bot.arbitrary_callback_data == acd - assert bot.callback_data_cache.maxsize == maxsize + @pytest.mark.asyncio + async def test_callback_data_maxsize(self, bot, acd_in, maxsize, acd): + async with ExtBot(bot.token, arbitrary_callback_data=acd_in) as acd_bot: + assert acd_bot.arbitrary_callback_data == acd + assert acd_bot.callback_data_cache.maxsize == maxsize @flaky(3, 1) - def test_invalid_token_server_response(self, monkeypatch): + @pytest.mark.asyncio + async def test_invalid_token_server_response(self, monkeypatch): monkeypatch.setattr('telegram.Bot._validate_token', lambda x, y: '') - bot = Bot('12') with pytest.raises(InvalidToken): - bot.get_me() + async with Bot('12') as bot: + await bot.get_me() - def test_unknown_kwargs(self, bot, monkeypatch): - def post(url, data, timeout): - assert data['unknown_kwarg_1'] == 7 - assert data['unknown_kwarg_2'] == 5 + @pytest.mark.asyncio + async def test_unknown_kwargs(self, bot, monkeypatch): + async def post(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters + if not all([data['unknown_kwarg_1'] == '7', data['unknown_kwarg_2'] == '5']): + pytest.fail('got wrong parameters') + return True monkeypatch.setattr(bot.request, 'post', post) - bot.send_message(123, 'text', api_kwargs={'unknown_kwarg_1': 7, 'unknown_kwarg_2': 5}) + await bot.send_message( + 123, 'text', api_kwargs={'unknown_kwarg_1': 7, 'unknown_kwarg_2': 5} + ) @flaky(3, 1) - def test_get_me_and_properties(self, bot): - get_me_bot = bot.get_me() + @pytest.mark.asyncio + async def test_get_me_and_properties(self, bot: Bot): + get_me_bot = await bot.get_me() assert isinstance(get_me_bot, User) assert get_me_bot.id == bot.id @@ -214,24 +296,49 @@ def test_get_me_and_properties(self, bot): assert get_me_bot.supports_inline_queries == bot.supports_inline_queries assert f'https://t.me/{get_me_bot.username}' == bot.link - def test_equality(self): - a = Bot(FALLBACKS[0]["token"]) - b = Bot(FALLBACKS[0]["token"]) - c = Bot(FALLBACKS[1]["token"]) - d = Update(123456789) + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'attribute', + [ + 'id', + 'username', + 'first_name', + 'last_name', + 'name', + 'can_join_groups', + 'can_read_all_group_messages', + 'supports_inline_queries', + 'link', + ], + ) + async def test_get_me_and_properties_not_initialized(self, bot: Bot, attribute): + bot = Bot(token=bot.token) + try: + with pytest.raises(RuntimeError, match='not properly initialized'): + bot[attribute] + finally: + await bot.shutdown() + + @pytest.mark.asyncio + async def test_equality(self): + async with Bot(FALLBACKS[0]["token"]) as a, Bot(FALLBACKS[0]["token"]) as b, Bot( + FALLBACKS[1]["token"] + ) as c: + d = Update(123456789) - assert a == b - assert hash(a) == hash(b) - assert a is not b + assert a == b + assert hash(a) == hash(b) + assert a is not b - assert a != c - assert hash(a) != hash(c) + assert a != c + assert hash(a) != hash(c) - assert a != d - assert hash(a) != hash(d) + assert a != d + assert hash(a) != hash(d) @flaky(3, 1) - def test_to_dict(self, bot): + @pytest.mark.asyncio + async def test_to_dict(self, bot): to_dict_bot = bot.to_dict() assert isinstance(to_dict_bot, dict) @@ -258,10 +365,13 @@ def test_to_dict(self, bot): 'getUpdates', 'get_bot', 'set_bot', + 'initialize', + 'shutdown', ] ], ) - def test_defaults_handling(self, bot_method_name, bot, raw_bot, monkeypatch): + @pytest.mark.asyncio + async def test_defaults_handling(self, bot_method_name, bot, raw_bot, monkeypatch): """ Here we check that the bot methods handle tg.ext.Defaults correctly. This has two parts: @@ -281,27 +391,84 @@ def test_defaults_handling(self, bot_method_name, bot, raw_bot, monkeypatch): Finally, there are some tests for Defaults.{parse_mode, quote, allow_sending_without_reply} at the appropriate places, as those are the only things we can actually check. """ - # Check that ExtBot does the right thing - bot_method = getattr(bot, bot_method_name) - assert check_defaults_handling(bot_method, bot) + try: + # Check that ExtBot does the right thing + bot_method = getattr(bot, bot_method_name) + assert await check_defaults_handling(bot_method, bot) + + # check that tg.Bot does the right thing + # make_assertion basically checks everything that happens in + # Bot._insert_defaults and Bot._insert_defaults_for_ilq_results + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + json_data = request_data.parameters + + # Check regular kwargs + for k, v in json_data.items(): + if isinstance(v, DefaultValue): + pytest.fail(f'Parameter {k} was passed as DefaultValue to request') + elif isinstance(v, InputMedia) and isinstance(v.parse_mode, DefaultValue): + pytest.fail(f'Parameter {k} has a DefaultValue parse_mode') + # Check InputMedia + elif k == 'media' and isinstance(v, list): + if any(isinstance(med.get('parse_mode'), DefaultValue) for med in v): + pytest.fail('One of the media items has a DefaultValue parse_mode') + + # Check inline query results + if bot_method_name.lower().replace('_', '') == 'answerinlinequery': + for result_dict in json_data['results']: + if isinstance(result_dict.get('parse_mode'), DefaultValue): + pytest.fail('InlineQueryResult has DefaultValue parse_mode') + imc = result_dict.get('input_message_content') + if imc and isinstance(imc.get('parse_mode'), DefaultValue): + pytest.fail( + 'InlineQueryResult is InputMessageContext with DefaultValue ' + 'parse_mode ' + ) + if imc and isinstance(imc.get('disable_web_page_preview'), DefaultValue): + pytest.fail( + 'InlineQueryResult is InputMessageContext with DefaultValue ' + 'disable_web_page_preview ' + ) + # Check datetime conversion + until_date = json_data.pop('until_date', None) + if until_date and until_date != 946684800: + pytest.fail('Naive until_date was not interpreted as UTC') + + if bot_method_name in ['get_file', 'getFile']: + # The get_file methods try to check if the result is a local file + return File(file_id='result', file_unique_id='result').to_dict() + + method = getattr(raw_bot, bot_method_name) + signature = inspect.signature(method) + kwargs_need_default = [ + kwarg + for kwarg, value in signature.parameters.items() + if isinstance(value.default, DefaultValue) + ] + monkeypatch.setattr(raw_bot.request, 'post', make_assertion) + await method(**build_kwargs(inspect.signature(method), kwargs_need_default)) + finally: + await bot.get_me() # because running the mock-get_me messages with bot.bot & friends # check that tg.Bot does the right thing # make_assertion basically checks everything that happens in # Bot._insert_defaults and Bot._insert_defaults_for_ilq_results - def make_assertion(_, data, timeout=None): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters + # Check regular kwargs for k, v in data.items(): if isinstance(v, DefaultValue): pytest.fail(f'Parameter {k} was passed as DefaultValue to request') elif isinstance(v, InputMedia) and isinstance(v.parse_mode, DefaultValue): pytest.fail(f'Parameter {k} has a DefaultValue parse_mode') + # Check InputMedia elif k == 'media' and isinstance(v, list): - if any(isinstance(med.parse_mode, DefaultValue) for med in v): - pytest.fail('One of the media items has a DefaultValue parse_mode') - # Check timeout - if isinstance(timeout, DefaultValue): - pytest.fail('Parameter timeout was passed as DefaultValue to request') + for med in v: + if isinstance(med.get('parse_mode', None), DefaultValue): + pytest.fail('One of the media items has a DefaultValue parse_mode') + # Check inline query results if bot_method_name.lower().replace('_', '') == 'answerinlinequery': for result_dict in data['results']: @@ -334,7 +501,7 @@ def make_assertion(_, data, timeout=None): if isinstance(value.default, DefaultValue) ] monkeypatch.setattr(raw_bot.request, 'post', make_assertion) - method(**build_kwargs(inspect.signature(method), kwargs_need_default)) + await method(**build_kwargs(inspect.signature(method), kwargs_need_default)) def test_ext_bot_signature(self): """ @@ -373,8 +540,9 @@ def test_ext_bot_signature(self): ), f'Wrong parameter kind for parameter {param_name} of method {name}' @flaky(3, 1) - def test_forward_message(self, bot, chat_id, message): - forward_message = bot.forward_message( + @pytest.mark.asyncio + async def test_forward_message(self, bot, chat_id, message): + forward_message = await bot.forward_message( chat_id, from_chat_id=chat_id, message_id=message.message_id ) @@ -382,39 +550,49 @@ def test_forward_message(self, bot, chat_id, message): assert forward_message.forward_from.username == message.from_user.username assert isinstance(forward_message.forward_date, dtm.datetime) - def test_forward_protected_message(self, bot, message, chat_id): - to_forward_protected = bot.send_message(chat_id, 'cant forward me', protect_content=True) + @pytest.mark.asyncio + async def test_forward_protected_message(self, bot, message, chat_id): + to_forward_protected = await bot.send_message( + chat_id, 'cant forward me', protect_content=True + ) assert to_forward_protected.has_protected_content with pytest.raises(BadRequest, match="can't be forwarded"): - to_forward_protected.forward(chat_id) + await to_forward_protected.forward(chat_id) - to_forward_unprotected = bot.send_message(chat_id, 'forward me', protect_content=False) + to_forward_unprotected = await bot.send_message( + chat_id, 'forward me', protect_content=False + ) assert not to_forward_unprotected.has_protected_content - forwarded_but_now_protected = to_forward_unprotected.forward(chat_id, protect_content=True) + forwarded_but_now_protected = await to_forward_unprotected.forward( + chat_id, protect_content=True + ) assert forwarded_but_now_protected.has_protected_content with pytest.raises(BadRequest, match="can't be forwarded"): - forwarded_but_now_protected.forward(chat_id) + await forwarded_but_now_protected.forward(chat_id) @flaky(3, 1) - def test_delete_message(self, bot, chat_id): - message = bot.send_message(chat_id, text='will be deleted') - time.sleep(2) + @pytest.mark.asyncio + async def test_delete_message(self, bot, chat_id): + message = await bot.send_message(chat_id, text='will be deleted') + await asyncio.sleep(2) - assert bot.delete_message(chat_id=chat_id, message_id=message.message_id) is True + assert await bot.delete_message(chat_id=chat_id, message_id=message.message_id) is True @flaky(3, 1) - def test_delete_message_old_message(self, bot, chat_id): + @pytest.mark.asyncio + async def test_delete_message_old_message(self, bot, chat_id): with pytest.raises(BadRequest): # Considering that the first message is old enough - bot.delete_message(chat_id=chat_id, message_id=1) + await bot.delete_message(chat_id=chat_id, message_id=1) # send_photo, send_audio, send_document, send_sticker, send_video, send_voice, send_video_note, # send_media_group and send_animation are tested in their respective test modules. No need to # duplicate here. @flaky(3, 1) - def test_send_venue(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_venue(self, bot, chat_id): longitude = -46.788279 latitude = -23.691288 title = 'title' @@ -424,7 +602,7 @@ def test_send_venue(self, bot, chat_id): google_place_id = 'google_place id' google_place_type = 'google_place type' - message = bot.send_venue( + message = await bot.send_venue( chat_id=chat_id, title=title, address=address, @@ -446,7 +624,7 @@ def test_send_venue(self, bot, chat_id): assert message.venue.google_place_type is None assert message.has_protected_content - message = bot.send_venue( + message = await bot.send_venue( chat_id=chat_id, title=title, address=address, @@ -469,11 +647,12 @@ def test_send_venue(self, bot, chat_id): assert message.has_protected_content @flaky(3, 1) - def test_send_contact(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_contact(self, bot, chat_id): phone_number = '+11234567890' first_name = 'Leandro' last_name = 'Toledo' - message = bot.send_contact( + message = await bot.send_contact( chat_id=chat_id, phone_number=phone_number, first_name=first_name, @@ -502,16 +681,17 @@ def test_send_contact(self, bot, chat_id): ).to_dict(), ], ) - def test_send_and_stop_poll(self, bot, super_group_id, reply_markup): + @pytest.mark.asyncio + async def test_send_and_stop_poll(self, bot, super_group_id, reply_markup): question = 'Is this a test?' answers = ['Yes', 'No', 'Maybe'] - message = bot.send_poll( + message = await bot.send_poll( chat_id=super_group_id, question=question, options=answers, is_anonymous=False, allows_multiple_answers=True, - timeout=60, + read_timeout=60, protect_content=True, ) @@ -528,11 +708,11 @@ def test_send_and_stop_poll(self, bot, super_group_id, reply_markup): # Since only the poll and not the complete message is returned, we can't check that the # reply_markup is correct. So we just test that sending doesn't give an error. - poll = bot.stop_poll( + poll = await bot.stop_poll( chat_id=super_group_id, message_id=message.message_id, reply_markup=reply_markup, - timeout=60, + read_timeout=60, ) assert isinstance(poll, Poll) assert poll.is_closed @@ -549,7 +729,7 @@ def test_send_and_stop_poll(self, bot, super_group_id, reply_markup): explanation_entities = [ MessageEntity(MessageEntity.TEXT_LINK, 0, 14, url='https://google.com') ] - message_quiz = bot.send_poll( + message_quiz = await bot.send_poll( chat_id=super_group_id, question=question, options=answers, @@ -566,8 +746,11 @@ def test_send_and_stop_poll(self, bot, super_group_id, reply_markup): assert message_quiz.poll.explanation_entities == explanation_entities @flaky(3, 1) - @pytest.mark.parametrize(['open_period', 'close_date'], [(5, None), (None, True)]) - def test_send_open_period(self, bot, super_group_id, open_period, close_date): + @pytest.mark.parametrize( + ['open_period', 'close_date'], [(5, None), (None, True)], ids=['open_period', 'close_date'] + ) + @pytest.mark.asyncio + async def test_send_open_period(self, bot, super_group_id, open_period, close_date): question = 'Is this a test?' answers = ['Yes', 'No', 'Maybe'] reply_markup = InlineKeyboardMarkup.from_button( @@ -575,30 +758,31 @@ def test_send_open_period(self, bot, super_group_id, open_period, close_date): ) if close_date: - close_date = dtm.datetime.utcnow() + dtm.timedelta(seconds=5) + close_date = dtm.datetime.utcnow() + dtm.timedelta(seconds=5.1) - message = bot.send_poll( + message = await bot.send_poll( chat_id=super_group_id, question=question, options=answers, is_anonymous=False, allows_multiple_answers=True, - timeout=60, + read_timeout=60, open_period=open_period, close_date=close_date, ) - time.sleep(5.1) - new_message = bot.edit_message_reply_markup( + await asyncio.sleep(5.2) + new_message = await bot.edit_message_reply_markup( chat_id=super_group_id, message_id=message.message_id, reply_markup=reply_markup, - timeout=60, + read_timeout=60, ) assert new_message.poll.id == message.poll.id assert new_message.poll.is_closed @flaky(3, 1) - def test_send_close_date_default_tz(self, tz_bot, super_group_id): + @pytest.mark.asyncio + async def test_send_close_date_default_tz(self, tz_bot, super_group_id): question = 'Is this a test?' answers = ['Yes', 'No', 'Maybe'] reply_markup = InlineKeyboardMarkup.from_button( @@ -608,12 +792,12 @@ def test_send_close_date_default_tz(self, tz_bot, super_group_id): aware_close_date = dtm.datetime.now(tz=tz_bot.defaults.tzinfo) + dtm.timedelta(seconds=5) close_date = aware_close_date.replace(tzinfo=None) - msg = tz_bot.send_poll( # The timezone returned from this is always converted to UTC + msg = await tz_bot.send_poll( # The timezone returned from this is always converted to UTC chat_id=super_group_id, question=question, options=answers, close_date=close_date, - timeout=60, + read_timeout=60, ) # Sometimes there can be a few seconds delay, so don't let the test fail due to that- msg.poll.close_date = msg.poll.close_date.astimezone(aware_close_date.tzinfo) @@ -621,24 +805,25 @@ def test_send_close_date_default_tz(self, tz_bot, super_group_id): time.sleep(5.1) - new_message = tz_bot.edit_message_reply_markup( + new_message = await tz_bot.edit_message_reply_markup( chat_id=super_group_id, message_id=msg.message_id, reply_markup=reply_markup, - timeout=60, + read_timeout=60, ) assert new_message.poll.id == msg.poll.id assert new_message.poll.is_closed @flaky(3, 1) - def test_send_poll_explanation_entities(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_poll_explanation_entities(self, bot, chat_id): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_poll( + message = await bot.send_poll( chat_id, 'question', options=['a', 'b'], @@ -653,13 +838,14 @@ def test_send_poll_explanation_entities(self, bot, chat_id): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_poll_default_parse_mode(self, default_bot, super_group_id): + @pytest.mark.asyncio + async def test_send_poll_default_parse_mode(self, default_bot, super_group_id): explanation = 'Italic Bold Code' explanation_markdown = '_Italic_ *Bold* `Code`' question = 'Is this a test?' answers = ['Yes', 'No', 'Maybe'] - message = default_bot.send_poll( + message = await default_bot.send_poll( chat_id=super_group_id, question=question, options=answers, @@ -675,7 +861,7 @@ def test_send_poll_default_parse_mode(self, default_bot, super_group_id): MessageEntity(MessageEntity.CODE, 12, 4), ] - message = default_bot.send_poll( + message = await default_bot.send_poll( chat_id=super_group_id, question=question, options=answers, @@ -688,7 +874,7 @@ def test_send_poll_default_parse_mode(self, default_bot, super_group_id): assert message.poll.explanation == explanation_markdown assert message.poll.explanation_entities == [] - message = default_bot.send_poll( + message = await default_bot.send_poll( chat_id=super_group_id, question=question, options=answers, @@ -711,13 +897,16 @@ def test_send_poll_default_parse_mode(self, default_bot, super_group_id): ], indirect=['default_bot'], ) - def test_send_poll_default_allow_sending_without_reply(self, default_bot, chat_id, custom): + @pytest.mark.asyncio + async def test_send_poll_default_allow_sending_without_reply( + self, default_bot, chat_id, custom + ): question = 'Is this a test?' answers = ['Yes', 'No', 'Maybe'] - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_poll( + message = await default_bot.send_poll( chat_id, question=question, options=answers, @@ -726,7 +915,7 @@ def test_send_poll_default_allow_sending_without_reply(self, default_bot, chat_i ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_poll( + message = await default_bot.send_poll( chat_id, question=question, options=answers, @@ -735,7 +924,7 @@ def test_send_poll_default_allow_sending_without_reply(self, default_bot, chat_i assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_poll( + await default_bot.send_poll( chat_id, question=question, options=answers, @@ -743,17 +932,21 @@ def test_send_poll_default_allow_sending_without_reply(self, default_bot, chat_i ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_poll_default_protect_content(self, chat_id, default_bot): - protected_poll = default_bot.send_poll(chat_id, 'Test', ['1', '2']) + async def test_send_poll_default_protect_content(self, chat_id, default_bot): + protected_poll = await default_bot.send_poll(chat_id, 'Test', ['1', '2']) assert protected_poll.has_protected_content - unprotect_poll = default_bot.send_poll(chat_id, 'test', ['1', '2'], protect_content=False) + unprotect_poll = await default_bot.send_poll( + chat_id, 'test', ['1', '2'], protect_content=False + ) assert not unprotect_poll.has_protected_content @flaky(3, 1) @pytest.mark.parametrize('emoji', Dice.ALL_EMOJI + [None]) - def test_send_dice(self, bot, chat_id, emoji): - message = bot.send_dice(chat_id, emoji=emoji, protect_content=True) + @pytest.mark.asyncio + async def test_send_dice(self, bot, chat_id, emoji): + message = await bot.send_dice(chat_id, emoji=emoji, protect_content=True) assert message.dice assert message.has_protected_content @@ -772,32 +965,38 @@ def test_send_dice(self, bot, chat_id, emoji): ], indirect=['default_bot'], ) - def test_send_dice_default_allow_sending_without_reply(self, default_bot, chat_id, custom): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + @pytest.mark.asyncio + async def test_send_dice_default_allow_sending_without_reply( + self, default_bot, chat_id, custom + ): + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_dice( + message = await default_bot.send_dice( chat_id, allow_sending_without_reply=custom, reply_to_message_id=reply_to_message.message_id, ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_dice( + message = await default_bot.send_dice( chat_id, reply_to_message_id=reply_to_message.message_id, ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_dice(chat_id, reply_to_message_id=reply_to_message.message_id) + await default_bot.send_dice( + chat_id, reply_to_message_id=reply_to_message.message_id + ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_dice_default_protect_content(self, chat_id, default_bot): - protected_dice = default_bot.send_dice(chat_id) + async def test_send_dice_default_protect_content(self, chat_id, default_bot): + protected_dice = await default_bot.send_dice(chat_id) assert protected_dice.has_protected_content - unprotected_dice = default_bot.send_dice(chat_id, protect_content=False) + unprotected_dice = await default_bot.send_dice(chat_id, protect_content=False) assert not unprotected_dice.has_protected_content @flaky(3, 1) @@ -817,16 +1016,18 @@ def test_send_dice_default_protect_content(self, chat_id, default_bot): ChatAction.CHOOSE_STICKER, ], ) - def test_send_chat_action(self, bot, chat_id, chat_action): - assert bot.send_chat_action(chat_id, chat_action) + @pytest.mark.asyncio + async def test_send_chat_action(self, bot, chat_id, chat_action): + assert await bot.send_chat_action(chat_id, chat_action) with pytest.raises(BadRequest, match='Wrong parameter action'): - bot.send_chat_action(chat_id, 'unknown action') + await bot.send_chat_action(chat_id, 'unknown action') # TODO: Needs improvement. We need incoming inline query to test answer. - def test_answer_inline_query(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_inline_query(self, monkeypatch, bot): # For now just test that our internals pass the correct data - def test(url, data, *args, **kwargs): - return data == { + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'cache_time': 300, 'results': [ { @@ -849,13 +1050,13 @@ def test(url, data, *args, **kwargs): 'switch_pm_text': 'switch pm', } - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) results = [ InlineQueryResultArticle('11', 'first', InputTextMessageContent('first')), InlineQueryResultArticle('12', 'second', InputTextMessageContent('second')), ] - assert bot.answer_inline_query( + assert await bot.answer_inline_query( 1234, results=results, cache_time=300, @@ -866,9 +1067,10 @@ def test(url, data, *args, **kwargs): ) monkeypatch.delattr(bot.request, 'post') - def test_answer_inline_query_no_default_parse_mode(self, monkeypatch, bot): - def test(url, data, *args, **kwargs): - return data == { + @pytest.mark.asyncio + async def test_answer_inline_query_no_default_parse_mode(self, monkeypatch, bot): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'cache_time': 300, 'results': [ { @@ -889,7 +1091,7 @@ def test(url, data, *args, **kwargs): 'switch_pm_text': 'switch pm', } - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) results = [ InlineQueryResultDocument( id='123', @@ -901,7 +1103,7 @@ def test(url, data, *args, **kwargs): ) ] - assert bot.answer_inline_query( + assert await bot.answer_inline_query( 1234, results=results, cache_time=300, @@ -912,9 +1114,10 @@ def test(url, data, *args, **kwargs): ) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_answer_inline_query_default_parse_mode(self, monkeypatch, default_bot): - def test(url, data, *args, **kwargs): - return data == { + @pytest.mark.asyncio + async def test_answer_inline_query_default_parse_mode(self, monkeypatch, default_bot): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'cache_time': 300, 'results': [ { @@ -936,7 +1139,7 @@ def test(url, data, *args, **kwargs): 'switch_pm_text': 'switch pm', } - monkeypatch.setattr(default_bot.request, 'post', test) + monkeypatch.setattr(default_bot.request, 'post', make_assertion) results = [ InlineQueryResultDocument( id='123', @@ -948,7 +1151,7 @@ def test(url, data, *args, **kwargs): ) ] - assert default_bot.answer_inline_query( + assert await default_bot.answer_inline_query( 1234, results=results, cache_time=300, @@ -958,9 +1161,10 @@ def test(url, data, *args, **kwargs): switch_pm_parameter='start_pm', ) - def test_answer_inline_query_current_offset_error(self, bot, inline_results): + @pytest.mark.asyncio + async def test_answer_inline_query_current_offset_error(self, bot, inline_results): with pytest.raises(ValueError, match=('`current_offset` and `next_offset`')): - bot.answer_inline_query( + await bot.answer_inline_query( 1234, results=inline_results, next_offset=42, current_offset=51 ) @@ -972,7 +1176,8 @@ def test_answer_inline_query_current_offset_error(self, bot, inline_results): (5, 3, 251, ''), ], ) - def test_answer_inline_query_current_offset_1( + @pytest.mark.asyncio + async def test_answer_inline_query_current_offset_1( self, monkeypatch, bot, @@ -983,7 +1188,8 @@ def test_answer_inline_query_current_offset_1( expected_next_offset, ): # For now just test that our internals pass the correct data - def make_assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters results = data['results'] length_matches = len(results) == num_results ids_match = all(int(res['id']) == id_offset + i for i, res in enumerate(results)) @@ -992,11 +1198,15 @@ def make_assertion(url, data, *args, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.answer_inline_query(1234, results=inline_results, current_offset=current_offset) + assert await bot.answer_inline_query( + 1234, results=inline_results, current_offset=current_offset + ) - def test_answer_inline_query_current_offset_2(self, monkeypatch, bot, inline_results): + @pytest.mark.asyncio + async def test_answer_inline_query_current_offset_2(self, monkeypatch, bot, inline_results): # For now just test that our internals pass the correct data - def make_assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters results = data['results'] length_matches = len(results) == InlineQueryLimit.RESULTS ids_match = all(int(res['id']) == 1 + i for i, res in enumerate(results)) @@ -1005,11 +1215,12 @@ def make_assertion(url, data, *args, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.answer_inline_query(1234, results=inline_results, current_offset=0) + assert await bot.answer_inline_query(1234, results=inline_results, current_offset=0) inline_results = inline_results[:30] - def make_assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters results = data['results'] length_matches = len(results) == 30 ids_match = all(int(res['id']) == 1 + i for i, res in enumerate(results)) @@ -1018,11 +1229,13 @@ def make_assertion(url, data, *args, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.answer_inline_query(1234, results=inline_results, current_offset=0) + assert await bot.answer_inline_query(1234, results=inline_results, current_offset=0) - def test_answer_inline_query_current_offset_callback(self, monkeypatch, bot, caplog): + @pytest.mark.asyncio + async def test_answer_inline_query_current_offset_callback(self, monkeypatch, bot, caplog): # For now just test that our internals pass the correct data - def make_assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters results = data['results'] length = len(results) == 5 ids = all(int(res['id']) == 6 + i for i, res in enumerate(results)) @@ -1031,9 +1244,12 @@ def make_assertion(url, data, *args, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.answer_inline_query(1234, results=inline_results_callback, current_offset=1) + assert await bot.answer_inline_query( + 1234, results=inline_results_callback, current_offset=1 + ) - def make_assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters results = data['results'] length = results == [] next_offset = data['next_offset'] == '' @@ -1041,25 +1257,30 @@ def make_assertion(url, data, *args, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.answer_inline_query(1234, results=inline_results_callback, current_offset=6) + assert await bot.answer_inline_query( + 1234, results=inline_results_callback, current_offset=6 + ) @flaky(3, 1) - def test_get_user_profile_photos(self, bot, chat_id): - user_profile_photos = bot.get_user_profile_photos(chat_id) + @pytest.mark.asyncio + async def test_get_user_profile_photos(self, bot, chat_id): + user_profile_photos = await bot.get_user_profile_photos(chat_id) assert user_profile_photos.photos[0][0].file_size == 5403 @flaky(3, 1) - def test_get_one_user_profile_photo(self, bot, chat_id): - user_profile_photos = bot.get_user_profile_photos(chat_id, offset=0, limit=1) + @pytest.mark.asyncio + async def test_get_one_user_profile_photo(self, bot, chat_id): + user_profile_photos = await bot.get_user_profile_photos(chat_id, offset=0, limit=1) assert user_profile_photos.photos[0][0].file_size == 5403 # get_file is tested multiple times in the test_*media* modules. # Here we only test the behaviour for bot apis in local mode - def test_get_file_local_mode(self, bot, monkeypatch): + @pytest.mark.asyncio + async def test_get_file_local_mode(self, bot, monkeypatch): path = str(data_file('game.gif')) - def _post(*args, **kwargs): + async def _post(*args, **kwargs): return { 'file_id': None, 'file_unique_id': None, @@ -1069,60 +1290,68 @@ def _post(*args, **kwargs): monkeypatch.setattr(bot, '_post', _post) - resulting_path = bot.get_file('file_id').file_path + resulting_path = (await bot.get_file('file_id')).file_path assert bot.token not in resulting_path assert resulting_path == path monkeypatch.delattr(bot, '_post') # TODO: Needs improvement. No feasible way to test until bots can add members. - def test_ban_chat_member(self, monkeypatch, bot): - def test(url, data, *args, **kwargs): - chat_id = data['chat_id'] == 2 - user_id = data['user_id'] == 32 - until_date = data.get('until_date', 1577887200) == 1577887200 - revoke_msgs = data.get('revoke_messages', True) is True + @pytest.mark.asyncio + async def test_ban_chat_member(self, monkeypatch, bot): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters + chat_id = data['chat_id'] == '2' + user_id = data['user_id'] == '32' + until_date = data.get('until_date', '1577887200') == '1577887200' + revoke_msgs = data.get('revoke_messages', 'true') == 'true' return chat_id and user_id and until_date and revoke_msgs - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) until = from_timestamp(1577887200) - assert bot.ban_chat_member(2, 32) - assert bot.ban_chat_member(2, 32, until_date=until) - assert bot.ban_chat_member(2, 32, until_date=1577887200) - assert bot.ban_chat_member(2, 32, revoke_messages=True) + assert await bot.ban_chat_member(2, 32) + assert await bot.ban_chat_member(2, 32, until_date=until) + assert await bot.ban_chat_member(2, 32, until_date=1577887200) + assert await bot.ban_chat_member(2, 32, revoke_messages=True) monkeypatch.delattr(bot.request, 'post') - def test_ban_chat_member_default_tz(self, monkeypatch, tz_bot): + @pytest.mark.asyncio + async def test_ban_chat_member_default_tz(self, monkeypatch, tz_bot): until = dtm.datetime(2020, 1, 11, 16, 13) until_timestamp = to_timestamp(until, tzinfo=tz_bot.defaults.tzinfo) - def test(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters chat_id = data['chat_id'] == 2 user_id = data['user_id'] == 32 until_date = data.get('until_date', until_timestamp) == until_timestamp return chat_id and user_id and until_date - monkeypatch.setattr(tz_bot.request, 'post', test) + monkeypatch.setattr(tz_bot.request, 'post', make_assertion) - assert tz_bot.ban_chat_member(2, 32) - assert tz_bot.ban_chat_member(2, 32, until_date=until) - assert tz_bot.ban_chat_member(2, 32, until_date=until_timestamp) + assert await tz_bot.ban_chat_member(2, 32) + assert await tz_bot.ban_chat_member(2, 32, until_date=until) + assert await tz_bot.ban_chat_member(2, 32, until_date=until_timestamp) - def test_ban_chat_sender_chat(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_ban_chat_sender_chat(self, monkeypatch, bot): # For now, we just test that we pass the correct data to TG - def make_assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters chat_id = data['chat_id'] == 2 sender_chat_id = data['sender_chat_id'] == 32 return chat_id and sender_chat_id monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.ban_chat_sender_chat(2, 32) + assert await bot.ban_chat_sender_chat(2, 32) monkeypatch.delattr(bot.request, 'post') # TODO: Needs improvement. @pytest.mark.parametrize('only_if_banned', [True, False, None]) - def test_unban_chat_member(self, monkeypatch, bot, only_if_banned): - def make_assertion(url, data, *args, **kwargs): + @pytest.mark.asyncio + async def test_unban_chat_member(self, monkeypatch, bot, only_if_banned): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters chat_id = data['chat_id'] == 2 user_id = data['user_id'] == 32 o_i_b = data.get('only_if_banned', None) == only_if_banned @@ -1130,42 +1359,49 @@ def make_assertion(url, data, *args, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.unban_chat_member(2, 32, only_if_banned=only_if_banned) + assert await bot.unban_chat_member(2, 32, only_if_banned=only_if_banned) - def test_unban_chat_sender_chat(self, monkeypatch, bot): - def make_assertion(url, data, *args, **kwargs): - chat_id = data['chat_id'] == 2 - sender_chat_id = data['sender_chat_id'] == 32 + @pytest.mark.asyncio + async def test_unban_chat_sender_chat(self, monkeypatch, bot): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters + chat_id = data['chat_id'] == '2' + sender_chat_id = data['sender_chat_id'] == '32' return chat_id and sender_chat_id monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.unbanChatSenderChat(2, 32) - - def test_set_chat_permissions(self, monkeypatch, bot, chat_permissions): - def test(url, data, *args, **kwargs): - chat_id = data['chat_id'] == 2 - permissions = data['permissions'] == chat_permissions.to_dict() + assert await bot.unban_chat_sender_chat(2, 32) + + @pytest.mark.asyncio + async def test_set_chat_permissions(self, monkeypatch, bot, chat_permissions): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters + chat_id = data['chat_id'] == '2' + permissions = data['permissions'] == chat_permissions.to_json() return chat_id and permissions - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.set_chat_permissions(2, chat_permissions) + assert await bot.set_chat_permissions(2, chat_permissions) - def test_set_chat_administrator_custom_title(self, monkeypatch, bot): - def test(url, data, *args, **kwargs): + @pytest.mark.asyncio + async def test_set_chat_administrator_custom_title(self, monkeypatch, bot): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters chat_id = data['chat_id'] == 2 user_id = data['user_id'] == 32 custom_title = data['custom_title'] == 'custom_title' return chat_id and user_id and custom_title - monkeypatch.setattr(bot.request, 'post', test) - assert bot.set_chat_administrator_custom_title(2, 32, 'custom_title') + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.set_chat_administrator_custom_title(2, 32, 'custom_title') # TODO: Needs improvement. Need an incoming callbackquery to test - def test_answer_callback_query(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_callback_query(self, monkeypatch, bot): # For now just test that our internals pass the correct data - def test(url, data, *args, **kwargs): - return data == { + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'callback_query_id': 23, 'show_alert': True, 'url': 'no_url', @@ -1173,15 +1409,16 @@ def test(url, data, *args, **kwargs): 'text': 'answer', } - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.answer_callback_query( + assert await bot.answer_callback_query( 23, text='answer', show_alert=True, url='no_url', cache_time=1 ) @flaky(3, 1) - def test_edit_message_text(self, bot, message): - message = bot.edit_message_text( + @pytest.mark.asyncio + async def test_edit_message_text(self, bot, message): + message = await bot.edit_message_text( text='new_text', chat_id=message.chat_id, message_id=message.message_id, @@ -1192,14 +1429,15 @@ def test_edit_message_text(self, bot, message): assert message.text == 'new_text' @flaky(3, 1) - def test_edit_message_text_entities(self, bot, message): + @pytest.mark.asyncio + async def test_edit_message_text_entities(self, bot, message): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.edit_message_text( + message = await bot.edit_message_text( text=test_string, chat_id=message.chat_id, message_id=message.message_id, @@ -1211,11 +1449,12 @@ def test_edit_message_text_entities(self, bot, message): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_edit_message_text_default_parse_mode(self, default_bot, message): + @pytest.mark.asyncio + async def test_edit_message_text_default_parse_mode(self, default_bot, message): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.edit_message_text( + message = await default_bot.edit_message_text( text=test_markdown_string, chat_id=message.chat_id, message_id=message.message_id, @@ -1224,7 +1463,7 @@ def test_edit_message_text_default_parse_mode(self, default_bot, message): assert message.text_markdown == test_markdown_string assert message.text == test_string - message = default_bot.edit_message_text( + message = await default_bot.edit_message_text( text=test_markdown_string, chat_id=message.chat_id, message_id=message.message_id, @@ -1234,13 +1473,13 @@ def test_edit_message_text_default_parse_mode(self, default_bot, message): assert message.text == test_markdown_string assert message.text_markdown == escape_markdown(test_markdown_string) - message = default_bot.edit_message_text( + message = await default_bot.edit_message_text( text=test_markdown_string, chat_id=message.chat_id, message_id=message.message_id, disable_web_page_preview=True, ) - message = default_bot.edit_message_text( + message = await default_bot.edit_message_text( text=test_markdown_string, chat_id=message.chat_id, message_id=message.message_id, @@ -1251,12 +1490,14 @@ def test_edit_message_text_default_parse_mode(self, default_bot, message): assert message.text_markdown == escape_markdown(test_markdown_string) @pytest.mark.skip(reason='need reference to an inline message') - def test_edit_message_text_inline(self): + @pytest.mark.asyncio + async def test_edit_message_text_inline(self): pass @flaky(3, 1) - def test_edit_message_caption(self, bot, media_message): - message = bot.edit_message_caption( + @pytest.mark.asyncio + async def test_edit_message_caption(self, bot, media_message): + message = await bot.edit_message_caption( caption='new_caption', chat_id=media_message.chat_id, message_id=media_message.message_id, @@ -1265,14 +1506,15 @@ def test_edit_message_caption(self, bot, media_message): assert message.caption == 'new_caption' @flaky(3, 1) - def test_edit_message_caption_entities(self, bot, media_message): + @pytest.mark.asyncio + async def test_edit_message_caption_entities(self, bot, media_message): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.edit_message_caption( + message = await bot.edit_message_caption( caption=test_string, chat_id=media_message.chat_id, message_id=media_message.message_id, @@ -1286,11 +1528,12 @@ def test_edit_message_caption_entities(self, bot, media_message): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_edit_message_caption_default_parse_mode(self, default_bot, media_message): + @pytest.mark.asyncio + async def test_edit_message_caption_default_parse_mode(self, default_bot, media_message): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.edit_message_caption( + message = await default_bot.edit_message_caption( caption=test_markdown_string, chat_id=media_message.chat_id, message_id=media_message.message_id, @@ -1298,7 +1541,7 @@ def test_edit_message_caption_default_parse_mode(self, default_bot, media_messag assert message.caption_markdown == test_markdown_string assert message.caption == test_string - message = default_bot.edit_message_caption( + message = await default_bot.edit_message_caption( caption=test_markdown_string, chat_id=media_message.chat_id, message_id=media_message.message_id, @@ -1307,12 +1550,12 @@ def test_edit_message_caption_default_parse_mode(self, default_bot, media_messag assert message.caption == test_markdown_string assert message.caption_markdown == escape_markdown(test_markdown_string) - message = default_bot.edit_message_caption( + message = await default_bot.edit_message_caption( caption=test_markdown_string, chat_id=media_message.chat_id, message_id=media_message.message_id, ) - message = default_bot.edit_message_caption( + message = await default_bot.edit_message_caption( caption=test_markdown_string, chat_id=media_message.chat_id, message_id=media_message.message_id, @@ -1322,8 +1565,9 @@ def test_edit_message_caption_default_parse_mode(self, default_bot, media_messag assert message.caption_markdown == escape_markdown(test_markdown_string) @flaky(3, 1) - def test_edit_message_caption_with_parse_mode(self, bot, media_message): - message = bot.edit_message_caption( + @pytest.mark.asyncio + async def test_edit_message_caption_with_parse_mode(self, bot, media_message): + message = await bot.edit_message_caption( caption='new *caption*', parse_mode='Markdown', chat_id=media_message.chat_id, @@ -1332,44 +1576,51 @@ def test_edit_message_caption_with_parse_mode(self, bot, media_message): assert message.caption == 'new caption' - def test_edit_message_caption_without_required(self, bot): + @pytest.mark.asyncio + async def test_edit_message_caption_without_required(self, bot): with pytest.raises(ValueError, match='Both chat_id and message_id are required when'): - bot.edit_message_caption(caption='new_caption') + await bot.edit_message_caption(caption='new_caption') @pytest.mark.skip(reason='need reference to an inline message') - def test_edit_message_caption_inline(self): + @pytest.mark.asyncio + async def test_edit_message_caption_inline(self): pass @flaky(3, 1) - def test_edit_reply_markup(self, bot, message): + @pytest.mark.asyncio + async def test_edit_reply_markup(self, bot, message): new_markup = InlineKeyboardMarkup([[InlineKeyboardButton(text='test', callback_data='1')]]) - message = bot.edit_message_reply_markup( + message = await bot.edit_message_reply_markup( chat_id=message.chat_id, message_id=message.message_id, reply_markup=new_markup ) assert message is not True - def test_edit_message_reply_markup_without_required(self, bot): + @pytest.mark.asyncio + async def test_edit_message_reply_markup_without_required(self, bot): new_markup = InlineKeyboardMarkup([[InlineKeyboardButton(text='test', callback_data='1')]]) with pytest.raises(ValueError, match='Both chat_id and message_id are required when'): - bot.edit_message_reply_markup(reply_markup=new_markup) + await bot.edit_message_reply_markup(reply_markup=new_markup) @pytest.mark.skip(reason='need reference to an inline message') - def test_edit_reply_markup_inline(self): + @pytest.mark.asyncio + async def test_edit_reply_markup_inline(self): pass # TODO: Actually send updates to the test bot so this can be tested properly @flaky(3, 1) - def test_get_updates(self, bot): - bot.delete_webhook() # make sure there is no webhook set if webhook tests failed - updates = bot.get_updates(timeout=1) + @pytest.mark.asyncio + async def test_get_updates(self, bot): + await bot.delete_webhook() # make sure there is no webhook set if webhook tests failed + updates = await bot.get_updates(timeout=1) assert isinstance(updates, list) if updates: assert isinstance(updates[0], Update) - def test_get_updates_invalid_callback_data(self, bot, monkeypatch): - def post(*args, **kwargs): + @pytest.mark.asyncio + async def test_get_updates_invalid_callback_data(self, bot, monkeypatch): + async def post(*args, **kwargs): return [ Update( 17, @@ -1391,9 +1642,9 @@ def post(*args, **kwargs): bot.arbitrary_callback_data = True try: - monkeypatch.setattr(bot.request, 'post', post) - bot.delete_webhook() # make sure there is no webhook set if webhook tests failed - updates = bot.get_updates(timeout=1) + await bot.delete_webhook() # make sure there is no webhook set if webhook tests failed + monkeypatch.setattr(BaseRequest, 'post', post) + updates = await bot.get_updates(timeout=1) assert isinstance(updates, list) assert len(updates) == 1 @@ -1404,91 +1655,101 @@ def post(*args, **kwargs): bot.arbitrary_callback_data = False @flaky(3, 1) - @pytest.mark.xfail - def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot): + @pytest.mark.asyncio + async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot): url = 'https://python-telegram-bot.org/test/webhook' max_connections = 7 allowed_updates = ['message'] - bot.set_webhook( + await bot.set_webhook( url, max_connections=max_connections, allowed_updates=allowed_updates, - ip_address='127.0.0.1', - ) - time.sleep(2) - live_info = bot.get_webhook_info() - time.sleep(6) - bot.delete_webhook() - time.sleep(2) - info = bot.get_webhook_info() + ip_address='192.0.2.142', + ) + await asyncio.sleep(2) + live_info = await bot.get_webhook_info() + await asyncio.sleep(6) + await bot.delete_webhook() + await asyncio.sleep(2) + info = await bot.get_webhook_info() assert info.url == '' assert live_info.url == url assert live_info.max_connections == max_connections assert live_info.allowed_updates == allowed_updates - assert live_info.ip_address == '127.0.0.1' + assert live_info.ip_address == '198.51.100.142' @pytest.mark.parametrize('drop_pending_updates', [True, False]) - def test_set_webhook_delete_webhook_drop_pending_updates( + @pytest.mark.asyncio + async def test_set_webhook_delete_webhook_drop_pending_updates( self, bot, drop_pending_updates, monkeypatch ): - def assertion(url, data, *args, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters return bool(data.get('drop_pending_updates')) == drop_pending_updates - monkeypatch.setattr(bot.request, 'post', assertion) + monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.set_webhook('', drop_pending_updates=drop_pending_updates) - assert bot.delete_webhook(drop_pending_updates=drop_pending_updates) + assert await bot.set_webhook('', drop_pending_updates=drop_pending_updates) + assert await bot.delete_webhook(drop_pending_updates=drop_pending_updates) @flaky(3, 1) - def test_leave_chat(self, bot): + @pytest.mark.asyncio + async def test_leave_chat(self, bot): with pytest.raises(BadRequest, match='Chat not found'): - bot.leave_chat(-123456) + await bot.leave_chat(-123456) with pytest.raises(NetworkError, match='Chat not found'): - bot.leave_chat(-123456) + await bot.leave_chat(-123456) @flaky(3, 1) - def test_get_chat(self, bot, super_group_id): - chat = bot.get_chat(super_group_id) + @pytest.mark.asyncio + async def test_get_chat(self, bot, super_group_id): + chat = await bot.get_chat(super_group_id) assert chat.type == 'supergroup' assert chat.title == f'>>> telegram.Bot(test) @{bot.username}' assert chat.id == int(super_group_id) @flaky(3, 1) - def test_get_chat_administrators(self, bot, channel_id): - admins = bot.get_chat_administrators(channel_id) + @pytest.mark.asyncio + async def test_get_chat_administrators(self, bot, channel_id): + admins = await bot.get_chat_administrators(channel_id) assert isinstance(admins, list) for a in admins: assert a.status in ('administrator', 'creator') @flaky(3, 1) - def test_get_chat_member_count(self, bot, channel_id): - count = bot.get_chat_member_count(channel_id) + @pytest.mark.asyncio + async def test_get_chat_member_count(self, bot, channel_id): + count = await bot.get_chat_member_count(channel_id) assert isinstance(count, int) assert count > 3 @flaky(3, 1) - def test_get_chat_member(self, bot, channel_id, chat_id): - chat_member = bot.get_chat_member(channel_id, chat_id) + @pytest.mark.asyncio + async def test_get_chat_member(self, bot, channel_id, chat_id): + chat_member = await bot.get_chat_member(channel_id, chat_id) assert chat_member.status == 'administrator' assert chat_member.user.first_name == 'PTB' assert chat_member.user.last_name == 'Test user' @pytest.mark.skip(reason="Not implemented since we need a supergroup with many members") - def test_set_chat_sticker_set(self): + @pytest.mark.asyncio + async def test_set_chat_sticker_set(self): pass @pytest.mark.skip(reason="Not implemented since we need a supergroup with many members") - def test_delete_chat_sticker_set(self): + @pytest.mark.asyncio + async def test_delete_chat_sticker_set(self): pass @flaky(3, 1) - def test_send_game(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_game(self, bot, chat_id): game_short_name = 'test_game' - message = bot.send_game(chat_id, game_short_name, protect_content=True) + message = await bot.send_game(chat_id, game_short_name, protect_content=True) assert message.game assert message.game.description == ( @@ -1510,12 +1771,15 @@ def test_send_game(self, bot, chat_id): ], indirect=['default_bot'], ) - def test_send_game_default_allow_sending_without_reply(self, default_bot, chat_id, custom): + @pytest.mark.asyncio + async def test_send_game_default_allow_sending_without_reply( + self, default_bot, chat_id, custom + ): game_short_name = 'test_game' - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_game( + message = await default_bot.send_game( chat_id, game_short_name, allow_sending_without_reply=custom, @@ -1523,7 +1787,7 @@ def test_send_game_default_allow_sending_without_reply(self, default_bot, chat_i ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_game( + message = await default_bot.send_game( chat_id, game_short_name, reply_to_message_id=reply_to_message.message_id, @@ -1531,28 +1795,30 @@ def test_send_game_default_allow_sending_without_reply(self, default_bot, chat_i assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_game( + await default_bot.send_game( chat_id, game_short_name, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize( 'default_bot,val', [({'protect_content': True}, True), ({'protect_content': False}, None)], indirect=['default_bot'], ) - def test_send_game_default_protect_content(self, default_bot, chat_id, val): - protected = default_bot.send_game(chat_id, 'test_game', protect_content=val) + async def test_send_game_default_protect_content(self, default_bot, chat_id, val): + protected = await default_bot.send_game(chat_id, 'test_game', protect_content=val) assert protected.has_protected_content is val @xfail - def test_set_game_score_1(self, bot, chat_id): + @pytest.mark.asyncio + async def test_set_game_score_1(self, bot, chat_id): # NOTE: numbering of methods assures proper order between test_set_game_scoreX methods # First, test setting a score. game_short_name = 'test_game' - game = bot.send_game(chat_id, game_short_name) + game = await bot.send_game(chat_id, game_short_name) - message = bot.set_game_score( + message = await bot.set_game_score( user_id=chat_id, score=BASE_GAME_SCORE, # Score value is relevant for other set_game_score_* tests! chat_id=game.chat_id, @@ -1565,15 +1831,16 @@ def test_set_game_score_1(self, bot, chat_id): assert message.game.text != game.game.text @xfail - def test_set_game_score_2(self, bot, chat_id): + @pytest.mark.asyncio + async def test_set_game_score_2(self, bot, chat_id): # NOTE: numbering of methods assures proper order between test_set_game_scoreX methods # Test setting a score higher than previous game_short_name = 'test_game' - game = bot.send_game(chat_id, game_short_name) + game = await bot.send_game(chat_id, game_short_name) score = BASE_GAME_SCORE + 1 - message = bot.set_game_score( + message = await bot.set_game_score( user_id=chat_id, score=score, chat_id=game.chat_id, @@ -1587,30 +1854,32 @@ def test_set_game_score_2(self, bot, chat_id): assert message.game.text == game.game.text @xfail - def test_set_game_score_3(self, bot, chat_id): + @pytest.mark.asyncio + async def test_set_game_score_3(self, bot, chat_id): # NOTE: numbering of methods assures proper order between test_set_game_scoreX methods # Test setting a score lower than previous (should raise error) game_short_name = 'test_game' - game = bot.send_game(chat_id, game_short_name) + game = await bot.send_game(chat_id, game_short_name) score = BASE_GAME_SCORE # Even a score equal to previous raises an error. with pytest.raises(BadRequest, match='Bot_score_not_modified'): - bot.set_game_score( + await bot.set_game_score( user_id=chat_id, score=score, chat_id=game.chat_id, message_id=game.message_id ) @xfail - def test_set_game_score_4(self, bot, chat_id): + @pytest.mark.asyncio + async def test_set_game_score_4(self, bot, chat_id): # NOTE: numbering of methods assures proper order between test_set_game_scoreX methods # Test force setting a lower score game_short_name = 'test_game' - game = bot.send_game(chat_id, game_short_name) + game = await bot.send_game(chat_id, game_short_name) time.sleep(2) score = BASE_GAME_SCORE - 10 - message = bot.set_game_score( + message = await bot.set_game_score( user_id=chat_id, score=score, chat_id=game.chat_id, @@ -1624,25 +1893,27 @@ def test_set_game_score_4(self, bot, chat_id): # For some reason the returned message doesn't contain the updated score. need to fetch # the game again... (the service message is also absent when running the test suite) - game2 = bot.send_game(chat_id, game_short_name) + game2 = await bot.send_game(chat_id, game_short_name) assert str(score) in game2.game.text @xfail - def test_get_game_high_scores(self, bot, chat_id): + @pytest.mark.asyncio + async def test_get_game_high_scores(self, bot, chat_id): # We need a game to get the scores for game_short_name = 'test_game' - game = bot.send_game(chat_id, game_short_name) - high_scores = bot.get_game_high_scores(chat_id, game.chat_id, game.message_id) + game = await bot.send_game(chat_id, game_short_name) + high_scores = await bot.get_game_high_scores(chat_id, game.chat_id, game.message_id) # We assume that the other game score tests ran within 20 sec assert high_scores[0].score == BASE_GAME_SCORE - 10 # send_invoice is tested in test_invoice # TODO: Needs improvement. Need incoming shipping queries to test - def test_answer_shipping_query_ok(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_shipping_query_ok(self, monkeypatch, bot): # For now just test that our internals pass the correct data - def test(url, data, *args, **kwargs): - return data == { + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'shipping_query_id': 1, 'ok': True, 'shipping_options': [ @@ -1650,100 +1921,108 @@ def test(url, data, *args, **kwargs): ], } - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) shipping_options = ShippingOption(1, 'option1', [LabeledPrice('price', 100)]) - assert bot.answer_shipping_query(1, True, shipping_options=[shipping_options]) + assert await bot.answer_shipping_query(1, True, shipping_options=[shipping_options]) - def test_answer_shipping_query_error_message(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_shipping_query_error_message(self, monkeypatch, bot): # For now just test that our internals pass the correct data - def test(url, data, *args, **kwargs): - return data == { + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'shipping_query_id': 1, 'error_message': 'Not enough fish', 'ok': False, } - monkeypatch.setattr(bot.request, 'post', test) - assert bot.answer_shipping_query(1, False, error_message='Not enough fish') + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.answer_shipping_query(1, False, error_message='Not enough fish') - def test_answer_shipping_query_errors(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_shipping_query_errors(self, monkeypatch, bot): shipping_options = ShippingOption(1, 'option1', [LabeledPrice('price', 100)]) with pytest.raises(TelegramError, match='should not be empty and there should not be'): - bot.answer_shipping_query(1, True, error_message='Not enough fish') + await bot.answer_shipping_query(1, True, error_message='Not enough fish') with pytest.raises(TelegramError, match='should not be empty and there should not be'): - bot.answer_shipping_query(1, False) + await bot.answer_shipping_query(1, False) with pytest.raises(TelegramError, match='should not be empty and there should not be'): - bot.answer_shipping_query(1, False, shipping_options=shipping_options) + await bot.answer_shipping_query(1, False, shipping_options=shipping_options) with pytest.raises(TelegramError, match='should not be empty and there should not be'): - bot.answer_shipping_query(1, True) + await bot.answer_shipping_query(1, True) with pytest.raises(AssertionError): - bot.answer_shipping_query(1, True, shipping_options=[]) + await bot.answer_shipping_query(1, True, shipping_options=[]) # TODO: Needs improvement. Need incoming pre checkout queries to test - def test_answer_pre_checkout_query_ok(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_pre_checkout_query_ok(self, monkeypatch, bot): # For now just test that our internals pass the correct data - def test(url, data, *args, **kwargs): - return data == {'pre_checkout_query_id': 1, 'ok': True} + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == {'pre_checkout_query_id': 1, 'ok': True} - monkeypatch.setattr(bot.request, 'post', test) - assert bot.answer_pre_checkout_query(1, True) + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.answer_pre_checkout_query(1, True) - def test_answer_pre_checkout_query_error_message(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_pre_checkout_query_error_message(self, monkeypatch, bot): # For now just test that our internals pass the correct data - def test(url, data, *args, **kwargs): - return data == { + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters == { 'pre_checkout_query_id': 1, 'error_message': 'Not enough fish', 'ok': False, } - monkeypatch.setattr(bot.request, 'post', test) - assert bot.answer_pre_checkout_query(1, False, error_message='Not enough fish') + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.answer_pre_checkout_query(1, False, error_message='Not enough fish') - def test_answer_pre_checkout_query_errors(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_answer_pre_checkout_query_errors(self, monkeypatch, bot): with pytest.raises(TelegramError, match='should not be'): - bot.answer_pre_checkout_query(1, True, error_message='Not enough fish') + await bot.answer_pre_checkout_query(1, True, error_message='Not enough fish') with pytest.raises(TelegramError, match='should not be empty'): - bot.answer_pre_checkout_query(1, False) + await bot.answer_pre_checkout_query(1, False) @flaky(3, 1) - def test_restrict_chat_member(self, bot, channel_id, chat_permissions): + @pytest.mark.asyncio + async def test_restrict_chat_member(self, bot, channel_id, chat_permissions): # TODO: Add bot to supergroup so this can be tested properly with pytest.raises(BadRequest, match='Method is available only for supergroups'): - assert bot.restrict_chat_member( + assert await bot.restrict_chat_member( channel_id, 95205500, chat_permissions, until_date=dtm.datetime.utcnow() ) - def test_restrict_chat_member_default_tz( + @pytest.mark.asyncio + async def test_restrict_chat_member_default_tz( self, monkeypatch, tz_bot, channel_id, chat_permissions ): until = dtm.datetime(2020, 1, 11, 16, 13) until_timestamp = to_timestamp(until, tzinfo=tz_bot.defaults.tzinfo) - def test(url, data, *args, **kwargs): - return data.get('until_date', until_timestamp) == until_timestamp + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters.get('until_date', until_timestamp) == until_timestamp - monkeypatch.setattr(tz_bot.request, 'post', test) + monkeypatch.setattr(tz_bot.request, 'post', make_assertion) - assert tz_bot.restrict_chat_member(channel_id, 95205500, chat_permissions) - assert tz_bot.restrict_chat_member( + assert await tz_bot.restrict_chat_member(channel_id, 95205500, chat_permissions) + assert await tz_bot.restrict_chat_member( channel_id, 95205500, chat_permissions, until_date=until ) - assert tz_bot.restrict_chat_member( + assert await tz_bot.restrict_chat_member( channel_id, 95205500, chat_permissions, until_date=until_timestamp ) @flaky(3, 1) - def test_promote_chat_member(self, bot, channel_id, monkeypatch): + @pytest.mark.asyncio + async def test_promote_chat_member(self, bot, channel_id, monkeypatch): # TODO: Add bot to supergroup so this can be tested properly / give bot perms with pytest.raises(BadRequest, match='Not enough rights'): - assert bot.promote_chat_member( + assert await bot.promote_chat_member( channel_id, 95205500, is_anonymous=True, @@ -1760,7 +2039,7 @@ def test_promote_chat_member(self, bot, channel_id, monkeypatch): ) # Test that we pass the correct params to TG - def make_assertion(*args, **_): + async def make_assertion(*args, **_): data = args[1] return ( data.get('chat_id') == channel_id @@ -1779,7 +2058,7 @@ def make_assertion(*args, **_): ) monkeypatch.setattr(bot, '_post', make_assertion) - assert bot.promote_chat_member( + assert await bot.promote_chat_member( channel_id, 95205500, is_anonymous=1, @@ -1796,34 +2075,39 @@ def make_assertion(*args, **_): ) @flaky(3, 1) - def test_export_chat_invite_link(self, bot, channel_id): + @pytest.mark.asyncio + async def test_export_chat_invite_link(self, bot, channel_id): # Each link is unique apparently - invite_link = bot.export_chat_invite_link(channel_id) + invite_link = await bot.export_chat_invite_link(channel_id) assert isinstance(invite_link, str) assert invite_link != '' - def test_create_edit_invite_link_mutually_exclusive_arguments(self, bot, channel_id): + @pytest.mark.asyncio + async def test_create_edit_invite_link_mutually_exclusive_arguments(self, bot, channel_id): data = {'chat_id': channel_id, 'member_limit': 17, 'creates_join_request': True} with pytest.raises(ValueError, match="`member_limit` can't be specified"): - bot.create_chat_invite_link(**data) + await bot.create_chat_invite_link(**data) data.update({'invite_link': 'https://invite.link'}) with pytest.raises(ValueError, match="`member_limit` can't be specified"): - bot.edit_chat_invite_link(**data) + await bot.edit_chat_invite_link(**data) @flaky(3, 1) - def test_edit_revoke_chat_invite_link_passing_link_objects(self, bot, channel_id): - invite_link = bot.create_chat_invite_link(chat_id=channel_id) + @pytest.mark.asyncio + async def test_edit_revoke_chat_invite_link_passing_link_objects(self, bot, channel_id): + invite_link = await bot.create_chat_invite_link(chat_id=channel_id) assert invite_link.name is None - edited_link = bot.edit_chat_invite_link( + edited_link = await bot.edit_chat_invite_link( chat_id=channel_id, invite_link=invite_link, name='some_name' ) assert edited_link == invite_link assert edited_link.name == 'some_name' - revoked_link = bot.revoke_chat_invite_link(chat_id=channel_id, invite_link=edited_link) + revoked_link = await bot.revoke_chat_invite_link( + chat_id=channel_id, invite_link=edited_link + ) assert revoked_link.invite_link == edited_link.invite_link assert revoked_link.is_revoked is True assert revoked_link.name == 'some_name' @@ -1831,27 +2115,31 @@ def test_edit_revoke_chat_invite_link_passing_link_objects(self, bot, channel_id @flaky(3, 1) @pytest.mark.parametrize('creates_join_request', [True, False]) @pytest.mark.parametrize('name', [None, 'name']) - def test_create_chat_invite_link_basics(self, bot, creates_join_request, name, channel_id): + @pytest.mark.asyncio + async def test_create_chat_invite_link_basics( + self, bot, creates_join_request, name, channel_id + ): data = {} if creates_join_request: data['creates_join_request'] = True if name: data['name'] = name - invite_link = bot.create_chat_invite_link(chat_id=channel_id, **data) + invite_link = await bot.create_chat_invite_link(chat_id=channel_id, **data) assert invite_link.member_limit is None assert invite_link.expire_date is None assert invite_link.creates_join_request == creates_join_request assert invite_link.name == name - revoked_link = bot.revoke_chat_invite_link( + revoked_link = await bot.revoke_chat_invite_link( chat_id=channel_id, invite_link=invite_link.invite_link ) assert revoked_link.is_revoked @flaky(3, 1) @pytest.mark.parametrize('datetime', argvalues=[True, False], ids=['datetime', 'integer']) - def test_advanced_chat_invite_links(self, bot, channel_id, datetime): + @pytest.mark.asyncio + async def test_advanced_chat_invite_links(self, bot, channel_id, datetime): # we are testing this all in one function in order to save api calls timestamp = dtm.datetime.utcnow() add_seconds = dtm.timedelta(0, 70) @@ -1859,7 +2147,7 @@ def test_advanced_chat_invite_links(self, bot, channel_id, datetime): expire_time = time_in_future if datetime else to_timestamp(time_in_future) aware_time_in_future = pytz.UTC.localize(time_in_future) - invite_link = bot.create_chat_invite_link( + invite_link = await bot.create_chat_invite_link( channel_id, expire_date=expire_time, member_limit=10 ) assert invite_link.invite_link != '' @@ -1872,7 +2160,7 @@ def test_advanced_chat_invite_links(self, bot, channel_id, datetime): expire_time = time_in_future if datetime else to_timestamp(time_in_future) aware_time_in_future = pytz.UTC.localize(time_in_future) - edited_invite_link = bot.edit_chat_invite_link( + edited_invite_link = await bot.edit_chat_invite_link( channel_id, invite_link.invite_link, expire_date=expire_time, @@ -1884,7 +2172,7 @@ def test_advanced_chat_invite_links(self, bot, channel_id, datetime): assert edited_invite_link.name == 'NewName' assert edited_invite_link.member_limit == 20 - edited_invite_link = bot.edit_chat_invite_link( + edited_invite_link = await bot.edit_chat_invite_link( channel_id, invite_link.invite_link, name='EvenNewerName', @@ -1896,18 +2184,21 @@ def test_advanced_chat_invite_links(self, bot, channel_id, datetime): assert edited_invite_link.creates_join_request is True assert edited_invite_link.member_limit is None - revoked_invite_link = bot.revoke_chat_invite_link(channel_id, invite_link.invite_link) + revoked_invite_link = await bot.revoke_chat_invite_link( + channel_id, invite_link.invite_link + ) assert revoked_invite_link.invite_link == invite_link.invite_link assert revoked_invite_link.is_revoked is True @flaky(3, 1) - def test_advanced_chat_invite_links_default_tzinfo(self, tz_bot, channel_id): + @pytest.mark.asyncio + async def test_advanced_chat_invite_links_default_tzinfo(self, tz_bot, channel_id): # we are testing this all in one function in order to save api calls add_seconds = dtm.timedelta(0, 70) aware_expire_date = dtm.datetime.now(tz=tz_bot.defaults.tzinfo) + add_seconds time_in_future = aware_expire_date.replace(tzinfo=None) - invite_link = tz_bot.create_chat_invite_link( + invite_link = await tz_bot.create_chat_invite_link( channel_id, expire_date=time_in_future, member_limit=10 ) assert invite_link.invite_link != '' @@ -1919,7 +2210,7 @@ def test_advanced_chat_invite_links_default_tzinfo(self, tz_bot, channel_id): aware_expire_date += add_seconds time_in_future = aware_expire_date.replace(tzinfo=None) - edited_invite_link = tz_bot.edit_chat_invite_link( + edited_invite_link = await tz_bot.edit_chat_invite_link( channel_id, invite_link.invite_link, expire_date=time_in_future, @@ -1931,7 +2222,7 @@ def test_advanced_chat_invite_links_default_tzinfo(self, tz_bot, channel_id): assert edited_invite_link.name == 'NewName' assert edited_invite_link.member_limit == 20 - edited_invite_link = tz_bot.edit_chat_invite_link( + edited_invite_link = await tz_bot.edit_chat_invite_link( channel_id, invite_link.invite_link, name='EvenNewerName', @@ -1943,143 +2234,176 @@ def test_advanced_chat_invite_links_default_tzinfo(self, tz_bot, channel_id): assert edited_invite_link.creates_join_request is True assert edited_invite_link.member_limit is None - revoked_invite_link = tz_bot.revoke_chat_invite_link(channel_id, invite_link.invite_link) + revoked_invite_link = await tz_bot.revoke_chat_invite_link( + channel_id, invite_link.invite_link + ) assert revoked_invite_link.invite_link == invite_link.invite_link assert revoked_invite_link.is_revoked is True @flaky(3, 1) - def test_approve_chat_join_request(self, bot, chat_id, channel_id): + @pytest.mark.asyncio + async def test_approve_chat_join_request(self, bot, chat_id, channel_id): # TODO: Need incoming join request to properly test # Since we can't create join requests on the fly, we just tests the call to TG # by checking that it complains about approving a user who is already in the chat with pytest.raises(BadRequest, match='User_already_participant'): - bot.approve_chat_join_request(chat_id=channel_id, user_id=chat_id) + await bot.approve_chat_join_request(chat_id=channel_id, user_id=chat_id) @flaky(3, 1) - def test_decline_chat_join_request(self, bot, chat_id, channel_id): + @pytest.mark.asyncio + async def test_decline_chat_join_request(self, bot, chat_id, channel_id): # TODO: Need incoming join request to properly test # Since we can't create join requests on the fly, we just tests the call to TG # by checking that it complains about declining a user who is already in the chat with pytest.raises(BadRequest, match='User_already_participant'): - bot.decline_chat_join_request(chat_id=channel_id, user_id=chat_id) + await bot.decline_chat_join_request(chat_id=channel_id, user_id=chat_id) @flaky(3, 1) - def test_set_chat_photo(self, bot, channel_id): - def func(): - assert bot.set_chat_photo(channel_id, f) + @pytest.mark.asyncio + async def test_set_chat_photo(self, bot, channel_id): + async def func(): + assert await bot.set_chat_photo(channel_id, f) with data_file('telegram_test_channel.jpg').open('rb') as f: - expect_bad_request(func, 'Type of file mismatch', 'Telegram did not accept the file.') + await expect_bad_request( + func, 'Type of file mismatch', 'Telegram did not accept the file.' + ) - def test_set_chat_photo_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_set_chat_photo_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('photo') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.set_chat_photo(chat_id, file) + await bot.set_chat_photo(chat_id, file) assert test_flag @flaky(3, 1) - def test_delete_chat_photo(self, bot, channel_id): - def func(): - assert bot.delete_chat_photo(channel_id) + @pytest.mark.asyncio + async def test_delete_chat_photo(self, bot, channel_id): + async def func(): + assert await bot.delete_chat_photo(channel_id) - expect_bad_request(func, 'Chat_not_modified', 'Chat photo was not set.') + await expect_bad_request(func, 'Chat_not_modified', 'Chat photo was not set.') @flaky(3, 1) - def test_set_chat_title(self, bot, channel_id): - assert bot.set_chat_title(channel_id, '>>> telegram.Bot() - Tests') + @pytest.mark.asyncio + async def test_set_chat_title(self, bot, channel_id): + assert await bot.set_chat_title(channel_id, '>>> telegram.Bot() - Tests') @flaky(3, 1) - def test_set_chat_description(self, bot, channel_id): - assert bot.set_chat_description(channel_id, 'Time: ' + str(time.time())) + @pytest.mark.asyncio + async def test_set_chat_description(self, bot, channel_id): + assert await bot.set_chat_description(channel_id, 'Time: ' + str(time.time())) @flaky(3, 1) - def test_pin_and_unpin_message(self, bot, super_group_id): - message1 = bot.send_message(super_group_id, text="test_pin_message_1") - message2 = bot.send_message(super_group_id, text="test_pin_message_2") - message3 = bot.send_message(super_group_id, text="test_pin_message_3") + @pytest.mark.asyncio + async def test_pin_and_unpin_message(self, bot, super_group_id): + message1 = await bot.send_message(super_group_id, text="test_pin_message_1") + message2 = await bot.send_message(super_group_id, text="test_pin_message_2") + message3 = await bot.send_message(super_group_id, text="test_pin_message_3") - assert bot.pin_chat_message( - chat_id=super_group_id, message_id=message1.message_id, disable_notification=True + assert await bot.pin_chat_message( + chat_id=super_group_id, + message_id=message1.message_id, + disable_notification=True, + read_timeout=10, ) time.sleep(1) - bot.pin_chat_message( - chat_id=super_group_id, message_id=message2.message_id, disable_notification=True + await bot.pin_chat_message( + chat_id=super_group_id, + message_id=message2.message_id, + disable_notification=True, + read_timeout=10, ) - time.sleep(1) - bot.pin_chat_message( - chat_id=super_group_id, message_id=message3.message_id, disable_notification=True + await bot.pin_chat_message( + chat_id=super_group_id, + message_id=message3.message_id, + disable_notification=True, + read_timeout=10, ) time.sleep(1) - chat = bot.get_chat(super_group_id) + chat = await bot.get_chat(super_group_id) assert chat.pinned_message == message3 - assert bot.unpin_chat_message(super_group_id, message_id=message2.message_id) - assert bot.unpin_chat_message(super_group_id) + assert await bot.unpin_chat_message( + super_group_id, + message_id=message2.message_id, + read_timeout=10, + ) + assert await bot.unpin_chat_message( + super_group_id, + read_timeout=10, + ) - assert bot.unpin_all_chat_messages(super_group_id) + assert await bot.unpin_all_chat_messages( + super_group_id, + read_timeout=10, + ) # get_sticker_set, upload_sticker_file, create_new_sticker_set, add_sticker_to_set, # set_sticker_position_in_set and delete_sticker_from_set are tested in the # test_sticker module. - def test_timeout_propagation_explicit(self, monkeypatch, bot, chat_id): - - from telegram.vendor.ptb_urllib3.urllib3.util.timeout import Timeout - - class OkException(Exception): + @pytest.mark.asyncio + async def test_timeout_propagation_explicit(self, monkeypatch, bot, chat_id): + # Use BaseException that's not a subclass of Exception such that + # OkException should not be caught anywhere + class OkException(BaseException): pass - TIMEOUT = 500 + timeout = 42 - def request_wrapper(*args, **kwargs): - obj = kwargs.get('timeout') - if isinstance(obj, Timeout) and obj._read == TIMEOUT: + async def do_request(*args, **kwargs): + obj = kwargs.get('read_timeout') + if obj == timeout: raise OkException - return b'{"ok": true, "result": []}' + return 200, b'{"ok": true, "result": []}' - monkeypatch.setattr('telegram.request.Request._request_wrapper', request_wrapper) + monkeypatch.setattr(bot.request, 'do_request', do_request) # Test file uploading with pytest.raises(OkException): - bot.send_photo(chat_id, data_file('telegram.jpg').open('rb'), timeout=TIMEOUT) + await bot.send_photo( + chat_id, data_file('telegram.jpg').open('rb'), read_timeout=timeout + ) # Test JSON submission with pytest.raises(OkException): - bot.get_chat_administrators(chat_id, timeout=TIMEOUT) - - def test_timeout_propagation_implicit(self, monkeypatch, bot, chat_id): + await bot.get_chat_administrators(chat_id, read_timeout=timeout) - from telegram.vendor.ptb_urllib3.urllib3.util.timeout import Timeout - - class OkException(Exception): + @pytest.mark.asyncio + async def test_timeout_propagation_implicit(self, monkeypatch, bot, chat_id): + # Use BaseException that's not a subclass of Exception such that + # OkException should not be caught anywhere + class OkException(BaseException): pass - def request_wrapper(*args, **kwargs): - obj = kwargs.get('timeout') - if isinstance(obj, Timeout) and obj._read == 20: + async def do_request(*args, **kwargs): + obj = kwargs.get('read_timeout') + if obj == 20: raise OkException - return b'{"ok": true, "result": []}' + return 200, b'{"ok": true, "result": []}' - monkeypatch.setattr('telegram.request.Request._request_wrapper', request_wrapper) + monkeypatch.setattr(bot.request, 'do_request', do_request) # Test file uploading with pytest.raises(OkException): - bot.send_photo(chat_id, data_file('telegram.jpg').open('rb')) + await bot.send_photo(chat_id, data_file('telegram.jpg').open('rb')) @flaky(3, 1) - def test_send_message_entities(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_message_entities(self, bot, chat_id): test_string = 'Italic Bold Code Spoiler' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), @@ -2087,35 +2411,37 @@ def test_send_message_entities(self, bot, chat_id): MessageEntity(MessageEntity.ITALIC, 12, 4), MessageEntity(MessageEntity.SPOILER, 17, 7), ] - message = bot.send_message(chat_id=chat_id, text=test_string, entities=entities) + message = await bot.send_message(chat_id=chat_id, text=test_string, entities=entities) assert message.text == test_string assert message.entities == entities @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_message_default_parse_mode(self, default_bot, chat_id): + @pytest.mark.asyncio + async def test_send_message_default_parse_mode(self, default_bot, chat_id): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_message(chat_id, test_markdown_string) + message = await default_bot.send_message(chat_id, test_markdown_string) assert message.text_markdown == test_markdown_string assert message.text == test_string - message = default_bot.send_message(chat_id, test_markdown_string, parse_mode=None) + message = await default_bot.send_message(chat_id, test_markdown_string, parse_mode=None) assert message.text == test_markdown_string assert message.text_markdown == escape_markdown(test_markdown_string) - message = default_bot.send_message(chat_id, test_markdown_string, parse_mode='HTML') + message = await default_bot.send_message(chat_id, test_markdown_string, parse_mode='HTML') assert message.text == test_markdown_string assert message.text_markdown == escape_markdown(test_markdown_string) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_message_default_protect_content(self, default_bot, chat_id): - to_check = default_bot.send_message(chat_id, "test") + async def test_send_message_default_protect_content(self, default_bot, chat_id): + to_check = await default_bot.send_message(chat_id, "test") assert to_check.has_protected_content - no_protect = default_bot.send_message(chat_id, "test", protect_content=False) + no_protect = await default_bot.send_message(chat_id, "test", protect_content=False) assert not no_protect.has_protected_content @flaky(3, 1) @@ -2128,11 +2454,14 @@ def test_send_message_default_protect_content(self, default_bot, chat_id): ], indirect=['default_bot'], ) - def test_send_message_default_allow_sending_without_reply(self, default_bot, chat_id, custom): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + @pytest.mark.asyncio + async def test_send_message_default_allow_sending_without_reply( + self, default_bot, chat_id, custom + ): + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_message( + message = await default_bot.send_message( chat_id, 'test', allow_sending_without_reply=custom, @@ -2140,103 +2469,119 @@ def test_send_message_default_allow_sending_without_reply(self, default_bot, cha ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_message( + message = await default_bot.send_message( chat_id, 'test', reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_message( + await default_bot.send_message( chat_id, 'test', reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) - def test_set_and_get_my_commands(self, bot): + @pytest.mark.asyncio + async def test_set_and_get_my_commands(self, bot): commands = [BotCommand('cmd1', 'descr1'), ['cmd2', 'descr2']] - bot.set_my_commands([]) - assert bot.get_my_commands() == [] - assert bot.set_my_commands(commands) + await bot.set_my_commands([]) + assert await bot.get_my_commands() == [] + assert await bot.set_my_commands(commands) - for i, bc in enumerate(bot.get_my_commands()): + for i, bc in enumerate(await bot.get_my_commands()): assert bc.command == f'cmd{i+1}' assert bc.description == f'descr{i+1}' @flaky(3, 1) - def test_get_set_delete_my_commands_with_scope(self, bot, super_group_id, chat_id): + @pytest.mark.asyncio + async def test_get_set_delete_my_commands_with_scope(self, bot, super_group_id, chat_id): group_cmds = [BotCommand('group_cmd', 'visible to this supergroup only')] private_cmds = [BotCommand('private_cmd', 'visible to this private chat only')] group_scope = BotCommandScopeChat(super_group_id) private_scope = BotCommandScopeChat(chat_id) # Set supergroup command list with lang code and check if the same can be returned from api - bot.set_my_commands(group_cmds, scope=group_scope, language_code='en') - gotten_group_cmds = bot.get_my_commands(scope=group_scope, language_code='en') + await bot.set_my_commands(group_cmds, scope=group_scope, language_code='en') + gotten_group_cmds = await bot.get_my_commands(scope=group_scope, language_code='en') assert len(gotten_group_cmds) == len(group_cmds) assert gotten_group_cmds[0].command == group_cmds[0].command # Set private command list and check if same can be returned from the api - bot.set_my_commands(private_cmds, scope=private_scope) - gotten_private_cmd = bot.get_my_commands(scope=private_scope) + await bot.set_my_commands(private_cmds, scope=private_scope) + gotten_private_cmd = await bot.get_my_commands(scope=private_scope) assert len(gotten_private_cmd) == len(private_cmds) assert gotten_private_cmd[0].command == private_cmds[0].command # Delete command list from that supergroup and private chat- - bot.delete_my_commands(private_scope) - bot.delete_my_commands(group_scope, 'en') + await bot.delete_my_commands(private_scope) + await bot.delete_my_commands(group_scope, 'en') # Check if its been deleted- - deleted_priv_cmds = bot.get_my_commands(scope=private_scope) - deleted_grp_cmds = bot.get_my_commands(scope=group_scope, language_code='en') + deleted_priv_cmds = await bot.get_my_commands(scope=private_scope) + deleted_grp_cmds = await bot.get_my_commands(scope=group_scope, language_code='en') assert len(deleted_grp_cmds) == 0 == len(group_cmds) - 1 assert len(deleted_priv_cmds) == 0 == len(private_cmds) - 1 - bot.delete_my_commands() # Delete commands from default scope - assert len(bot.get_my_commands()) == 0 + await bot.delete_my_commands() # Delete commands from default scope + assert len(await bot.get_my_commands()) == 0 - def test_log_out(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_log_out(self, monkeypatch, bot): # We don't actually make a request as to not break the test setup - def assertion(url, data, *args, **kwargs): - return data == {} and url.split('/')[-1] == 'logOut' + async def assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters == {} and url.split('/')[-1] == 'logOut' monkeypatch.setattr(bot.request, 'post', assertion) - assert bot.log_out() + assert await bot.log_out() - def test_close(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_close(self, monkeypatch, bot): # We don't actually make a request as to not break the test setup - def assertion(url, data, *args, **kwargs): - return data == {} and url.split('/')[-1] == 'close' + async def assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters == {} and url.split('/')[-1] == 'close' monkeypatch.setattr(bot.request, 'post', assertion) - assert bot.close() + assert await bot.close() @flaky(3, 1) @pytest.mark.parametrize('json_keyboard', [True, False]) @pytest.mark.parametrize('caption', ["Test", '', None]) - def test_copy_message(self, monkeypatch, bot, chat_id, media_message, json_keyboard, caption): + @pytest.mark.asyncio + async def test_copy_message( + self, monkeypatch, bot, chat_id, media_message, json_keyboard, caption + ): keyboard = InlineKeyboardMarkup( [[InlineKeyboardButton(text="test", callback_data="test2")]] ) - def post(url, data, timeout): - assert data["chat_id"] == chat_id - assert data["from_chat_id"] == chat_id - assert data["message_id"] == media_message.message_id - assert data.get("caption") == caption - assert data["parse_mode"] == ParseMode.HTML - assert data["reply_to_message_id"] == media_message.message_id - assert data["reply_markup"] == keyboard.to_json() - assert data["disable_notification"] is True - assert data["caption_entities"] == [MessageEntity(MessageEntity.BOLD, 0, 4)] - assert data['protect_content'] is True + async def post(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters + if not all( + [ + data["chat_id"] == chat_id, + data["from_chat_id"] == chat_id, + data["message_id"] == media_message.message_id, + data.get("caption") == caption, + data["parse_mode"] == ParseMode.HTML, + data["reply_to_message_id"] == media_message.message_id, + data["reply_markup"] == keyboard.to_json() + if json_keyboard + else keyboard.to_dict(), + data["disable_notification"] is True, + data["caption_entities"] + == [MessageEntity(MessageEntity.BOLD, 0, 4).to_dict()], + data['protect_content'] is True, + ] + ): + pytest.fail('I got wrong parameters in post') return data monkeypatch.setattr(bot.request, 'post', post) - bot.copy_message( + await bot.copy_message( chat_id, from_chat_id=chat_id, message_id=media_message.message_id, @@ -2250,12 +2595,13 @@ def post(url, data, timeout): ) @flaky(3, 1) - def test_copy_message_without_reply(self, bot, chat_id, media_message): + @pytest.mark.asyncio + async def test_copy_message_without_reply(self, bot, chat_id, media_message): keyboard = InlineKeyboardMarkup( [[InlineKeyboardButton(text="test", callback_data="test2")]] ) - returned = bot.copy_message( + returned = await bot.copy_message( chat_id, from_chat_id=chat_id, message_id=media_message.message_id, @@ -2266,7 +2612,9 @@ def test_copy_message_without_reply(self, bot, chat_id, media_message): ) # we send a temp message which replies to the returned message id in order to get a # message object - temp_message = bot.send_message(chat_id, "test", reply_to_message_id=returned.message_id) + temp_message = await bot.send_message( + chat_id, "test", reply_to_message_id=returned.message_id + ) message = temp_message.reply_to_message assert message.chat_id == int(chat_id) assert message.caption == "Test" @@ -2283,12 +2631,13 @@ def test_copy_message_without_reply(self, bot, chat_id, media_message): ], indirect=['default_bot'], ) - def test_copy_message_with_default(self, default_bot, chat_id, media_message): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + @pytest.mark.asyncio + async def test_copy_message_with_default(self, default_bot, chat_id, media_message): + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if not default_bot.defaults.allow_sending_without_reply: with pytest.raises(BadRequest, match='not found'): - default_bot.copy_message( + await default_bot.copy_message( chat_id, from_chat_id=chat_id, message_id=media_message.message_id, @@ -2296,7 +2645,7 @@ def test_copy_message_with_default(self, default_bot, chat_id, media_message): reply_to_message_id=reply_to_message.message_id, ) return - returned = default_bot.copy_message( + returned = await default_bot.copy_message( chat_id, from_chat_id=chat_id, message_id=media_message.message_id, @@ -2305,7 +2654,7 @@ def test_copy_message_with_default(self, default_bot, chat_id, media_message): ) # we send a temp message which replies to the returned message id in order to get a # message object - temp_message = default_bot.send_message( + temp_message = await default_bot.send_message( chat_id, "test", reply_to_message_id=returned.message_id ) message = temp_message.reply_to_message @@ -2314,7 +2663,8 @@ def test_copy_message_with_default(self, default_bot, chat_id, media_message): else: assert len(message.caption_entities) == 0 - def test_replace_callback_data_send_message(self, bot, chat_id): + @pytest.mark.asyncio + async def test_replace_callback_data_send_message(self, bot, chat_id): try: bot.arbitrary_callback_data = True replace_button = InlineKeyboardButton(text='replace', callback_data='replace_test') @@ -2327,7 +2677,9 @@ def test_replace_callback_data_send_message(self, bot, chat_id): no_replace_button, ] ) - message = bot.send_message(chat_id=chat_id, text='test', reply_markup=reply_markup) + message = await bot.send_message( + chat_id=chat_id, text='test', reply_markup=reply_markup + ) inline_keyboard = message.reply_markup.inline_keyboard assert inline_keyboard[0][1] == no_replace_button @@ -2340,8 +2692,9 @@ def test_replace_callback_data_send_message(self, bot, chat_id): bot.callback_data_cache.clear_callback_data() bot.callback_data_cache.clear_callback_queries() - def test_replace_callback_data_stop_poll_and_repl_to_message(self, bot, chat_id): - poll_message = bot.send_poll(chat_id=chat_id, question='test', options=['1', '2']) + @pytest.mark.asyncio + async def test_replace_callback_data_stop_poll_and_repl_to_message(self, bot, chat_id): + poll_message = await bot.send_poll(chat_id=chat_id, question='test', options=['1', '2']) try: bot.arbitrary_callback_data = True replace_button = InlineKeyboardButton(text='replace', callback_data='replace_test') @@ -2354,8 +2707,8 @@ def test_replace_callback_data_stop_poll_and_repl_to_message(self, bot, chat_id) no_replace_button, ] ) - poll_message.stop_poll(reply_markup=reply_markup) - helper_message = poll_message.reply_text('temp', quote=True) + await poll_message.stop_poll(reply_markup=reply_markup) + helper_message = await poll_message.reply_text('temp', quote=True) message = helper_message.reply_to_message inline_keyboard = message.reply_markup.inline_keyboard @@ -2369,10 +2722,11 @@ def test_replace_callback_data_stop_poll_and_repl_to_message(self, bot, chat_id) bot.callback_data_cache.clear_callback_data() bot.callback_data_cache.clear_callback_queries() - def test_replace_callback_data_copy_message(self, bot, chat_id): + @pytest.mark.asyncio + async def test_replace_callback_data_copy_message(self, bot, chat_id): """This also tests that data is inserted into the buttons of message.reply_to_message where message is the return value of a bot method""" - original_message = bot.send_message(chat_id=chat_id, text='original') + original_message = await bot.send_message(chat_id=chat_id, text='original') try: bot.arbitrary_callback_data = True replace_button = InlineKeyboardButton(text='replace', callback_data='replace_test') @@ -2385,8 +2739,8 @@ def test_replace_callback_data_copy_message(self, bot, chat_id): no_replace_button, ] ) - message_id = original_message.copy(chat_id=chat_id, reply_markup=reply_markup) - helper_message = bot.send_message( + message_id = await original_message.copy(chat_id=chat_id, reply_markup=reply_markup) + helper_message = await bot.send_message( chat_id=chat_id, reply_to_message_id=message_id.message_id, text='temp' ) message = helper_message.reply_to_message @@ -2403,17 +2757,16 @@ def test_replace_callback_data_copy_message(self, bot, chat_id): bot.callback_data_cache.clear_callback_queries() # TODO: Needs improvement. We need incoming inline query to test answer. - def test_replace_callback_data_answer_inline_query(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_replace_callback_data_answer_inline_query(self, monkeypatch, bot, chat_id): # For now just test that our internals pass the correct data - def make_assertion( + async def make_assertion( endpoint, data=None, - timeout=None, - api_kwargs=None, + *args, + **kwargs, ): - inline_keyboard = InlineKeyboardMarkup.de_json( - data['results'][0]['reply_markup'], bot - ).inline_keyboard + inline_keyboard = data['results'][0]['reply_markup'].inline_keyboard assertion_1 = inline_keyboard[0][1] == no_replace_button assertion_2 = inline_keyboard[0][0] != replace_button keyboard, button = ( @@ -2424,7 +2777,7 @@ def make_assertion( bot.callback_data_cache._keyboard_data[keyboard].button_data[button] == 'replace_test' ) - assertion_4 = 'reply_markup' not in data['results'][1] + assertion_4 = data['results'][1].reply_markup is None return assertion_1 and assertion_2 and assertion_3 and assertion_4 try: @@ -2453,46 +2806,48 @@ def make_assertion( ), ] - assert bot.answer_inline_query(chat_id, results=results) + assert await bot.answer_inline_query(chat_id, results=results) finally: bot.arbitrary_callback_data = False bot.callback_data_cache.clear_callback_data() bot.callback_data_cache.clear_callback_queries() - def test_get_chat_arbitrary_callback_data(self, super_group_id, bot): + @pytest.mark.asyncio + async def test_get_chat_arbitrary_callback_data(self, super_group_id, bot): try: bot.arbitrary_callback_data = True reply_markup = InlineKeyboardMarkup.from_button( InlineKeyboardButton(text='text', callback_data='callback_data') ) - message = bot.send_message( + message = await bot.send_message( super_group_id, text='get_chat_arbitrary_callback_data', reply_markup=reply_markup ) - message.pin() + await message.pin() keyboard = list(bot.callback_data_cache._keyboard_data)[0] data = list(bot.callback_data_cache._keyboard_data[keyboard].button_data.values())[0] assert data == 'callback_data' - chat = bot.get_chat(super_group_id) + chat = await bot.get_chat(super_group_id) assert chat.pinned_message == message assert chat.pinned_message.reply_markup == reply_markup finally: bot.arbitrary_callback_data = False bot.callback_data_cache.clear_callback_data() bot.callback_data_cache.clear_callback_queries() - bot.unpin_all_chat_messages(super_group_id) + await bot.unpin_all_chat_messages(super_group_id) # In the following tests we check that get_updates inserts callback data correctly if necessary # The same must be done in the webhook updater. This is tested over at test_updater.py, but # here we test more extensively. - def test_arbitrary_callback_data_no_insert(self, monkeypatch, bot): + @pytest.mark.asyncio + async def test_arbitrary_callback_data_no_insert(self, monkeypatch, bot): """Updates that don't need insertion shouldn.t fail obviously""" - def post(*args, **kwargs): + async def post(*args, **kwargs): update = Update( 17, poll=Poll( @@ -2510,9 +2865,9 @@ def post(*args, **kwargs): try: bot.arbitrary_callback_data = True - monkeypatch.setattr(bot.request, 'post', post) - bot.delete_webhook() # make sure there is no webhook set if webhook tests failed - updates = bot.get_updates(timeout=1) + monkeypatch.setattr(BaseRequest, 'post', post) + await bot.delete_webhook() # make sure there is no webhook set if webhook tests failed + updates = await bot.get_updates(timeout=1) assert len(updates) == 1 assert updates[0].update_id == 17 @@ -2523,7 +2878,8 @@ def post(*args, **kwargs): @pytest.mark.parametrize( 'message_type', ['channel_post', 'edited_channel_post', 'message', 'edited_message'] ) - def test_arbitrary_callback_data_pinned_message_reply_to_message( + @pytest.mark.asyncio + async def test_arbitrary_callback_data_pinned_message_reply_to_message( self, super_group_id, bot, monkeypatch, message_type ): bot.arbitrary_callback_data = True @@ -2537,7 +2893,7 @@ def test_arbitrary_callback_data_pinned_message_reply_to_message( # We do to_dict -> de_json to make sure those aren't the same objects message.pinned_message = Message.de_json(message.to_dict(), bot) - def post(*args, **kwargs): + async def post(*args, **kwargs): update = Update( 17, **{ @@ -2553,9 +2909,9 @@ def post(*args, **kwargs): return [update.to_dict()] try: - monkeypatch.setattr(bot.request, 'post', post) - bot.delete_webhook() # make sure there is no webhook set if webhook tests failed - updates = bot.get_updates(timeout=1) + monkeypatch.setattr(BaseRequest, 'post', post) + await bot.delete_webhook() # make sure there is no webhook set if webhook tests failed + updates = await bot.get_updates(timeout=1) assert isinstance(updates, list) assert len(updates) == 1 @@ -2580,12 +2936,13 @@ def post(*args, **kwargs): bot.callback_data_cache.clear_callback_data() bot.callback_data_cache.clear_callback_queries() - def test_arbitrary_callback_data_get_chat_no_pinned_message(self, super_group_id, bot): + @pytest.mark.asyncio + async def test_arbitrary_callback_data_get_chat_no_pinned_message(self, super_group_id, bot): bot.arbitrary_callback_data = True - bot.unpin_all_chat_messages(super_group_id) + await bot.unpin_all_chat_messages(super_group_id) try: - chat = bot.get_chat(super_group_id) + chat = await bot.get_chat(super_group_id) assert isinstance(chat, Chat) assert int(chat.id) == int(super_group_id) @@ -2597,7 +2954,8 @@ def test_arbitrary_callback_data_get_chat_no_pinned_message(self, super_group_id 'message_type', ['channel_post', 'edited_channel_post', 'message', 'edited_message'] ) @pytest.mark.parametrize('self_sender', [True, False]) - def test_arbitrary_callback_data_via_bot( + @pytest.mark.asyncio + async def test_arbitrary_callback_data_via_bot( self, super_group_id, bot, monkeypatch, self_sender, message_type ): bot.arbitrary_callback_data = True @@ -2614,13 +2972,13 @@ def test_arbitrary_callback_data_via_bot( via_bot=bot.bot if self_sender else User(1, 'first', False), ) - def post(*args, **kwargs): + async def post(*args, **kwargs): return [Update(17, **{message_type: message}).to_dict()] try: - monkeypatch.setattr(bot.request, 'post', post) - bot.delete_webhook() # make sure there is no webhook set if webhook tests failed - updates = bot.get_updates(timeout=1) + monkeypatch.setattr(BaseRequest, 'post', post) + await bot.delete_webhook() # make sure there is no webhook set if webhook tests failed + updates = await bot.get_updates(timeout=1) assert isinstance(updates, list) assert len(updates) == 1 @@ -2652,7 +3010,7 @@ def test_camel_case_bot(self): if ( function_name.startswith("_") or not callable(function) - or function_name == "to_dict" + or function_name in ["to_dict", "do_init", "do_teardown"] ): continue camel_case_function = getattr(Bot, to_camel_case(function_name), False) diff --git a/tests/test_builders.py b/tests/test_builders.py deleted file mode 100644 index 9fff1ae1de0..00000000000 --- a/tests/test_builders.py +++ /dev/null @@ -1,279 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2021 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. - -""" -We mainly test on UpdaterBuilder because it has all methods that DispatcherBuilder already has -""" -from pathlib import Path -from random import randint -from threading import Event - -import pytest - -from telegram.request import Request -from .conftest import PRIVATE_KEY - -from telegram.ext import ( - UpdaterBuilder, - DispatcherBuilder, - Defaults, - JobQueue, - PicklePersistence, - ContextTypes, - Dispatcher, - Updater, -) -from telegram.ext._builders import _BOT_CHECKS, _DISPATCHER_CHECKS, _BaseBuilder - - -@pytest.fixture( - scope='function', - params=[{'class': UpdaterBuilder}, {'class': DispatcherBuilder}], - ids=['UpdaterBuilder', 'DispatcherBuilder'], -) -def builder(request): - return request.param['class']() - - -class TestBuilder: - @pytest.mark.parametrize('workers', [randint(1, 100) for _ in range(10)]) - def test_get_connection_pool_size(self, workers): - assert _BaseBuilder._get_connection_pool_size(workers) == workers + 4 - - @pytest.mark.parametrize( - 'method, description', _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS] - ) - def test_mutually_exclusive_for_bot(self, builder, method, description): - if getattr(builder, method, None) is None: - pytest.skip(f'{builder.__class__} has no method called {method}') - - # First that e.g. `bot` can't be set if `request` was already set - # We pass the private key since `private_key` is the only method that doesn't just save - # the passed value - getattr(builder, method)(Path('tests/data/private.key')) - with pytest.raises(RuntimeError, match=f'`bot` may only be set, if no {description}'): - builder.bot(None) - - # Now test that `request` can't be set if `bot` was already set - builder = builder.__class__() - builder.bot(None) - with pytest.raises(RuntimeError, match=f'`{method}` may only be set, if no bot instance'): - getattr(builder, method)(None) - - @pytest.mark.parametrize( - 'method, description', _DISPATCHER_CHECKS, ids=[entry[0] for entry in _DISPATCHER_CHECKS] - ) - def test_mutually_exclusive_for_dispatcher(self, builder, method, description): - if isinstance(builder, DispatcherBuilder): - pytest.skip('This test is only relevant for UpdaterBuilder') - - if getattr(builder, method, None) is None: - pytest.skip(f'{builder.__class__} has no method called {method}') - - # First that e.g. `dispatcher` can't be set if `bot` was already set - # We pass the private key since `private_key` is the only method that doesn't just save - # the passed value - getattr(builder, method)(Path('tests/data/private.key')) - with pytest.raises( - RuntimeError, match=f'`dispatcher` may only be set, if no {description}' - ): - builder.dispatcher(None) - - # Now test that `bot` can't be set if `dispatcher` was already set - builder = builder.__class__() - builder.dispatcher(1) - with pytest.raises( - RuntimeError, match=f'`{method}` may only be set, if no Dispatcher instance' - ): - getattr(builder, method)(None) - - # Finally test that `bot` *can* be set if `dispatcher` was set to None - builder = builder.__class__() - builder.dispatcher(None) - if method != 'dispatcher_class': - # We pass the private key since `private_key` is the only method that doesn't just save - # the passed value - getattr(builder, method)(Path('tests/data/private.key')) - else: - with pytest.raises( - RuntimeError, match=f'`{method}` may only be set, if no Dispatcher instance' - ): - getattr(builder, method)(None) - - def test_mutually_exclusive_for_request(self, builder): - builder.request(None) - with pytest.raises( - RuntimeError, match='`request_kwargs` may only be set, if no Request instance' - ): - builder.request_kwargs(None) - - builder = builder.__class__() - builder.request_kwargs(None) - with pytest.raises(RuntimeError, match='`request` may only be set, if no request_kwargs'): - builder.request(None) - - def test_build_without_token(self, builder): - with pytest.raises(RuntimeError, match='No bot token was set.'): - builder.build() - - def test_build_custom_bot(self, builder, bot): - builder.bot(bot) - obj = builder.build() - assert obj.bot is bot - - if isinstance(obj, Updater): - assert obj.dispatcher.bot is bot - assert obj.dispatcher.job_queue.dispatcher is obj.dispatcher - assert obj.exception_event is obj.dispatcher.exception_event - - def test_build_custom_dispatcher(self, dp): - updater = UpdaterBuilder().dispatcher(dp).build() - assert updater.dispatcher is dp - assert updater.bot is updater.dispatcher.bot - assert updater.exception_event is dp.exception_event - - def test_build_no_dispatcher(self, bot): - updater = UpdaterBuilder().dispatcher(None).token(bot.token).build() - assert updater.dispatcher is None - assert updater.bot.token == bot.token - assert updater.bot.request.con_pool_size == 8 - assert isinstance(updater.exception_event, Event) - - def test_all_bot_args_custom(self, builder, bot): - defaults = Defaults() - request = Request(8) - builder.token(bot.token).base_url('base_url').base_file_url('base_file_url').private_key( - PRIVATE_KEY - ).defaults(defaults).arbitrary_callback_data(42).request(request) - built_bot = builder.build().bot - - assert built_bot.token == bot.token - assert built_bot.base_url == 'base_url' + bot.token - assert built_bot.base_file_url == 'base_file_url' + bot.token - assert built_bot.defaults is defaults - assert built_bot.request is request - assert built_bot.callback_data_cache.maxsize == 42 - - builder = builder.__class__() - builder.token(bot.token).request_kwargs({'connect_timeout': 42}) - built_bot = builder.build().bot - - assert built_bot.token == bot.token - assert built_bot.request._connect_timeout == 42 - - def test_all_dispatcher_args_custom(self, dp): - builder = DispatcherBuilder() - - job_queue = JobQueue() - persistence = PicklePersistence('filename') - context_types = ContextTypes() - builder.bot(dp.bot).update_queue(dp.update_queue).exception_event( - dp.exception_event - ).job_queue(job_queue).persistence(persistence).context_types(context_types).workers(3) - dispatcher = builder.build() - - assert dispatcher.bot is dp.bot - assert dispatcher.update_queue is dp.update_queue - assert dispatcher.exception_event is dp.exception_event - assert dispatcher.job_queue is job_queue - assert dispatcher.job_queue.dispatcher is dispatcher - assert dispatcher.persistence is persistence - assert dispatcher.context_types is context_types - assert dispatcher.workers == 3 - - def test_all_updater_args_custom(self, dp): - updater = ( - UpdaterBuilder() - .dispatcher(None) - .bot(dp.bot) - .exception_event(dp.exception_event) - .update_queue(dp.update_queue) - .user_signal_handler(42) - .build() - ) - - assert updater.dispatcher is None - assert updater.bot is dp.bot - assert updater.exception_event is dp.exception_event - assert updater.update_queue is dp.update_queue - assert updater.user_signal_handler == 42 - - def test_connection_pool_size_with_workers(self, bot, builder): - obj = builder.token(bot.token).workers(42).build() - dispatcher = obj if isinstance(obj, Dispatcher) else obj.dispatcher - assert dispatcher.workers == 42 - assert dispatcher.bot.request.con_pool_size == 46 - - def test_connection_pool_size_warning(self, bot, builder, recwarn): - builder.token(bot.token).workers(42).request_kwargs({'con_pool_size': 1}) - obj = builder.build() - dispatcher = obj if isinstance(obj, Dispatcher) else obj.dispatcher - assert dispatcher.workers == 42 - assert dispatcher.bot.request.con_pool_size == 1 - - assert len(recwarn) == 1 - message = str(recwarn[-1].message) - assert 'smaller (1)' in message - assert 'recommended value of 46.' in message - assert recwarn[-1].filename == __file__, "wrong stacklevel" - - def test_custom_classes(self, bot, builder): - class CustomDispatcher(Dispatcher): - def __init__(self, arg, **kwargs): - super().__init__(**kwargs) - self.arg = arg - - class CustomUpdater(Updater): - def __init__(self, arg, **kwargs): - super().__init__(**kwargs) - self.arg = arg - - builder.dispatcher_class(CustomDispatcher, kwargs={'arg': 2}).token(bot.token) - if isinstance(builder, UpdaterBuilder): - builder.updater_class(CustomUpdater, kwargs={'arg': 1}) - - obj = builder.build() - - if isinstance(builder, UpdaterBuilder): - assert isinstance(obj, CustomUpdater) - assert obj.arg == 1 - assert isinstance(obj.dispatcher, CustomDispatcher) - assert obj.dispatcher.arg == 2 - else: - assert isinstance(obj, CustomDispatcher) - assert obj.arg == 2 - - @pytest.mark.parametrize('input_type', ('bytes', 'str', 'Path')) - def test_all_private_key_input_types(self, builder, bot, input_type): - private_key = Path('tests/data/private.key') - password = Path('tests/data/private_key.password') - - if input_type == 'bytes': - private_key = private_key.read_bytes() - password = password.read_bytes() - if input_type == 'str': - private_key = str(private_key) - password = str(password) - - builder.token(bot.token).private_key( - private_key=private_key, - password=password, - ) - bot = builder.build().bot - assert bot.private_key diff --git a/tests/test_callbackcontext.py b/tests/test_callbackcontext.py deleted file mode 100644 index b1114746ba6..00000000000 --- a/tests/test_callbackcontext.py +++ /dev/null @@ -1,228 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. - -import pytest - -from telegram import ( - Update, - Message, - Chat, - User, - Bot, - InlineKeyboardMarkup, - InlineKeyboardButton, - CallbackQuery, -) -from telegram.ext import CallbackContext -from telegram.error import TelegramError - -""" -CallbackContext.refresh_data is tested in TestBasePersistence -""" - - -class TestCallbackContext: - def test_slot_behaviour(self, dp, mro_slots, recwarn): - c = CallbackContext(dp) - for attr in c.__slots__: - assert getattr(c, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert not c.__dict__, f"got missing slot(s): {c.__dict__}" - assert len(mro_slots(c)) == len(set(mro_slots(c))), "duplicate slot" - c.args = c.args - assert len(recwarn) == 0, recwarn.list - - def test_from_job(self, dp): - job = dp.job_queue.run_once(lambda x: x, 10) - - callback_context = CallbackContext.from_job(job, dp) - - assert callback_context.job is job - assert callback_context.chat_data is None - assert callback_context.user_data is None - assert callback_context.bot_data is dp.bot_data - assert callback_context.bot is dp.bot - assert callback_context.job_queue is dp.job_queue - assert callback_context.update_queue is dp.update_queue - - def test_from_update(self, dp): - update = Update( - 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) - ) - - callback_context = CallbackContext.from_update(update, dp) - - assert callback_context.chat_data == {} - assert callback_context.user_data == {} - assert callback_context.bot_data is dp.bot_data - assert callback_context.bot is dp.bot - assert callback_context.job_queue is dp.job_queue - assert callback_context.update_queue is dp.update_queue - - callback_context_same_user_chat = CallbackContext.from_update(update, dp) - - callback_context.bot_data['test'] = 'bot' - callback_context.chat_data['test'] = 'chat' - callback_context.user_data['test'] = 'user' - - assert callback_context_same_user_chat.bot_data is callback_context.bot_data - assert callback_context_same_user_chat.chat_data is callback_context.chat_data - assert callback_context_same_user_chat.user_data is callback_context.user_data - - update_other_user_chat = Update( - 0, message=Message(0, None, Chat(2, 'chat'), from_user=User(2, 'user', False)) - ) - - callback_context_other_user_chat = CallbackContext.from_update(update_other_user_chat, dp) - - assert callback_context_other_user_chat.bot_data is callback_context.bot_data - assert callback_context_other_user_chat.chat_data is not callback_context.chat_data - assert callback_context_other_user_chat.user_data is not callback_context.user_data - - def test_from_update_not_update(self, dp): - callback_context = CallbackContext.from_update(None, dp) - - assert callback_context.chat_data is None - assert callback_context.user_data is None - assert callback_context.bot_data is dp.bot_data - assert callback_context.bot is dp.bot - assert callback_context.job_queue is dp.job_queue - assert callback_context.update_queue is dp.update_queue - - callback_context = CallbackContext.from_update('', dp) - - assert callback_context.chat_data is None - assert callback_context.user_data is None - assert callback_context.bot_data is dp.bot_data - assert callback_context.bot is dp.bot - assert callback_context.job_queue is dp.job_queue - assert callback_context.update_queue is dp.update_queue - - def test_from_error(self, dp): - error = TelegramError('test') - - update = Update( - 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) - ) - - callback_context = CallbackContext.from_error(update, error, dp) - - assert callback_context.error is error - assert callback_context.chat_data == {} - assert callback_context.user_data == {} - assert callback_context.bot_data is dp.bot_data - assert callback_context.bot is dp.bot - assert callback_context.job_queue is dp.job_queue - assert callback_context.update_queue is dp.update_queue - assert callback_context.async_args is None - assert callback_context.async_kwargs is None - - def test_from_error_async_params(self, dp): - error = TelegramError('test') - - args = [1, '2'] - kwargs = {'one': 1, 2: 'two'} - - callback_context = CallbackContext.from_error( - None, error, dp, async_args=args, async_kwargs=kwargs - ) - - assert callback_context.error is error - assert callback_context.async_args is args - assert callback_context.async_kwargs is kwargs - - def test_match(self, dp): - callback_context = CallbackContext(dp) - - assert callback_context.match is None - - callback_context.matches = ['test', 'blah'] - - assert callback_context.match == 'test' - - def test_data_assignment(self, dp): - update = Update( - 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) - ) - - callback_context = CallbackContext.from_update(update, dp) - - with pytest.raises(AttributeError): - callback_context.bot_data = {"test": 123} - with pytest.raises(AttributeError): - callback_context.user_data = {} - with pytest.raises(AttributeError): - callback_context.chat_data = "test" - - def test_dispatcher_attribute(self, dp): - callback_context = CallbackContext(dp) - assert callback_context.dispatcher == dp - - def test_drop_callback_data_exception(self, bot, dp): - non_ext_bot = Bot(bot.token) - update = Update( - 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) - ) - - callback_context = CallbackContext.from_update(update, dp) - - with pytest.raises(RuntimeError, match='This telegram.ext.ExtBot instance does not'): - callback_context.drop_callback_data(None) - - try: - dp.bot = non_ext_bot - with pytest.raises(RuntimeError, match='telegram.Bot does not allow for'): - callback_context.drop_callback_data(None) - finally: - dp.bot = bot - - def test_drop_callback_data(self, dp, monkeypatch, chat_id): - monkeypatch.setattr(dp.bot, 'arbitrary_callback_data', True) - - update = Update( - 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) - ) - - callback_context = CallbackContext.from_update(update, dp) - dp.bot.send_message( - chat_id=chat_id, - text='test', - reply_markup=InlineKeyboardMarkup.from_button( - InlineKeyboardButton('test', callback_data='callback_data') - ), - ) - keyboard_uuid = dp.bot.callback_data_cache.persistence_data[0][0][0] - button_uuid = list(dp.bot.callback_data_cache.persistence_data[0][0][2])[0] - callback_data = keyboard_uuid + button_uuid - callback_query = CallbackQuery( - id='1', - from_user=None, - chat_instance=None, - data=callback_data, - ) - dp.bot.callback_data_cache.process_callback_query(callback_query) - - try: - assert len(dp.bot.callback_data_cache.persistence_data[0]) == 1 - assert list(dp.bot.callback_data_cache.persistence_data[1]) == ['1'] - - callback_context.drop_callback_data(callback_query) - assert dp.bot.callback_data_cache.persistence_data == ([], {}) - finally: - dp.bot.callback_data_cache.clear_callback_data() - dp.bot.callback_data_cache.clear_callback_queries() diff --git a/tests/test_callbackdatacache.py b/tests/test_callbackdatacache.py deleted file mode 100644 index 1d97022d29c..00000000000 --- a/tests/test_callbackdatacache.py +++ /dev/null @@ -1,381 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import time -from copy import deepcopy -from datetime import datetime -from uuid import uuid4 - -import pytest -import pytz - -from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User -from telegram.ext._callbackdatacache import ( - CallbackDataCache, - _KeyboardData, - InvalidCallbackData, -) - - -@pytest.fixture(scope='function') -def callback_data_cache(bot): - return CallbackDataCache(bot) - - -class TestInvalidCallbackData: - def test_slot_behaviour(self, mro_slots): - invalid_callback_data = InvalidCallbackData() - for attr in invalid_callback_data.__slots__: - assert getattr(invalid_callback_data, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(invalid_callback_data)) == len( - set(mro_slots(invalid_callback_data)) - ), "duplicate slot" - - -class TestKeyboardData: - def test_slot_behaviour(self, mro_slots): - keyboard_data = _KeyboardData('uuid') - for attr in keyboard_data.__slots__: - assert getattr(keyboard_data, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(keyboard_data)) == len( - set(mro_slots(keyboard_data)) - ), "duplicate slot" - - -class TestCallbackDataCache: - def test_slot_behaviour(self, callback_data_cache, mro_slots): - for attr in callback_data_cache.__slots__: - attr = ( - f"_CallbackDataCache{attr}" - if attr.startswith('__') and not attr.endswith('__') - else attr - ) - assert getattr(callback_data_cache, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(callback_data_cache)) == len( - set(mro_slots(callback_data_cache)) - ), "duplicate slot" - - @pytest.mark.parametrize('maxsize', [1, 5, 2048]) - def test_init_maxsize(self, maxsize, bot): - assert CallbackDataCache(bot).maxsize == 1024 - cdc = CallbackDataCache(bot, maxsize=maxsize) - assert cdc.maxsize == maxsize - assert cdc.bot is bot - - def test_init_and_access__persistent_data(self, bot): - keyboard_data = _KeyboardData('123', 456, {'button': 678}) - persistent_data = ([keyboard_data.to_tuple()], {'id': '123'}) - cdc = CallbackDataCache(bot, persistent_data=persistent_data) - - assert cdc.maxsize == 1024 - assert dict(cdc._callback_queries) == {'id': '123'} - assert list(cdc._keyboard_data.keys()) == ['123'] - assert cdc._keyboard_data['123'].keyboard_uuid == '123' - assert cdc._keyboard_data['123'].access_time == 456 - assert cdc._keyboard_data['123'].button_data == {'button': 678} - - assert cdc.persistence_data == persistent_data - - def test_process_keyboard(self, callback_data_cache): - changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') - changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') - non_changing_button = InlineKeyboardButton('non-changing', url='https://ptb.org') - reply_markup = InlineKeyboardMarkup.from_row( - [non_changing_button, changing_button_1, changing_button_2] - ) - - out = callback_data_cache.process_keyboard(reply_markup) - assert out.inline_keyboard[0][0] is non_changing_button - assert out.inline_keyboard[0][1] != changing_button_1 - assert out.inline_keyboard[0][2] != changing_button_2 - - keyboard_1, button_1 = callback_data_cache.extract_uuids( - out.inline_keyboard[0][1].callback_data - ) - keyboard_2, button_2 = callback_data_cache.extract_uuids( - out.inline_keyboard[0][2].callback_data - ) - assert keyboard_1 == keyboard_2 - assert ( - callback_data_cache._keyboard_data[keyboard_1].button_data[button_1] == 'some data 1' - ) - assert ( - callback_data_cache._keyboard_data[keyboard_2].button_data[button_2] == 'some data 2' - ) - - def test_process_keyboard_no_changing_button(self, callback_data_cache): - reply_markup = InlineKeyboardMarkup.from_button( - InlineKeyboardButton('non-changing', url='https://ptb.org') - ) - assert callback_data_cache.process_keyboard(reply_markup) is reply_markup - - def test_process_keyboard_full(self, bot): - cdc = CallbackDataCache(bot, maxsize=1) - changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') - changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') - non_changing_button = InlineKeyboardButton('non-changing', url='https://ptb.org') - reply_markup = InlineKeyboardMarkup.from_row( - [non_changing_button, changing_button_1, changing_button_2] - ) - - out1 = cdc.process_keyboard(reply_markup) - assert len(cdc.persistence_data[0]) == 1 - out2 = cdc.process_keyboard(reply_markup) - assert len(cdc.persistence_data[0]) == 1 - - keyboard_1, button_1 = cdc.extract_uuids(out1.inline_keyboard[0][1].callback_data) - keyboard_2, button_2 = cdc.extract_uuids(out2.inline_keyboard[0][2].callback_data) - assert cdc.persistence_data[0][0][0] != keyboard_1 - assert cdc.persistence_data[0][0][0] == keyboard_2 - - @pytest.mark.parametrize('data', [True, False]) - @pytest.mark.parametrize('message', [True, False]) - @pytest.mark.parametrize('invalid', [True, False]) - def test_process_callback_query(self, callback_data_cache, data, message, invalid): - """This also tests large parts of process_message""" - changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') - changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') - non_changing_button = InlineKeyboardButton('non-changing', url='https://ptb.org') - reply_markup = InlineKeyboardMarkup.from_row( - [non_changing_button, changing_button_1, changing_button_2] - ) - - out = callback_data_cache.process_keyboard(reply_markup) - if invalid: - callback_data_cache.clear_callback_data() - - effective_message = Message(message_id=1, date=None, chat=None, reply_markup=out) - effective_message.reply_to_message = deepcopy(effective_message) - effective_message.pinned_message = deepcopy(effective_message) - cq_id = uuid4().hex - callback_query = CallbackQuery( - cq_id, - from_user=None, - chat_instance=None, - # not all CallbackQueries have callback_data - data=out.inline_keyboard[0][1].callback_data if data else None, - # CallbackQueries from inline messages don't have the message attached, so we test that - message=effective_message if message else None, - ) - callback_data_cache.process_callback_query(callback_query) - - if not invalid: - if data: - assert callback_query.data == 'some data 1' - # make sure that we stored the mapping CallbackQuery.id -> keyboard_uuid correctly - assert len(callback_data_cache._keyboard_data) == 1 - assert ( - callback_data_cache._callback_queries[cq_id] - == list(callback_data_cache._keyboard_data.keys())[0] - ) - else: - assert callback_query.data is None - if message: - for msg in ( - callback_query.message, - callback_query.message.reply_to_message, - callback_query.message.pinned_message, - ): - assert msg.reply_markup == reply_markup - else: - if data: - assert isinstance(callback_query.data, InvalidCallbackData) - else: - assert callback_query.data is None - if message: - for msg in ( - callback_query.message, - callback_query.message.reply_to_message, - callback_query.message.pinned_message, - ): - assert isinstance( - msg.reply_markup.inline_keyboard[0][1].callback_data, - InvalidCallbackData, - ) - assert isinstance( - msg.reply_markup.inline_keyboard[0][2].callback_data, - InvalidCallbackData, - ) - - @pytest.mark.parametrize('pass_from_user', [True, False]) - @pytest.mark.parametrize('pass_via_bot', [True, False]) - def test_process_message_wrong_sender(self, pass_from_user, pass_via_bot, callback_data_cache): - reply_markup = InlineKeyboardMarkup.from_button( - InlineKeyboardButton('test', callback_data='callback_data') - ) - user = User(1, 'first', False) - message = Message( - 1, - None, - None, - from_user=user if pass_from_user else None, - via_bot=user if pass_via_bot else None, - reply_markup=reply_markup, - ) - callback_data_cache.process_message(message) - if pass_from_user or pass_via_bot: - # Here we can determine that the message is not from our bot, so no replacing - assert message.reply_markup.inline_keyboard[0][0].callback_data == 'callback_data' - else: - # Here we have no chance to know, so InvalidCallbackData - assert isinstance( - message.reply_markup.inline_keyboard[0][0].callback_data, InvalidCallbackData - ) - - @pytest.mark.parametrize('pass_from_user', [True, False]) - def test_process_message_inline_mode(self, pass_from_user, callback_data_cache): - """Check that via_bot tells us correctly that our bot sent the message, even if - from_user is not our bot.""" - reply_markup = InlineKeyboardMarkup.from_button( - InlineKeyboardButton('test', callback_data='callback_data') - ) - user = User(1, 'first', False) - message = Message( - 1, - None, - None, - from_user=user if pass_from_user else None, - via_bot=callback_data_cache.bot.bot, - reply_markup=callback_data_cache.process_keyboard(reply_markup), - ) - callback_data_cache.process_message(message) - # Here we can determine that the message is not from our bot, so no replacing - assert message.reply_markup.inline_keyboard[0][0].callback_data == 'callback_data' - - def test_process_message_no_reply_markup(self, callback_data_cache): - message = Message(1, None, None) - callback_data_cache.process_message(message) - assert message.reply_markup is None - - def test_drop_data(self, callback_data_cache): - changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') - changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') - reply_markup = InlineKeyboardMarkup.from_row([changing_button_1, changing_button_2]) - - out = callback_data_cache.process_keyboard(reply_markup) - callback_query = CallbackQuery( - '1', - from_user=None, - chat_instance=None, - data=out.inline_keyboard[0][1].callback_data, - ) - callback_data_cache.process_callback_query(callback_query) - - assert len(callback_data_cache.persistence_data[1]) == 1 - assert len(callback_data_cache.persistence_data[0]) == 1 - - callback_data_cache.drop_data(callback_query) - assert len(callback_data_cache.persistence_data[1]) == 0 - assert len(callback_data_cache.persistence_data[0]) == 0 - - def test_drop_data_missing_data(self, callback_data_cache): - changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') - changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') - reply_markup = InlineKeyboardMarkup.from_row([changing_button_1, changing_button_2]) - - out = callback_data_cache.process_keyboard(reply_markup) - callback_query = CallbackQuery( - '1', - from_user=None, - chat_instance=None, - data=out.inline_keyboard[0][1].callback_data, - ) - - with pytest.raises(KeyError, match='CallbackQuery was not found in cache.'): - callback_data_cache.drop_data(callback_query) - - callback_data_cache.process_callback_query(callback_query) - callback_data_cache.clear_callback_data() - callback_data_cache.drop_data(callback_query) - assert callback_data_cache.persistence_data == ([], {}) - - @pytest.mark.parametrize('method', ('callback_data', 'callback_queries')) - def test_clear_all(self, callback_data_cache, method): - changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') - changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') - reply_markup = InlineKeyboardMarkup.from_row([changing_button_1, changing_button_2]) - - for i in range(100): - out = callback_data_cache.process_keyboard(reply_markup) - callback_query = CallbackQuery( - str(i), - from_user=None, - chat_instance=None, - data=out.inline_keyboard[0][1].callback_data, - ) - callback_data_cache.process_callback_query(callback_query) - - if method == 'callback_data': - callback_data_cache.clear_callback_data() - # callback_data was cleared, callback_queries weren't - assert len(callback_data_cache.persistence_data[0]) == 0 - assert len(callback_data_cache.persistence_data[1]) == 100 - else: - callback_data_cache.clear_callback_queries() - # callback_queries were cleared, callback_data wasn't - assert len(callback_data_cache.persistence_data[0]) == 100 - assert len(callback_data_cache.persistence_data[1]) == 0 - - @pytest.mark.parametrize('time_method', ['time', 'datetime', 'defaults']) - def test_clear_cutoff(self, callback_data_cache, time_method, tz_bot): - # Fill the cache with some fake data - for i in range(50): - reply_markup = InlineKeyboardMarkup.from_button( - InlineKeyboardButton('changing', callback_data=str(i)) - ) - out = callback_data_cache.process_keyboard(reply_markup) - callback_query = CallbackQuery( - str(i), - from_user=None, - chat_instance=None, - data=out.inline_keyboard[0][0].callback_data, - ) - callback_data_cache.process_callback_query(callback_query) - - # sleep a bit before saving the time cutoff, to make test more reliable - time.sleep(0.1) - if time_method == 'time': - cutoff = time.time() - elif time_method == 'datetime': - cutoff = datetime.now(pytz.utc) - else: - cutoff = datetime.now(tz_bot.defaults.tzinfo).replace(tzinfo=None) - callback_data_cache.bot = tz_bot - time.sleep(0.1) - - # more fake data after the time cutoff - for i in range(50, 100): - reply_markup = InlineKeyboardMarkup.from_button( - InlineKeyboardButton('changing', callback_data=str(i)) - ) - out = callback_data_cache.process_keyboard(reply_markup) - callback_query = CallbackQuery( - str(i), - from_user=None, - chat_instance=None, - data=out.inline_keyboard[0][0].callback_data, - ) - callback_data_cache.process_callback_query(callback_query) - - callback_data_cache.clear_callback_data(time_cutoff=cutoff) - assert len(callback_data_cache.persistence_data[0]) == 50 - assert len(callback_data_cache.persistence_data[1]) == 100 - callback_data = [ - list(data[2].values())[0] for data in callback_data_cache.persistence_data[0] - ] - assert callback_data == list(str(i) for i in range(50, 100)) diff --git a/tests/test_callbackquery.py b/tests/test_callbackquery.py index 6098444a58f..7e3dcb0c22f 100644 --- a/tests/test_callbackquery.py +++ b/tests/test_callbackquery.py @@ -50,11 +50,6 @@ class TestCallbackQuery: inline_message_id = 'inline_message_id' game_short_name = 'the_game' - def test_slot_behaviour(self, callback_query, mro_slots): - for attr in callback_query.__slots__: - assert getattr(callback_query, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(callback_query)) == len(set(mro_slots(callback_query))), "same slot" - @staticmethod def skip_params(callback_query: CallbackQuery): if callback_query.inline_message_id: @@ -79,6 +74,11 @@ def check_passed_ids(callback_query: CallbackQuery, kwargs): message_id = kwargs['message_id'] == callback_query.message.message_id return id_ and chat_id and message_id + def test_slot_behaviour(self, callback_query, mro_slots): + for attr in callback_query.__slots__: + assert getattr(callback_query, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(callback_query)) == len(set(mro_slots(callback_query))), "same slot" + def test_de_json(self, bot): json_dict = { 'id': self.id_, @@ -113,24 +113,26 @@ def test_to_dict(self, callback_query): assert callback_query_dict['data'] == callback_query.data assert callback_query_dict['game_short_name'] == callback_query.game_short_name - def test_answer(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_answer(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): return kwargs['callback_query_id'] == callback_query.id assert check_shortcut_signature( CallbackQuery.answer, Bot.answer_callback_query, ['callback_query_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.answer, callback_query.get_bot(), 'answer_callback_query' ) - assert check_defaults_handling(callback_query.answer, callback_query.get_bot()) + assert await check_defaults_handling(callback_query.answer, callback_query.get_bot()) monkeypatch.setattr(callback_query.get_bot(), 'answer_callback_query', make_assertion) # TODO: PEP8 - assert callback_query.answer() + assert await callback_query.answer() - def test_edit_message_text(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_message_text(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): text = kwargs['text'] == 'test' ids = self.check_passed_ids(callback_query, kwargs) return ids and text @@ -141,21 +143,24 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.edit_message_text, callback_query.get_bot(), 'edit_message_text', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling(callback_query.edit_message_text, callback_query.get_bot()) + assert await check_defaults_handling( + callback_query.edit_message_text, callback_query.get_bot() + ) monkeypatch.setattr(callback_query.get_bot(), 'edit_message_text', make_assertion) - assert callback_query.edit_message_text(text='test') - assert callback_query.edit_message_text('test') + assert await callback_query.edit_message_text(text='test') + assert await callback_query.edit_message_text('test') - def test_edit_message_caption(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_message_caption(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): caption = kwargs['caption'] == 'new caption' ids = self.check_passed_ids(callback_query, kwargs) return ids and caption @@ -166,23 +171,24 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.edit_message_caption, callback_query.get_bot(), 'edit_message_caption', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling( + assert await check_defaults_handling( callback_query.edit_message_caption, callback_query.get_bot() ) monkeypatch.setattr(callback_query.get_bot(), 'edit_message_caption', make_assertion) - assert callback_query.edit_message_caption(caption='new caption') - assert callback_query.edit_message_caption('new caption') + assert await callback_query.edit_message_caption(caption='new caption') + assert await callback_query.edit_message_caption('new caption') - def test_edit_message_reply_markup(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_message_reply_markup(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): reply_markup = kwargs['reply_markup'] == [['1', '2']] ids = self.check_passed_ids(callback_query, kwargs) return ids and reply_markup @@ -193,23 +199,24 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.edit_message_reply_markup, callback_query.get_bot(), 'edit_message_reply_markup', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling( + assert await check_defaults_handling( callback_query.edit_message_reply_markup, callback_query.get_bot() ) monkeypatch.setattr(callback_query.get_bot(), 'edit_message_reply_markup', make_assertion) - assert callback_query.edit_message_reply_markup(reply_markup=[['1', '2']]) - assert callback_query.edit_message_reply_markup([['1', '2']]) + assert await callback_query.edit_message_reply_markup(reply_markup=[['1', '2']]) + assert await callback_query.edit_message_reply_markup([['1', '2']]) - def test_edit_message_media(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_message_media(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): message_media = kwargs.get('media') == [['1', '2']] ids = self.check_passed_ids(callback_query, kwargs) return ids and message_media @@ -220,21 +227,24 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.edit_message_media, callback_query.get_bot(), 'edit_message_media', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling(callback_query.edit_message_media, callback_query.get_bot()) + assert await check_defaults_handling( + callback_query.edit_message_media, callback_query.get_bot() + ) monkeypatch.setattr(callback_query.get_bot(), 'edit_message_media', make_assertion) - assert callback_query.edit_message_media(media=[['1', '2']]) - assert callback_query.edit_message_media([['1', '2']]) + assert await callback_query.edit_message_media(media=[['1', '2']]) + assert await callback_query.edit_message_media([['1', '2']]) - def test_edit_message_live_location(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_message_live_location(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): latitude = kwargs.get('latitude') == 1 longitude = kwargs.get('longitude') == 2 ids = self.check_passed_ids(callback_query, kwargs) @@ -246,23 +256,24 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.edit_message_live_location, callback_query.get_bot(), 'edit_message_live_location', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling( + assert await check_defaults_handling( callback_query.edit_message_live_location, callback_query.get_bot() ) monkeypatch.setattr(callback_query.get_bot(), 'edit_message_live_location', make_assertion) - assert callback_query.edit_message_live_location(latitude=1, longitude=2) - assert callback_query.edit_message_live_location(1, 2) + assert await callback_query.edit_message_live_location(latitude=1, longitude=2) + assert await callback_query.edit_message_live_location(1, 2) - def test_stop_message_live_location(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_stop_message_live_location(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): ids = self.check_passed_ids(callback_query, kwargs) return ids @@ -272,22 +283,23 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.stop_message_live_location, callback_query.get_bot(), 'stop_message_live_location', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling( + assert await check_defaults_handling( callback_query.stop_message_live_location, callback_query.get_bot() ) monkeypatch.setattr(callback_query.get_bot(), 'stop_message_live_location', make_assertion) - assert callback_query.stop_message_live_location() + assert await callback_query.stop_message_live_location() - def test_set_game_score(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_set_game_score(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): user_id = kwargs.get('user_id') == 1 score = kwargs.get('score') == 2 ids = self.check_passed_ids(callback_query, kwargs) @@ -299,21 +311,24 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.set_game_score, callback_query.get_bot(), 'set_game_score', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling(callback_query.set_game_score, callback_query.get_bot()) + assert await check_defaults_handling( + callback_query.set_game_score, callback_query.get_bot() + ) monkeypatch.setattr(callback_query.get_bot(), 'set_game_score', make_assertion) - assert callback_query.set_game_score(user_id=1, score=2) - assert callback_query.set_game_score(1, 2) + assert await callback_query.set_game_score(user_id=1, score=2) + assert await callback_query.set_game_score(1, 2) - def test_get_game_high_scores(self, monkeypatch, callback_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_game_high_scores(self, monkeypatch, callback_query): + async def make_assertion(*_, **kwargs): user_id = kwargs.get('user_id') == 1 ids = self.check_passed_ids(callback_query, kwargs) return ids and user_id @@ -324,26 +339,27 @@ def make_assertion(*_, **kwargs): ['inline_message_id', 'message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.get_game_high_scores, callback_query.get_bot(), 'get_game_high_scores', skip_params=self.skip_params(callback_query), shortcut_kwargs=self.shortcut_kwargs(callback_query), ) - assert check_defaults_handling( + assert await check_defaults_handling( callback_query.get_game_high_scores, callback_query.get_bot() ) monkeypatch.setattr(callback_query.get_bot(), 'get_game_high_scores', make_assertion) - assert callback_query.get_game_high_scores(user_id=1) - assert callback_query.get_game_high_scores(1) + assert await callback_query.get_game_high_scores(user_id=1) + assert await callback_query.get_game_high_scores(1) - def test_delete_message(self, monkeypatch, callback_query): + @pytest.mark.asyncio + async def test_delete_message(self, monkeypatch, callback_query): if callback_query.inline_message_id: pytest.skip("Can't delete inline messages") - def make_assertion(*args, **kwargs): + async def make_assertion(*args, **kwargs): id_ = kwargs['chat_id'] == callback_query.message.chat_id message = kwargs['message_id'] == callback_query.message.message_id return id_ and message @@ -354,19 +370,22 @@ def make_assertion(*args, **kwargs): ['message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.delete_message, callback_query.get_bot(), 'delete_message' ) - assert check_defaults_handling(callback_query.delete_message, callback_query.get_bot()) + assert await check_defaults_handling( + callback_query.delete_message, callback_query.get_bot() + ) monkeypatch.setattr(callback_query.get_bot(), 'delete_message', make_assertion) - assert callback_query.delete_message() + assert await callback_query.delete_message() - def test_pin_message(self, monkeypatch, callback_query): + @pytest.mark.asyncio + async def test_pin_message(self, monkeypatch, callback_query): if callback_query.inline_message_id: pytest.skip("Can't pin inline messages") - def make_assertion(*args, **kwargs): + async def make_assertion(*args, **kwargs): return kwargs['chat_id'] == callback_query.message.chat_id assert check_shortcut_signature( @@ -375,19 +394,20 @@ def make_assertion(*args, **kwargs): ['message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.pin_message, callback_query.get_bot(), 'pin_chat_message' ) - assert check_defaults_handling(callback_query.pin_message, callback_query.get_bot()) + assert await check_defaults_handling(callback_query.pin_message, callback_query.get_bot()) monkeypatch.setattr(callback_query.get_bot(), 'pin_chat_message', make_assertion) - assert callback_query.pin_message() + assert await callback_query.pin_message() - def test_unpin_message(self, monkeypatch, callback_query): + @pytest.mark.asyncio + async def test_unpin_message(self, monkeypatch, callback_query): if callback_query.inline_message_id: pytest.skip("Can't unpin inline messages") - def make_assertion(*args, **kwargs): + async def make_assertion(*args, **kwargs): return kwargs['chat_id'] == callback_query.message.chat_id assert check_shortcut_signature( @@ -396,22 +416,25 @@ def make_assertion(*args, **kwargs): ['message_id', 'chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.unpin_message, callback_query.get_bot(), 'unpin_chat_message', shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(callback_query.unpin_message, callback_query.get_bot()) + assert await check_defaults_handling( + callback_query.unpin_message, callback_query.get_bot() + ) monkeypatch.setattr(callback_query.get_bot(), 'unpin_chat_message', make_assertion) - assert callback_query.unpin_message() + assert await callback_query.unpin_message() - def test_copy_message(self, monkeypatch, callback_query): + @pytest.mark.asyncio + async def test_copy_message(self, monkeypatch, callback_query): if callback_query.inline_message_id: pytest.skip("Can't copy inline messages") - def make_assertion(*args, **kwargs): + async def make_assertion(*args, **kwargs): id_ = kwargs['from_chat_id'] == callback_query.message.chat_id chat_id = kwargs['chat_id'] == 1 message = kwargs['message_id'] == callback_query.message.message_id @@ -423,13 +446,13 @@ def make_assertion(*args, **kwargs): ['message_id', 'from_chat_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( callback_query.copy_message, callback_query.get_bot(), 'copy_message' ) - assert check_defaults_handling(callback_query.copy_message, callback_query.get_bot()) + assert await check_defaults_handling(callback_query.copy_message, callback_query.get_bot()) monkeypatch.setattr(callback_query.get_bot(), 'copy_message', make_assertion) - assert callback_query.copy_message(1) + assert await callback_query.copy_message(1) def test_equality(self): a = CallbackQuery(self.id_, self.from_user, 'chat') diff --git a/tests/test_callbackqueryhandler.py b/tests/test_callbackqueryhandler.py deleted file mode 100644 index 1089e215b6d..00000000000 --- a/tests/test_callbackqueryhandler.py +++ /dev/null @@ -1,211 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - CallbackQuery, - Bot, - Message, - User, - Chat, - InlineQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import CallbackQueryHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, -] - -ids = ( - 'message', - 'edited_message', - 'channel_post', - 'edited_channel_post', - 'inline_query', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=2, **request.param) - - -@pytest.fixture(scope='function') -def callback_query(bot): - return Update(0, callback_query=CallbackQuery(2, User(1, '', False), None, data='test data')) - - -class TestCallbackQueryHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - handler = CallbackQueryHandler(self.callback_data_1) - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_basic(self, update, context): - test_bot = isinstance(context.bot, Bot) - test_update = isinstance(update, Update) - self.test_flag = test_bot and test_update - - def callback_data_1(self, bot, update, user_data=None, chat_data=None): - self.test_flag = (user_data is not None) or (chat_data is not None) - - def callback_data_2(self, bot, update, user_data=None, chat_data=None): - self.test_flag = (user_data is not None) and (chat_data is not None) - - def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): - self.test_flag = (job_queue is not None) or (update_queue is not None) - - def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): - self.test_flag = (job_queue is not None) and (update_queue is not None) - - def callback_group(self, bot, update, groups=None, groupdict=None): - if groups is not None: - self.test_flag = groups == ('t', ' data') - if groupdict is not None: - self.test_flag = groupdict == {'begin': 't', 'end': ' data'} - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.callback_query, CallbackQuery) - ) - - def callback_context_pattern(self, update, context): - if context.matches[0].groups(): - self.test_flag = context.matches[0].groups() == ('t', ' data') - if context.matches[0].groupdict(): - self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' data'} - - def test_with_pattern(self, callback_query): - handler = CallbackQueryHandler(self.callback_basic, pattern='.*est.*') - - assert handler.check_update(callback_query) - - callback_query.callback_query.data = 'nothing here' - assert not handler.check_update(callback_query) - - callback_query.callback_query.data = None - callback_query.callback_query.game_short_name = "this is a short game name" - assert not handler.check_update(callback_query) - - def test_with_callable_pattern(self, callback_query): - class CallbackData: - pass - - def pattern(callback_data): - return isinstance(callback_data, CallbackData) - - handler = CallbackQueryHandler(self.callback_basic, pattern=pattern) - - callback_query.callback_query.data = CallbackData() - assert handler.check_update(callback_query) - callback_query.callback_query.data = 'callback_data' - assert not handler.check_update(callback_query) - - def test_with_type_pattern(self, callback_query): - class CallbackData: - pass - - handler = CallbackQueryHandler(self.callback_basic, pattern=CallbackData) - - callback_query.callback_query.data = CallbackData() - assert handler.check_update(callback_query) - callback_query.callback_query.data = 'callback_data' - assert not handler.check_update(callback_query) - - handler = CallbackQueryHandler(self.callback_basic, pattern=bool) - - callback_query.callback_query.data = False - assert handler.check_update(callback_query) - callback_query.callback_query.data = 'callback_data' - assert not handler.check_update(callback_query) - - def test_other_update_types(self, false_update): - handler = CallbackQueryHandler(self.callback_basic) - assert not handler.check_update(false_update) - - def test_context(self, dp, callback_query): - handler = CallbackQueryHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(callback_query) - assert self.test_flag - - def test_context_pattern(self, dp, callback_query): - handler = CallbackQueryHandler( - self.callback_context_pattern, pattern=r'(?P.*)est(?P.*)' - ) - dp.add_handler(handler) - - dp.process_update(callback_query) - assert self.test_flag - - dp.remove_handler(handler) - handler = CallbackQueryHandler(self.callback_context_pattern, pattern=r'(t)est(.*)') - dp.add_handler(handler) - - dp.process_update(callback_query) - assert self.test_flag - - def test_context_callable_pattern(self, dp, callback_query): - class CallbackData: - pass - - def pattern(callback_data): - return isinstance(callback_data, CallbackData) - - def callback(update, context): - assert context.matches is None - - handler = CallbackQueryHandler(callback, pattern=pattern) - dp.add_handler(handler) - - dp.process_update(callback_query) diff --git a/tests/test_chat.py b/tests/test_chat.py index 8b2246f44a5..e3eef6f5dee 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -37,7 +37,6 @@ def chat(bot): can_set_sticker_set=TestChat.can_set_sticker_set, permissions=TestChat.permissions, slow_mode_delay=TestChat.slow_mode_delay, - message_auto_delete_time=TestChat.message_auto_delete_time, bio=TestChat.bio, linked_chat_id=TestChat.linked_chat_id, location=TestChat.location, @@ -60,18 +59,12 @@ class TestChat: can_invite_users=True, ) slow_mode_delay = 30 - message_auto_delete_time = 42 bio = "I'm a Barbie Girl in a Barbie World" linked_chat_id = 11880 location = ChatLocation(Location(123, 456), 'Barbie World') has_protected_content = True has_private_forwards = True - def test_slot_behaviour(self, chat, mro_slots): - for attr in chat.__slots__: - assert getattr(chat, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(chat)) == len(set(mro_slots(chat))), "duplicate slot" - def test_de_json(self, bot): json_dict = { 'id': self.id_, @@ -83,7 +76,6 @@ def test_de_json(self, bot): 'can_set_sticker_set': self.can_set_sticker_set, 'permissions': self.permissions.to_dict(), 'slow_mode_delay': self.slow_mode_delay, - 'message_auto_delete_time': self.message_auto_delete_time, 'bio': self.bio, 'has_protected_content': self.has_protected_content, 'has_private_forwards': self.has_private_forwards, @@ -101,7 +93,6 @@ def test_de_json(self, bot): assert chat.can_set_sticker_set == self.can_set_sticker_set assert chat.permissions == self.permissions assert chat.slow_mode_delay == self.slow_mode_delay - assert chat.message_auto_delete_time == self.message_auto_delete_time assert chat.bio == self.bio assert chat.has_protected_content == self.has_protected_content assert chat.has_private_forwards == self.has_private_forwards @@ -120,7 +111,6 @@ def test_to_dict(self, chat): assert chat_dict['all_members_are_administrators'] == chat.all_members_are_administrators assert chat_dict['permissions'] == chat.permissions.to_dict() assert chat_dict['slow_mode_delay'] == chat.slow_mode_delay - assert chat_dict['message_auto_delete_time'] == chat.message_auto_delete_time assert chat_dict['bio'] == chat.bio assert chat_dict['has_private_forwards'] == chat.has_private_forwards assert chat_dict['has_protected_content'] == chat.has_protected_content @@ -145,88 +135,97 @@ def test_full_name(self): ) assert chat.full_name is None - def test_send_action(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_send_action(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == chat.id action = kwargs['action'] == ChatAction.TYPING return id_ and action assert check_shortcut_signature(chat.send_action, Bot.send_chat_action, ['chat_id'], []) - assert check_shortcut_call(chat.send_action, chat.get_bot(), 'send_chat_action') - assert check_defaults_handling(chat.send_action, chat.get_bot()) + assert await check_shortcut_call(chat.send_action, chat.get_bot(), 'send_chat_action') + assert await check_defaults_handling(chat.send_action, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_chat_action', make_assertion) - assert chat.send_action(action=ChatAction.TYPING) - assert chat.send_action(action=ChatAction.TYPING) + assert await chat.send_action(action=ChatAction.TYPING) + assert await chat.send_action(action=ChatAction.TYPING) - def test_leave(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_leave(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature(Chat.leave, Bot.leave_chat, ['chat_id'], []) - assert check_shortcut_call(chat.leave, chat.get_bot(), 'leave_chat') - assert check_defaults_handling(chat.leave, chat.get_bot()) + assert await check_shortcut_call(chat.leave, chat.get_bot(), 'leave_chat') + assert await check_defaults_handling(chat.leave, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'leave_chat', make_assertion) - assert chat.leave() + assert await chat.leave() - def test_get_administrators(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_administrators(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature( Chat.get_administrators, Bot.get_chat_administrators, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.get_administrators, chat.get_bot(), 'get_chat_administrators' ) - assert check_defaults_handling(chat.get_administrators, chat.get_bot()) + assert await check_defaults_handling(chat.get_administrators, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'get_chat_administrators', make_assertion) - assert chat.get_administrators() + assert await chat.get_administrators() - def test_get_member_count(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_members_count(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature( Chat.get_member_count, Bot.get_chat_member_count, ['chat_id'], [] ) - assert check_shortcut_call(chat.get_member_count, chat.get_bot(), 'get_chat_member_count') - assert check_defaults_handling(chat.get_member_count, chat.get_bot()) + assert await check_shortcut_call( + chat.get_member_count, chat.get_bot(), 'get_chat_member_count' + ) + assert await check_defaults_handling(chat.get_member_count, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'get_chat_member_count', make_assertion) - assert chat.get_member_count() + assert await chat.get_member_count() - def test_get_member(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_member(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id user_id = kwargs['user_id'] == 42 return chat_id and user_id assert check_shortcut_signature(Chat.get_member, Bot.get_chat_member, ['chat_id'], []) - assert check_shortcut_call(chat.get_member, chat.get_bot(), 'get_chat_member') - assert check_defaults_handling(chat.get_member, chat.get_bot()) + assert await check_shortcut_call(chat.get_member, chat.get_bot(), 'get_chat_member') + assert await check_defaults_handling(chat.get_member, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'get_chat_member', make_assertion) - assert chat.get_member(user_id=42) + assert await chat.get_member(user_id=42) - def test_ban_member(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_ban_member(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id user_id = kwargs['user_id'] == 42 until = kwargs['until_date'] == 43 return chat_id and user_id and until assert check_shortcut_signature(Chat.ban_member, Bot.ban_chat_member, ['chat_id'], []) - assert check_shortcut_call(chat.ban_member, chat.get_bot(), 'ban_chat_member') - assert check_defaults_handling(chat.ban_member, chat.get_bot()) + assert await check_shortcut_call(chat.ban_member, chat.get_bot(), 'ban_chat_member') + assert await check_defaults_handling(chat.ban_member, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'ban_chat_member', make_assertion) - assert chat.ban_member(user_id=42, until_date=43) + assert await chat.ban_member(user_id=42, until_date=43) - def test_ban_sender_chat(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_ban_sender_chat(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id sender_chat_id = kwargs['sender_chat_id'] == 42 return chat_id and sender_chat_id @@ -234,14 +233,17 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Chat.ban_sender_chat, Bot.ban_chat_sender_chat, ['chat_id'], [] ) - assert check_shortcut_call(chat.ban_sender_chat, chat.get_bot(), 'ban_chat_sender_chat') - assert check_defaults_handling(chat.ban_sender_chat, chat.get_bot()) + assert await check_shortcut_call( + chat.ban_sender_chat, chat.get_bot(), 'ban_chat_sender_chat' + ) + assert await check_defaults_handling(chat.ban_sender_chat, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'ban_chat_sender_chat', make_assertion) - assert chat.ban_sender_chat(42) + assert await chat.ban_sender_chat(42) - def test_ban_chat(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_ban_chat(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 42 sender_chat_id = kwargs['sender_chat_id'] == chat.id return chat_id and sender_chat_id @@ -249,29 +251,31 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Chat.ban_chat, Bot.ban_chat_sender_chat, ['sender_chat_id'], [] ) - assert check_shortcut_call(chat.ban_chat, chat.get_bot(), 'ban_chat_sender_chat') - assert check_defaults_handling(chat.ban_chat, chat.get_bot()) + assert await check_shortcut_call(chat.ban_chat, chat.get_bot(), 'ban_chat_sender_chat') + assert await check_defaults_handling(chat.ban_chat, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'ban_chat_sender_chat', make_assertion) - assert chat.ban_chat(42) + assert await chat.ban_chat(42) @pytest.mark.parametrize('only_if_banned', [True, False, None]) - def test_unban_member(self, monkeypatch, chat, only_if_banned): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_unban_member(self, monkeypatch, chat, only_if_banned): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id user_id = kwargs['user_id'] == 42 - o_i_b = kwargs.get('only_if_banned') == only_if_banned + o_i_b = kwargs.get('only_if_banned', None) == only_if_banned return chat_id and user_id and o_i_b assert check_shortcut_signature(Chat.unban_member, Bot.unban_chat_member, ['chat_id'], []) - assert check_shortcut_call(chat.unban_member, chat.get_bot(), 'unban_chat_member') - assert check_defaults_handling(chat.unban_member, chat.get_bot()) + assert await check_shortcut_call(chat.unban_member, chat.get_bot(), 'unban_chat_member') + assert await check_defaults_handling(chat.unban_member, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'unban_chat_member', make_assertion) - assert chat.unban_member(user_id=42, only_if_banned=only_if_banned) + assert await chat.unban_member(user_id=42, only_if_banned=only_if_banned) - def test_unban_sender_chat(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_unban_sender_chat(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id sender_chat_id = kwargs['sender_chat_id'] == 42 return chat_id and sender_chat_id @@ -279,16 +283,17 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Chat.unban_sender_chat, Bot.unban_chat_sender_chat, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.unban_sender_chat, chat.get_bot(), 'unban_chat_sender_chat' ) - assert check_defaults_handling(chat.unban_sender_chat, chat.get_bot()) + assert await check_defaults_handling(chat.unban_sender_chat, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'unban_chat_sender_chat', make_assertion) - assert chat.unban_sender_chat(42) + assert await chat.unban_sender_chat(42) - def test_unban_chat(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_unban_chat(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 42 sender_chat_id = kwargs['sender_chat_id'] == chat.id return chat_id and sender_chat_id @@ -296,49 +301,56 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Chat.unban_chat, Bot.ban_chat_sender_chat, ['sender_chat_id'], [] ) - assert check_shortcut_call(chat.unban_chat, chat.get_bot(), 'unban_chat_sender_chat') - assert check_defaults_handling(chat.unban_chat, chat.get_bot()) + assert await check_shortcut_call(chat.unban_chat, chat.get_bot(), 'unban_chat_sender_chat') + assert await check_defaults_handling(chat.unban_chat, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'unban_chat_sender_chat', make_assertion) - assert chat.unban_chat(42) + assert await chat.unban_chat(42) @pytest.mark.parametrize('is_anonymous', [True, False, None]) - def test_promote_member(self, monkeypatch, chat, is_anonymous): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_promote_member(self, monkeypatch, chat, is_anonymous): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id user_id = kwargs['user_id'] == 42 - o_i_b = kwargs.get('is_anonymous') == is_anonymous + o_i_b = kwargs.get('is_anonymous', None) == is_anonymous return chat_id and user_id and o_i_b assert check_shortcut_signature( Chat.promote_member, Bot.promote_chat_member, ['chat_id'], [] ) - assert check_shortcut_call(chat.promote_member, chat.get_bot(), 'promote_chat_member') - assert check_defaults_handling(chat.promote_member, chat.get_bot()) + assert await check_shortcut_call( + chat.promote_member, chat.get_bot(), 'promote_chat_member' + ) + assert await check_defaults_handling(chat.promote_member, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'promote_chat_member', make_assertion) - assert chat.promote_member(user_id=42, is_anonymous=is_anonymous) + assert await chat.promote_member(user_id=42, is_anonymous=is_anonymous) - def test_restrict_member(self, monkeypatch, chat): + @pytest.mark.asyncio + async def test_restrict_member(self, monkeypatch, chat): permissions = ChatPermissions(True, False, True, False, True, False, True, False) - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id user_id = kwargs['user_id'] == 42 - o_i_b = kwargs.get('permissions') == permissions + o_i_b = kwargs.get('permissions', None) == permissions return chat_id and user_id and o_i_b assert check_shortcut_signature( Chat.restrict_member, Bot.restrict_chat_member, ['chat_id'], [] ) - assert check_shortcut_call(chat.restrict_member, chat.get_bot(), 'restrict_chat_member') - assert check_defaults_handling(chat.restrict_member, chat.get_bot()) + assert await check_shortcut_call( + chat.restrict_member, chat.get_bot(), 'restrict_chat_member' + ) + assert await check_defaults_handling(chat.restrict_member, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'restrict_chat_member', make_assertion) - assert chat.restrict_member(user_id=42, permissions=permissions) + assert await chat.restrict_member(user_id=42, permissions=permissions) - def test_set_permissions(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_set_permissions(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id permissions = kwargs['permissions'] == self.permissions return chat_id and permissions @@ -346,153 +358,168 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Chat.set_permissions, Bot.set_chat_permissions, ['chat_id'], [] ) - assert check_shortcut_call(chat.set_permissions, chat.get_bot(), 'set_chat_permissions') - assert check_defaults_handling(chat.set_permissions, chat.get_bot()) + assert await check_shortcut_call( + chat.set_permissions, chat.get_bot(), 'set_chat_permissions' + ) + assert await check_defaults_handling(chat.set_permissions, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'set_chat_permissions', make_assertion) - assert chat.set_permissions(permissions=self.permissions) + assert await chat.set_permissions(permissions=self.permissions) - def test_set_administrator_custom_title(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_set_administrator_custom_title(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == chat.id user_id = kwargs['user_id'] == 42 custom_title = kwargs['custom_title'] == 'custom_title' return chat_id and user_id and custom_title monkeypatch.setattr('telegram.Bot.set_chat_administrator_custom_title', make_assertion) - assert chat.set_administrator_custom_title(user_id=42, custom_title='custom_title') + assert await chat.set_administrator_custom_title(user_id=42, custom_title='custom_title') - def test_pin_message(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_pin_message(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['message_id'] == 42 assert check_shortcut_signature(Chat.pin_message, Bot.pin_chat_message, ['chat_id'], []) - assert check_shortcut_call(chat.pin_message, chat.get_bot(), 'pin_chat_message') - assert check_defaults_handling(chat.pin_message, chat.get_bot()) + assert await check_shortcut_call(chat.pin_message, chat.get_bot(), 'pin_chat_message') + assert await check_defaults_handling(chat.pin_message, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'pin_chat_message', make_assertion) - assert chat.pin_message(message_id=42) + assert await chat.pin_message(message_id=42) - def test_unpin_message(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_unpin_message(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature( Chat.unpin_message, Bot.unpin_chat_message, ['chat_id'], [] ) - assert check_shortcut_call(chat.unpin_message, chat.get_bot(), 'unpin_chat_message') - assert check_defaults_handling(chat.unpin_message, chat.get_bot()) + assert await check_shortcut_call(chat.unpin_message, chat.get_bot(), 'unpin_chat_message') + assert await check_defaults_handling(chat.unpin_message, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'unpin_chat_message', make_assertion) - assert chat.unpin_message() + assert await chat.unpin_message() - def test_unpin_all_messages(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_unpin_all_messages(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature( Chat.unpin_all_messages, Bot.unpin_all_chat_messages, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.unpin_all_messages, chat.get_bot(), 'unpin_all_chat_messages' ) - assert check_defaults_handling(chat.unpin_all_messages, chat.get_bot()) + assert await check_defaults_handling(chat.unpin_all_messages, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'unpin_all_chat_messages', make_assertion) - assert chat.unpin_all_messages() + assert await chat.unpin_all_messages() - def test_instance_method_send_message(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_message(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['text'] == 'test' assert check_shortcut_signature(Chat.send_message, Bot.send_message, ['chat_id'], []) - assert check_shortcut_call(chat.send_message, chat.get_bot(), 'send_message') - assert check_defaults_handling(chat.send_message, chat.get_bot()) + assert await check_shortcut_call(chat.send_message, chat.get_bot(), 'send_message') + assert await check_defaults_handling(chat.send_message, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_message', make_assertion) - assert chat.send_message(text='test') + assert await chat.send_message(text='test') - def test_instance_method_send_media_group(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_media_group(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['media'] == 'test_media_group' assert check_shortcut_signature( Chat.send_media_group, Bot.send_media_group, ['chat_id'], [] ) - assert check_shortcut_call(chat.send_media_group, chat.get_bot(), 'send_media_group') - assert check_defaults_handling(chat.send_media_group, chat.get_bot()) + assert await check_shortcut_call(chat.send_media_group, chat.get_bot(), 'send_media_group') + assert await check_defaults_handling(chat.send_media_group, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_media_group', make_assertion) - assert chat.send_media_group(media='test_media_group') + assert await chat.send_media_group(media='test_media_group') - def test_instance_method_send_photo(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_photo(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['photo'] == 'test_photo' assert check_shortcut_signature(Chat.send_photo, Bot.send_photo, ['chat_id'], []) - assert check_shortcut_call(chat.send_photo, chat.get_bot(), 'send_photo') - assert check_defaults_handling(chat.send_photo, chat.get_bot()) + assert await check_shortcut_call(chat.send_photo, chat.get_bot(), 'send_photo') + assert await check_defaults_handling(chat.send_photo, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_photo', make_assertion) - assert chat.send_photo(photo='test_photo') + assert await chat.send_photo(photo='test_photo') - def test_instance_method_send_contact(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_contact(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['phone_number'] == 'test_contact' assert check_shortcut_signature(Chat.send_contact, Bot.send_contact, ['chat_id'], []) - assert check_shortcut_call(chat.send_contact, chat.get_bot(), 'send_contact') - assert check_defaults_handling(chat.send_contact, chat.get_bot()) + assert await check_shortcut_call(chat.send_contact, chat.get_bot(), 'send_contact') + assert await check_defaults_handling(chat.send_contact, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_contact', make_assertion) - assert chat.send_contact(phone_number='test_contact') + assert await chat.send_contact(phone_number='test_contact') - def test_instance_method_send_audio(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_audio(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['audio'] == 'test_audio' assert check_shortcut_signature(Chat.send_audio, Bot.send_audio, ['chat_id'], []) - assert check_shortcut_call(chat.send_audio, chat.get_bot(), 'send_audio') - assert check_defaults_handling(chat.send_audio, chat.get_bot()) + assert await check_shortcut_call(chat.send_audio, chat.get_bot(), 'send_audio') + assert await check_defaults_handling(chat.send_audio, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_audio', make_assertion) - assert chat.send_audio(audio='test_audio') + assert await chat.send_audio(audio='test_audio') - def test_instance_method_send_document(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_document(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['document'] == 'test_document' assert check_shortcut_signature(Chat.send_document, Bot.send_document, ['chat_id'], []) - assert check_shortcut_call(chat.send_document, chat.get_bot(), 'send_document') - assert check_defaults_handling(chat.send_document, chat.get_bot()) + assert await check_shortcut_call(chat.send_document, chat.get_bot(), 'send_document') + assert await check_defaults_handling(chat.send_document, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_document', make_assertion) - assert chat.send_document(document='test_document') + assert await chat.send_document(document='test_document') - def test_instance_method_send_dice(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_dice(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['emoji'] == 'test_dice' assert check_shortcut_signature(Chat.send_dice, Bot.send_dice, ['chat_id'], []) - assert check_shortcut_call(chat.send_dice, chat.get_bot(), 'send_dice') - assert check_defaults_handling(chat.send_dice, chat.get_bot()) + assert await check_shortcut_call(chat.send_dice, chat.get_bot(), 'send_dice') + assert await check_defaults_handling(chat.send_dice, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_dice', make_assertion) - assert chat.send_dice(emoji='test_dice') + assert await chat.send_dice(emoji='test_dice') - def test_instance_method_send_game(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_game(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['game_short_name'] == 'test_game' assert check_shortcut_signature(Chat.send_game, Bot.send_game, ['chat_id'], []) - assert check_shortcut_call(chat.send_game, chat.get_bot(), 'send_game') - assert check_defaults_handling(chat.send_game, chat.get_bot()) + assert await check_shortcut_call(chat.send_game, chat.get_bot(), 'send_game') + assert await check_defaults_handling(chat.send_game, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_game', make_assertion) - assert chat.send_game(game_short_name='test_game') + assert await chat.send_game(game_short_name='test_game') - def test_instance_method_send_invoice(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_invoice(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): title = kwargs['title'] == 'title' description = kwargs['description'] == 'description' payload = kwargs['payload'] == 'payload' @@ -503,11 +530,11 @@ def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and args assert check_shortcut_signature(Chat.send_invoice, Bot.send_invoice, ['chat_id'], []) - assert check_shortcut_call(chat.send_invoice, chat.get_bot(), 'send_invoice') - assert check_defaults_handling(chat.send_invoice, chat.get_bot()) + assert await check_shortcut_call(chat.send_invoice, chat.get_bot(), 'send_invoice') + assert await check_defaults_handling(chat.send_invoice, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_invoice', make_assertion) - assert chat.send_invoice( + assert await chat.send_invoice( 'title', 'description', 'payload', @@ -516,213 +543,231 @@ def make_assertion(*_, **kwargs): 'prices', ) - def test_instance_method_send_location(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_location(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['latitude'] == 'test_location' assert check_shortcut_signature(Chat.send_location, Bot.send_location, ['chat_id'], []) - assert check_shortcut_call(chat.send_location, chat.get_bot(), 'send_location') - assert check_defaults_handling(chat.send_location, chat.get_bot()) + assert await check_shortcut_call(chat.send_location, chat.get_bot(), 'send_location') + assert await check_defaults_handling(chat.send_location, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_location', make_assertion) - assert chat.send_location(latitude='test_location') + assert await chat.send_location(latitude='test_location') - def test_instance_method_send_sticker(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_sticker(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['sticker'] == 'test_sticker' assert check_shortcut_signature(Chat.send_sticker, Bot.send_sticker, ['chat_id'], []) - assert check_shortcut_call(chat.send_sticker, chat.get_bot(), 'send_sticker') - assert check_defaults_handling(chat.send_sticker, chat.get_bot()) + assert await check_shortcut_call(chat.send_sticker, chat.get_bot(), 'send_sticker') + assert await check_defaults_handling(chat.send_sticker, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_sticker', make_assertion) - assert chat.send_sticker(sticker='test_sticker') + assert await chat.send_sticker(sticker='test_sticker') - def test_instance_method_send_venue(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_venue(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['title'] == 'test_venue' assert check_shortcut_signature(Chat.send_venue, Bot.send_venue, ['chat_id'], []) - assert check_shortcut_call(chat.send_venue, chat.get_bot(), 'send_venue') - assert check_defaults_handling(chat.send_venue, chat.get_bot()) + assert await check_shortcut_call(chat.send_venue, chat.get_bot(), 'send_venue') + assert await check_defaults_handling(chat.send_venue, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_venue', make_assertion) - assert chat.send_venue(title='test_venue') + assert await chat.send_venue(title='test_venue') - def test_instance_method_send_video(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_video(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['video'] == 'test_video' assert check_shortcut_signature(Chat.send_video, Bot.send_video, ['chat_id'], []) - assert check_shortcut_call(chat.send_video, chat.get_bot(), 'send_video') - assert check_defaults_handling(chat.send_video, chat.get_bot()) + assert await check_shortcut_call(chat.send_video, chat.get_bot(), 'send_video') + assert await check_defaults_handling(chat.send_video, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_video', make_assertion) - assert chat.send_video(video='test_video') + assert await chat.send_video(video='test_video') - def test_instance_method_send_video_note(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_video_note(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['video_note'] == 'test_video_note' assert check_shortcut_signature(Chat.send_video_note, Bot.send_video_note, ['chat_id'], []) - assert check_shortcut_call(chat.send_video_note, chat.get_bot(), 'send_video_note') - assert check_defaults_handling(chat.send_video_note, chat.get_bot()) + assert await check_shortcut_call(chat.send_video_note, chat.get_bot(), 'send_video_note') + assert await check_defaults_handling(chat.send_video_note, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_video_note', make_assertion) - assert chat.send_video_note(video_note='test_video_note') + assert await chat.send_video_note(video_note='test_video_note') - def test_instance_method_send_voice(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_voice(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['voice'] == 'test_voice' assert check_shortcut_signature(Chat.send_voice, Bot.send_voice, ['chat_id'], []) - assert check_shortcut_call(chat.send_voice, chat.get_bot(), 'send_voice') - assert check_defaults_handling(chat.send_voice, chat.get_bot()) + assert await check_shortcut_call(chat.send_voice, chat.get_bot(), 'send_voice') + assert await check_defaults_handling(chat.send_voice, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_voice', make_assertion) - assert chat.send_voice(voice='test_voice') + assert await chat.send_voice(voice='test_voice') - def test_instance_method_send_animation(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_animation(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['animation'] == 'test_animation' assert check_shortcut_signature(Chat.send_animation, Bot.send_animation, ['chat_id'], []) - assert check_shortcut_call(chat.send_animation, chat.get_bot(), 'send_animation') - assert check_defaults_handling(chat.send_animation, chat.get_bot()) + assert await check_shortcut_call(chat.send_animation, chat.get_bot(), 'send_animation') + assert await check_defaults_handling(chat.send_animation, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_animation', make_assertion) - assert chat.send_animation(animation='test_animation') + assert await chat.send_animation(animation='test_animation') - def test_instance_method_send_poll(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_poll(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['question'] == 'test_poll' assert check_shortcut_signature(Chat.send_poll, Bot.send_poll, ['chat_id'], []) - assert check_shortcut_call(chat.send_poll, chat.get_bot(), 'send_poll') - assert check_defaults_handling(chat.send_poll, chat.get_bot()) + assert await check_shortcut_call(chat.send_poll, chat.get_bot(), 'send_poll') + assert await check_defaults_handling(chat.send_poll, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'send_poll', make_assertion) - assert chat.send_poll(question='test_poll', options=[1, 2]) + assert await chat.send_poll(question='test_poll', options=[1, 2]) - def test_instance_method_send_copy(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_copy(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): from_chat_id = kwargs['from_chat_id'] == 'test_copy' message_id = kwargs['message_id'] == 42 chat_id = kwargs['chat_id'] == chat.id return from_chat_id and message_id and chat_id assert check_shortcut_signature(Chat.send_copy, Bot.copy_message, ['chat_id'], []) - assert check_shortcut_call(chat.copy_message, chat.get_bot(), 'copy_message') - assert check_defaults_handling(chat.copy_message, chat.get_bot()) + assert await check_shortcut_call(chat.copy_message, chat.get_bot(), 'copy_message') + assert await check_defaults_handling(chat.copy_message, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'copy_message', make_assertion) - assert chat.send_copy(from_chat_id='test_copy', message_id=42) + assert await chat.send_copy(from_chat_id='test_copy', message_id=42) - def test_instance_method_copy_message(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_copy_message(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): from_chat_id = kwargs['from_chat_id'] == chat.id message_id = kwargs['message_id'] == 42 chat_id = kwargs['chat_id'] == 'test_copy' return from_chat_id and message_id and chat_id assert check_shortcut_signature(Chat.copy_message, Bot.copy_message, ['from_chat_id'], []) - assert check_shortcut_call(chat.copy_message, chat.get_bot(), 'copy_message') - assert check_defaults_handling(chat.copy_message, chat.get_bot()) + assert await check_shortcut_call(chat.copy_message, chat.get_bot(), 'copy_message') + assert await check_defaults_handling(chat.copy_message, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'copy_message', make_assertion) - assert chat.copy_message(chat_id='test_copy', message_id=42) + assert await chat.copy_message(chat_id='test_copy', message_id=42) - def test_export_invite_link(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_export_invite_link(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature( Chat.export_invite_link, Bot.export_chat_invite_link, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.export_invite_link, chat.get_bot(), 'export_chat_invite_link' ) - assert check_defaults_handling(chat.export_invite_link, chat.get_bot()) + assert await check_defaults_handling(chat.export_invite_link, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'export_chat_invite_link', make_assertion) - assert chat.export_invite_link() + assert await chat.export_invite_link() - def test_create_invite_link(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_create_invite_link(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id assert check_shortcut_signature( Chat.create_invite_link, Bot.create_chat_invite_link, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.create_invite_link, chat.get_bot(), 'create_chat_invite_link' ) - assert check_defaults_handling(chat.create_invite_link, chat.get_bot()) + assert await check_defaults_handling(chat.create_invite_link, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'create_chat_invite_link', make_assertion) - assert chat.create_invite_link() + assert await chat.create_invite_link() - def test_edit_invite_link(self, monkeypatch, chat): + @pytest.mark.asyncio + async def test_edit_invite_link(self, monkeypatch, chat): link = "ThisIsALink" - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['invite_link'] == link assert check_shortcut_signature( Chat.edit_invite_link, Bot.edit_chat_invite_link, ['chat_id'], [] ) - assert check_shortcut_call(chat.edit_invite_link, chat.get_bot(), 'edit_chat_invite_link') - assert check_defaults_handling(chat.edit_invite_link, chat.get_bot()) + assert await check_shortcut_call( + chat.edit_invite_link, chat.get_bot(), 'edit_chat_invite_link' + ) + assert await check_defaults_handling(chat.edit_invite_link, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'edit_chat_invite_link', make_assertion) - assert chat.edit_invite_link(invite_link=link) + assert await chat.edit_invite_link(invite_link=link) - def test_revoke_invite_link(self, monkeypatch, chat): + @pytest.mark.asyncio + async def test_revoke_invite_link(self, monkeypatch, chat): link = "ThisIsALink" - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['invite_link'] == link assert check_shortcut_signature( Chat.revoke_invite_link, Bot.revoke_chat_invite_link, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.revoke_invite_link, chat.get_bot(), 'revoke_chat_invite_link' ) - assert check_defaults_handling(chat.revoke_invite_link, chat.get_bot()) + assert await check_defaults_handling(chat.revoke_invite_link, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'revoke_chat_invite_link', make_assertion) - assert chat.revoke_invite_link(invite_link=link) + assert await chat.revoke_invite_link(invite_link=link) - def test_approve_join_request(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_approve_join_request(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['user_id'] == 42 assert check_shortcut_signature( Chat.approve_join_request, Bot.approve_chat_join_request, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.approve_join_request, chat.get_bot(), 'approve_chat_join_request' ) - assert check_defaults_handling(chat.approve_join_request, chat.get_bot()) + assert await check_defaults_handling(chat.approve_join_request, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'approve_chat_join_request', make_assertion) - assert chat.approve_join_request(user_id=42) + assert await chat.approve_join_request(user_id=42) - def test_decline_join_request(self, monkeypatch, chat): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_decline_join_request(self, monkeypatch, chat): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == chat.id and kwargs['user_id'] == 42 assert check_shortcut_signature( Chat.decline_join_request, Bot.decline_chat_join_request, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat.decline_join_request, chat.get_bot(), 'decline_chat_join_request' ) - assert check_defaults_handling(chat.decline_join_request, chat.get_bot()) + assert await check_defaults_handling(chat.decline_join_request, chat.get_bot()) monkeypatch.setattr(chat.get_bot(), 'decline_chat_join_request', make_assertion) - assert chat.decline_join_request(user_id=42) + assert await chat.decline_join_request(user_id=42) def test_equality(self): a = Chat(self.id_, self.title, self.type_) diff --git a/tests/test_chatjoinrequest.py b/tests/test_chatjoinrequest.py index 89b98db412a..c5a53398d49 100644 --- a/tests/test_chatjoinrequest.py +++ b/tests/test_chatjoinrequest.py @@ -119,8 +119,9 @@ def test_equality(self, chat_join_request, time): assert a != f assert hash(a) != hash(f) - def test_approve(self, monkeypatch, chat_join_request): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_approve(self, monkeypatch, chat_join_request): + async def make_assertion(*_, **kwargs): chat_id_test = kwargs['chat_id'] == chat_join_request.chat.id user_id_test = kwargs['user_id'] == chat_join_request.from_user.id @@ -129,18 +130,21 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( ChatJoinRequest.approve, Bot.approve_chat_join_request, ['chat_id', 'user_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat_join_request.approve, chat_join_request.get_bot(), 'approve_chat_join_request' ) - assert check_defaults_handling(chat_join_request.approve, chat_join_request.get_bot()) + assert await check_defaults_handling( + chat_join_request.approve, chat_join_request.get_bot() + ) monkeypatch.setattr( chat_join_request.get_bot(), 'approve_chat_join_request', make_assertion ) - assert chat_join_request.approve() + assert await chat_join_request.approve() - def test_decline(self, monkeypatch, chat_join_request): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_decline(self, monkeypatch, chat_join_request): + async def make_assertion(*_, **kwargs): chat_id_test = kwargs['chat_id'] == chat_join_request.chat.id user_id_test = kwargs['user_id'] == chat_join_request.from_user.id @@ -149,12 +153,14 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( ChatJoinRequest.decline, Bot.decline_chat_join_request, ['chat_id', 'user_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( chat_join_request.decline, chat_join_request.get_bot(), 'decline_chat_join_request' ) - assert check_defaults_handling(chat_join_request.decline, chat_join_request.get_bot()) + assert await check_defaults_handling( + chat_join_request.decline, chat_join_request.get_bot() + ) monkeypatch.setattr( chat_join_request.get_bot(), 'decline_chat_join_request', make_assertion ) - assert chat_join_request.decline() + assert await chat_join_request.decline() diff --git a/tests/test_chatjoinrequesthandler.py b/tests/test_chatjoinrequesthandler.py deleted file mode 100644 index d0b8cf6de42..00000000000 --- a/tests/test_chatjoinrequesthandler.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import datetime -from queue import Queue - -import pytest -import pytz - -from telegram import ( - Update, - Bot, - Message, - User, - Chat, - CallbackQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, - ChatJoinRequest, - ChatInviteLink, -) -from telegram.ext import CallbackContext, JobQueue, ChatJoinRequestHandler - - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=2, **request.param) - - -@pytest.fixture(scope='class') -def time(): - return datetime.datetime.now(tz=pytz.utc) - - -@pytest.fixture(scope='class') -def chat_join_request(time, bot): - return ChatJoinRequest( - chat=Chat(1, Chat.SUPERGROUP), - from_user=User(2, 'first_name', False), - date=time, - bio='bio', - invite_link=ChatInviteLink( - 'https://invite.link', - User(42, 'creator', False), - creates_join_request=False, - name='InviteLink', - is_revoked=False, - is_primary=False, - ), - bot=bot, - ) - - -@pytest.fixture(scope='function') -def chat_join_request_update(bot, chat_join_request): - return Update(0, chat_join_request=chat_join_request) - - -class TestChatJoinRequestHandler: - test_flag = False - - def test_slot_behaviour(self, recwarn, mro_slots): - action = ChatJoinRequestHandler(self.callback_context) - for attr in action.__slots__: - assert getattr(action, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(action)) == len(set(mro_slots(action))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and isinstance(context.chat_data, dict) - and isinstance(context.bot_data, dict) - and isinstance( - update.chat_join_request, - ChatJoinRequest, - ) - ) - - def test_other_update_types(self, false_update): - handler = ChatJoinRequestHandler(self.callback_context) - assert not handler.check_update(false_update) - assert not handler.check_update(True) - - def test_context(self, dp, chat_join_request_update): - handler = ChatJoinRequestHandler(callback=self.callback_context) - dp.add_handler(handler) - - dp.process_update(chat_join_request_update) - assert self.test_flag diff --git a/tests/test_chatmemberhandler.py b/tests/test_chatmemberhandler.py deleted file mode 100644 index 3932ef40e6b..00000000000 --- a/tests/test_chatmemberhandler.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import time -from queue import Queue - -import pytest - -from telegram import ( - Update, - Bot, - Message, - User, - Chat, - CallbackQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, - ChatMemberUpdated, - ChatMember, -) -from telegram.ext import CallbackContext, JobQueue, ChatMemberHandler -from telegram._utils.datetime import from_timestamp - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=2, **request.param) - - -@pytest.fixture(scope='class') -def chat_member_updated(): - return ChatMemberUpdated( - Chat(1, 'chat'), - User(1, '', False), - from_timestamp(int(time.time())), - ChatMember(User(1, '', False), ChatMember.CREATOR), - ChatMember(User(1, '', False), ChatMember.CREATOR), - ) - - -@pytest.fixture(scope='function') -def chat_member(bot, chat_member_updated): - return Update(0, my_chat_member=chat_member_updated) - - -class TestChatMemberHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - action = ChatMemberHandler(self.callback_context) - for attr in action.__slots__: - assert getattr(action, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(action)) == len(set(mro_slots(action))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and isinstance(context.chat_data, dict) - and isinstance(context.bot_data, dict) - and isinstance(update.chat_member or update.my_chat_member, ChatMemberUpdated) - ) - - @pytest.mark.parametrize( - argnames=['allowed_types', 'expected'], - argvalues=[ - (ChatMemberHandler.MY_CHAT_MEMBER, (True, False)), - (ChatMemberHandler.CHAT_MEMBER, (False, True)), - (ChatMemberHandler.ANY_CHAT_MEMBER, (True, True)), - ], - ids=['MY_CHAT_MEMBER', 'CHAT_MEMBER', 'ANY_CHAT_MEMBER'], - ) - def test_chat_member_types( - self, dp, chat_member_updated, chat_member, expected, allowed_types - ): - result_1, result_2 = expected - - handler = ChatMemberHandler(self.callback_context, chat_member_types=allowed_types) - dp.add_handler(handler) - - assert handler.check_update(chat_member) == result_1 - dp.process_update(chat_member) - assert self.test_flag == result_1 - - self.test_flag = False - chat_member.my_chat_member = None - chat_member.chat_member = chat_member_updated - - assert handler.check_update(chat_member) == result_2 - dp.process_update(chat_member) - assert self.test_flag == result_2 - - def test_other_update_types(self, false_update): - handler = ChatMemberHandler(self.callback_context) - assert not handler.check_update(false_update) - assert not handler.check_update(True) - - def test_context(self, dp, chat_member): - handler = ChatMemberHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(chat_member) - assert self.test_flag diff --git a/tests/test_chatphoto.py b/tests/test_chatphoto.py index 84d40175354..9fd3ce954b8 100644 --- a/tests/test_chatphoto.py +++ b/tests/test_chatphoto.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. + import os from pathlib import Path @@ -24,6 +25,7 @@ from telegram import ChatPhoto, Voice, Bot from telegram.error import TelegramError +from telegram.request import RequestData from tests.conftest import ( expect_bad_request, check_shortcut_call, @@ -41,11 +43,14 @@ def chatphoto_file(): @pytest.fixture(scope='function') -def chat_photo(bot, super_group_id): - def func(): - return bot.get_chat(super_group_id, timeout=50).photo +@pytest.mark.asyncio +async def chat_photo(bot, super_group_id): + async def func(): + return (await bot.get_chat(super_group_id, read_timeout=50)).photo - return expect_bad_request(func, 'Type of file mismatch', 'Telegram did not accept the file.') + return await expect_bad_request( + func, 'Type of file mismatch', 'Telegram did not accept the file.' + ) class TestChatPhoto: @@ -55,45 +60,50 @@ class TestChatPhoto: chatphoto_big_file_unique_id = 'bigadc3145fd2e84d95b64d68eaa22aa33e' chatphoto_file_url = 'https://python-telegram-bot.org/static/testfiles/telegram.jpg' - def test_slot_behaviour(self, chat_photo, mro_slots): - for attr in chat_photo.__slots__: - assert getattr(chat_photo, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(chat_photo)) == len(set(mro_slots(chat_photo))), "duplicate slot" - @flaky(3, 1) - def test_send_all_args(self, bot, super_group_id, chatphoto_file, chat_photo, thumb_file): - def func(): - assert bot.set_chat_photo(super_group_id, chatphoto_file) - - expect_bad_request(func, 'Type of file mismatch', 'Telegram did not accept the file.') + @pytest.mark.asyncio + async def test_send_all_args( + self, bot, super_group_id, chatphoto_file, chat_photo, thumb_file + ): + async def func(): + assert await bot.set_chat_photo(super_group_id, chatphoto_file) + + await expect_bad_request( + func, 'Type of file mismatch', 'Telegram did not accept the file.' + ) @flaky(3, 1) - def test_get_and_download(self, bot, chat_photo): + @pytest.mark.asyncio + async def test_get_and_download(self, bot, chat_photo): jpg_file = Path('telegram.jpg') - new_file = bot.get_file(chat_photo.small_file_id) + if jpg_file.is_file(): + jpg_file.unlink() + + new_file = await bot.get_file(chat_photo.small_file_id) assert new_file.file_id == chat_photo.small_file_id assert new_file.file_path.startswith('https://') - new_file.download(jpg_file) + await new_file.download(jpg_file) assert jpg_file.is_file() - new_file = bot.get_file(chat_photo.big_file_id) + new_file = await bot.get_file(chat_photo.big_file_id) assert new_file.file_id == chat_photo.big_file_id assert new_file.file_path.startswith('https://') - new_file.download(jpg_file) + await new_file.download(jpg_file) assert jpg_file.is_file() - def test_send_with_chat_photo(self, monkeypatch, bot, super_group_id, chat_photo): - def test(url, data, **kwargs): - return data['photo'] == chat_photo + @pytest.mark.asyncio + async def test_send_with_chat_photo(self, monkeypatch, bot, super_group_id, chat_photo): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.parameters['photo'] == chat_photo.to_dict() - monkeypatch.setattr(bot.request, 'post', test) - message = bot.set_chat_photo(photo=chat_photo, chat_id=super_group_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.set_chat_photo(photo=chat_photo, chat_id=super_group_id) assert message def test_de_json(self, bot, chat_photo): @@ -109,7 +119,8 @@ def test_de_json(self, bot, chat_photo): assert chat_photo.small_file_unique_id == self.chatphoto_small_file_unique_id assert chat_photo.big_file_unique_id == self.chatphoto_big_file_unique_id - def test_to_dict(self, chat_photo): + @pytest.mark.asyncio + async def test_to_dict(self, chat_photo): chat_photo_dict = chat_photo.to_dict() assert isinstance(chat_photo_dict, dict) @@ -119,42 +130,49 @@ def test_to_dict(self, chat_photo): assert chat_photo_dict['big_file_unique_id'] == chat_photo.big_file_unique_id @flaky(3, 1) - def test_error_send_empty_file(self, bot, super_group_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, super_group_id): chatphoto_file = open(os.devnull, 'rb') with pytest.raises(TelegramError): - bot.set_chat_photo(chat_id=super_group_id, photo=chatphoto_file) + await bot.set_chat_photo(chat_id=super_group_id, photo=chatphoto_file) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, super_group_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, super_group_id): with pytest.raises(TelegramError): - bot.set_chat_photo(chat_id=super_group_id, photo='') + await bot.set_chat_photo(chat_id=super_group_id, photo='') - def test_error_send_without_required_args(self, bot, super_group_id): + @pytest.mark.asyncio + async def test_error_send_without_required_args(self, bot, super_group_id): with pytest.raises(TypeError): - bot.set_chat_photo(chat_id=super_group_id) + await bot.set_chat_photo(chat_id=super_group_id) - def test_get_small_file_instance_method(self, monkeypatch, chat_photo): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_small_file_instance_method(self, monkeypatch, chat_photo): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == chat_photo.small_file_id assert check_shortcut_signature(ChatPhoto.get_small_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(chat_photo.get_small_file, chat_photo.get_bot(), 'get_file') - assert check_defaults_handling(chat_photo.get_small_file, chat_photo.get_bot()) + assert await check_shortcut_call( + chat_photo.get_small_file, chat_photo.get_bot(), 'get_file' + ) + assert await check_defaults_handling(chat_photo.get_small_file, chat_photo.get_bot()) monkeypatch.setattr(chat_photo.get_bot(), 'get_file', make_assertion) - assert chat_photo.get_small_file() + assert await chat_photo.get_small_file() - def test_get_big_file_instance_method(self, monkeypatch, chat_photo): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_big_file_instance_method(self, monkeypatch, chat_photo): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == chat_photo.big_file_id assert check_shortcut_signature(ChatPhoto.get_big_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(chat_photo.get_big_file, chat_photo.get_bot(), 'get_file') - assert check_defaults_handling(chat_photo.get_big_file, chat_photo.get_bot()) + assert await check_shortcut_call(chat_photo.get_big_file, chat_photo.get_bot(), 'get_file') + assert await check_defaults_handling(chat_photo.get_big_file, chat_photo.get_bot()) monkeypatch.setattr(chat_photo.get_bot(), 'get_file', make_assertion) - assert chat_photo.get_big_file() + assert await chat_photo.get_big_file() def test_equality(self): a = ChatPhoto( diff --git a/tests/test_choseninlineresulthandler.py b/tests/test_choseninlineresulthandler.py deleted file mode 100644 index 517db81165f..00000000000 --- a/tests/test_choseninlineresulthandler.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - Chat, - Bot, - ChosenInlineResult, - User, - Message, - CallbackQuery, - InlineQuery, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import ChosenInlineResultHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'inline_query', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=1, **request.param) - - -@pytest.fixture(scope='class') -def chosen_inline_result(): - return Update( - 1, - chosen_inline_result=ChosenInlineResult('result_id', User(1, 'test_user', False), 'query'), - ) - - -class TestChosenInlineResultHandler: - test_flag = False - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def test_slot_behaviour(self, mro_slots): - handler = ChosenInlineResultHandler(self.callback_basic) - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - def callback_basic(self, update, context): - test_bot = isinstance(context.bot, Bot) - test_update = isinstance(update, Update) - self.test_flag = test_bot and test_update - - def callback_data_1(self, bot, update, user_data=None, chat_data=None): - self.test_flag = (user_data is not None) or (chat_data is not None) - - def callback_data_2(self, bot, update, user_data=None, chat_data=None): - self.test_flag = (user_data is not None) and (chat_data is not None) - - def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): - self.test_flag = (job_queue is not None) or (update_queue is not None) - - def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): - self.test_flag = (job_queue is not None) and (update_queue is not None) - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.chosen_inline_result, ChosenInlineResult) - ) - - def callback_context_pattern(self, update, context): - if context.matches[0].groups(): - self.test_flag = context.matches[0].groups() == ('res', '_id') - if context.matches[0].groupdict(): - self.test_flag = context.matches[0].groupdict() == {'begin': 'res', 'end': '_id'} - - def test_other_update_types(self, false_update): - handler = ChosenInlineResultHandler(self.callback_basic) - assert not handler.check_update(false_update) - - def test_context(self, dp, chosen_inline_result): - handler = ChosenInlineResultHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(chosen_inline_result) - assert self.test_flag - - def test_with_pattern(self, chosen_inline_result): - handler = ChosenInlineResultHandler(self.callback_basic, pattern='.*ult.*') - - assert handler.check_update(chosen_inline_result) - - chosen_inline_result.chosen_inline_result.result_id = 'nothing here' - assert not handler.check_update(chosen_inline_result) - chosen_inline_result.chosen_inline_result.result_id = 'result_id' - - def test_context_pattern(self, dp, chosen_inline_result): - handler = ChosenInlineResultHandler( - self.callback_context_pattern, pattern=r'(?P.*)ult(?P.*)' - ) - dp.add_handler(handler) - dp.process_update(chosen_inline_result) - assert self.test_flag - - dp.remove_handler(handler) - handler = ChosenInlineResultHandler(self.callback_context_pattern, pattern=r'(res)ult(.*)') - dp.add_handler(handler) - - dp.process_update(chosen_inline_result) - assert self.test_flag diff --git a/tests/test_commandhandler.py b/tests/test_commandhandler.py deleted file mode 100644 index d2622e89233..00000000000 --- a/tests/test_commandhandler.py +++ /dev/null @@ -1,383 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import re -from queue import Queue - -import pytest - -from telegram import Message, Update, Chat, Bot -from telegram.ext import CommandHandler, filters, CallbackContext, JobQueue, PrefixHandler -from tests.conftest import ( - make_command_message, - make_command_update, - make_message, - make_message_update, -) - - -def is_match(handler, update): - """ - Utility function that returns whether an update matched - against a specific handler. - :param handler: ``CommandHandler`` to check against - :param update: update to check - :return: (bool) whether ``update`` matched with ``handler`` - """ - check = handler.check_update(update) - return check is not None and check is not False - - -class BaseTest: - """Base class for command and prefix handler test classes. Contains - utility methods an several callbacks used by both classes.""" - - test_flag = False - SRE_TYPE = type(re.match("", "")) - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def response(self, dispatcher, update): - """ - Utility to send an update to a dispatcher and assert - whether the callback was called appropriately. Its purpose is - for repeated usage in the same test function. - """ - self.test_flag = False - dispatcher.process_update(update) - return self.test_flag - - def callback_basic(self, update, context): - test_bot = isinstance(context.bot, Bot) - test_update = isinstance(update, Update) - self.test_flag = test_bot and test_update - - def make_callback_for(self, pass_keyword): - def callback(bot, update, **kwargs): - self.test_flag = kwargs.get(keyword) is not None - - keyword = pass_keyword[5:] - return callback - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and isinstance(context.chat_data, dict) - and isinstance(context.bot_data, dict) - and isinstance(update.message, Message) - ) - - def callback_context_args(self, update, context): - self.test_flag = context.args == ['one', 'two'] - - def callback_context_regex1(self, update, context): - if context.matches: - types = all(type(res) is self.SRE_TYPE for res in context.matches) - num = len(context.matches) == 1 - self.test_flag = types and num - - def callback_context_regex2(self, update, context): - if context.matches: - types = all(type(res) is self.SRE_TYPE for res in context.matches) - num = len(context.matches) == 2 - self.test_flag = types and num - - def _test_context_args_or_regex(self, dp, handler, text): - dp.add_handler(handler) - update = make_command_update(text) - assert not self.response(dp, update) - update.message.text += ' one two' - assert self.response(dp, update) - - def _test_edited(self, message, handler_edited, handler_not_edited): - """ - Assert whether a handler that should accept edited messages - and a handler that shouldn't work correctly. - :param message: ``telegram.Message`` to check against the handlers - :param handler_edited: handler that should accept edited messages - :param handler_not_edited: handler that should not accept edited messages - """ - update = make_command_update(message) - edited_update = make_command_update(message, edited=True) - - assert is_match(handler_edited, update) - assert is_match(handler_edited, edited_update) - assert is_match(handler_not_edited, update) - assert not is_match(handler_not_edited, edited_update) - - -# ----------------------------- CommandHandler ----------------------------- - - -class TestCommandHandler(BaseTest): - CMD = '/test' - - def test_slot_behaviour(self, mro_slots): - handler = self.make_default_handler() - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - @pytest.fixture(scope='class') - def command(self): - return self.CMD - - @pytest.fixture(scope='class') - def command_message(self, command): - return make_command_message(command) - - @pytest.fixture(scope='class') - def command_update(self, command_message): - return make_command_update(command_message) - - def make_default_handler(self, callback=None, **kwargs): - callback = callback or self.callback_basic - return CommandHandler(self.CMD[1:], callback, **kwargs) - - def test_basic(self, dp, command): - """Test whether a command handler responds to its command - and not to others, or badly formatted commands""" - handler = self.make_default_handler() - dp.add_handler(handler) - - assert self.response(dp, make_command_update(command)) - assert not is_match(handler, make_command_update(command[1:])) - assert not is_match(handler, make_command_update(f'/not{command[1:]}')) - assert not is_match(handler, make_command_update(f'not {command} at start')) - - @pytest.mark.parametrize( - 'cmd', - ['way_too_longcommand1234567yes_way_toooooooLong', 'ïñválídletters', 'invalid #&* chars'], - ids=['too long', 'invalid letter', 'invalid characters'], - ) - def test_invalid_commands(self, cmd): - with pytest.raises( - ValueError, match=f'`{re.escape(cmd.lower())}` is not a valid bot command' - ): - CommandHandler(cmd, self.callback_basic) - - def test_command_list(self): - """A command handler with multiple commands registered should respond to all of them.""" - handler = CommandHandler(['test', 'star'], self.callback_basic) - assert is_match(handler, make_command_update('/test')) - assert is_match(handler, make_command_update('/star')) - assert not is_match(handler, make_command_update('/stop')) - - def test_edited(self, command_message): - """Test that a CH responds to an edited message if its filters allow it""" - handler_edited = self.make_default_handler() - handler_no_edited = self.make_default_handler(filters=~filters.UpdateType.EDITED_MESSAGE) - self._test_edited(command_message, handler_edited, handler_no_edited) - - def test_directed_commands(self, bot, command): - """Test recognition of commands with a mention to the bot""" - handler = self.make_default_handler() - assert is_match(handler, make_command_update(command + '@' + bot.username, bot=bot)) - assert not is_match(handler, make_command_update(command + '@otherbot', bot=bot)) - - def test_with_filter(self, command): - """Test that a CH with a (generic) filter responds if its filters match""" - handler = self.make_default_handler(filters=filters.ChatType.GROUP) - assert is_match(handler, make_command_update(command, chat=Chat(-23, Chat.GROUP))) - assert not is_match(handler, make_command_update(command, chat=Chat(23, Chat.PRIVATE))) - - def test_newline(self, dp, command): - """Assert that newlines don't interfere with a command handler matching a message""" - handler = self.make_default_handler() - dp.add_handler(handler) - update = make_command_update(command + '\nfoobar') - assert is_match(handler, update) - assert self.response(dp, update) - - def test_other_update_types(self, false_update): - """Test that a command handler doesn't respond to unrelated updates""" - handler = self.make_default_handler() - assert not is_match(handler, false_update) - - def test_filters_for_wrong_command(self, mock_filter): - """Filters should not be executed if the command does not match the handler""" - handler = self.make_default_handler(filters=mock_filter) - assert not is_match(handler, make_command_update('/star')) - assert not mock_filter.tested - - def test_context(self, dp, command_update): - """Test correct behaviour of CHs with context-based callbacks""" - handler = self.make_default_handler(self.callback_context) - dp.add_handler(handler) - assert self.response(dp, command_update) - - def test_context_args(self, dp, command): - """Test CHs that pass arguments through ``context``""" - handler = self.make_default_handler(self.callback_context_args) - self._test_context_args_or_regex(dp, handler, command) - - def test_context_regex(self, dp, command): - """Test CHs with context-based callbacks and a single filter""" - handler = self.make_default_handler( - self.callback_context_regex1, filters=filters.Regex('one two') - ) - self._test_context_args_or_regex(dp, handler, command) - - def test_context_multiple_regex(self, dp, command): - """Test CHs with context-based callbacks and filters combined""" - handler = self.make_default_handler( - self.callback_context_regex2, filters=filters.Regex('one') & filters.Regex('two') - ) - self._test_context_args_or_regex(dp, handler, command) - - -# ----------------------------- PrefixHandler ----------------------------- - - -def combinations(prefixes, commands): - return (prefix + command for prefix in prefixes for command in commands) - - -class TestPrefixHandler(BaseTest): - # Prefixes and commands with which to test PrefixHandler: - PREFIXES = ['!', '#', 'mytrig-'] - COMMANDS = ['help', 'test'] - COMBINATIONS = list(combinations(PREFIXES, COMMANDS)) - - def test_slot_behaviour(self, mro_slots): - handler = self.make_default_handler() - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - @pytest.fixture(scope='class', params=PREFIXES) - def prefix(self, request): - return request.param - - @pytest.fixture(scope='class', params=[1, 2], ids=['single prefix', 'multiple prefixes']) - def prefixes(self, request): - return TestPrefixHandler.PREFIXES[: request.param] - - @pytest.fixture(scope='class', params=COMMANDS) - def command(self, request): - return request.param - - @pytest.fixture(scope='class', params=[1, 2], ids=['single command', 'multiple commands']) - def commands(self, request): - return TestPrefixHandler.COMMANDS[: request.param] - - @pytest.fixture(scope='class') - def prefix_message_text(self, prefix, command): - return prefix + command - - @pytest.fixture(scope='class') - def prefix_message(self, prefix_message_text): - return make_message(prefix_message_text) - - @pytest.fixture(scope='class') - def prefix_message_update(self, prefix_message): - return make_message_update(prefix_message) - - def make_default_handler(self, callback=None, **kwargs): - callback = callback or self.callback_basic - return PrefixHandler(self.PREFIXES, self.COMMANDS, callback, **kwargs) - - def test_basic(self, dp, prefix, command): - """Test the basic expected response from a prefix handler""" - handler = self.make_default_handler() - dp.add_handler(handler) - text = prefix + command - - assert self.response(dp, make_message_update(text)) - assert not is_match(handler, make_message_update(command)) - assert not is_match(handler, make_message_update(prefix + 'notacommand')) - assert not is_match(handler, make_command_update(f'not {text} at start')) - - def test_single_multi_prefixes_commands(self, prefixes, commands, prefix_message_update): - """Test various combinations of prefixes and commands""" - handler = self.make_default_handler() - result = is_match(handler, prefix_message_update) - expected = prefix_message_update.message.text in combinations(prefixes, commands) - return result == expected - - def test_edited(self, prefix_message): - handler_edited = self.make_default_handler() - handler_no_edited = self.make_default_handler(filters=~filters.UpdateType.EDITED_MESSAGE) - self._test_edited(prefix_message, handler_edited, handler_no_edited) - - def test_with_filter(self, prefix_message_text): - handler = self.make_default_handler(filters=filters.ChatType.GROUP) - text = prefix_message_text - assert is_match(handler, make_message_update(text, chat=Chat(-23, Chat.GROUP))) - assert not is_match(handler, make_message_update(text, chat=Chat(23, Chat.PRIVATE))) - - def test_other_update_types(self, false_update): - handler = self.make_default_handler() - assert not is_match(handler, false_update) - - def test_filters_for_wrong_command(self, mock_filter): - """Filters should not be executed if the command does not match the handler""" - handler = self.make_default_handler(filters=mock_filter) - assert not is_match(handler, make_message_update('/test')) - assert not mock_filter.tested - - def test_edit_prefix(self): - handler = self.make_default_handler() - handler.prefix = ['?', '§'] - assert handler._commands == list(combinations(['?', '§'], self.COMMANDS)) - handler.prefix = '+' - assert handler._commands == list(combinations(['+'], self.COMMANDS)) - - def test_edit_command(self): - handler = self.make_default_handler() - handler.command = 'foo' - assert handler._commands == list(combinations(self.PREFIXES, ['foo'])) - - def test_basic_after_editing(self, dp, prefix, command): - """Test the basic expected response from a prefix handler""" - handler = self.make_default_handler() - dp.add_handler(handler) - text = prefix + command - - assert self.response(dp, make_message_update(text)) - handler.command = 'foo' - text = prefix + 'foo' - assert self.response(dp, make_message_update(text)) - - def test_context(self, dp, prefix_message_update): - handler = self.make_default_handler(self.callback_context) - dp.add_handler(handler) - assert self.response(dp, prefix_message_update) - - def test_context_args(self, dp, prefix_message_text): - handler = self.make_default_handler(self.callback_context_args) - self._test_context_args_or_regex(dp, handler, prefix_message_text) - - def test_context_regex(self, dp, prefix_message_text): - handler = self.make_default_handler( - self.callback_context_regex1, filters=filters.Regex('one two') - ) - self._test_context_args_or_regex(dp, handler, prefix_message_text) - - def test_context_multiple_regex(self, dp, prefix_message_text): - handler = self.make_default_handler( - self.callback_context_regex2, filters=filters.Regex('one') & filters.Regex('two') - ) - self._test_context_args_or_regex(dp, handler, prefix_message_text) diff --git a/tests/test_constants.py b/tests/test_constants.py index 42e8a6794b9..a034341e722 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -88,23 +88,27 @@ def test_int_inheritance(self): assert hash(IntEnumTest.FOO) == hash(1) @flaky(3, 1) - def test_max_message_length(self, bot, chat_id): - bot.send_message(chat_id=chat_id, text='a' * constants.MessageLimit.TEXT_LENGTH) + @pytest.mark.asyncio + async def test_max_message_length(self, bot, chat_id): + await bot.send_message(chat_id=chat_id, text='a' * constants.MessageLimit.TEXT_LENGTH) with pytest.raises( BadRequest, match='Message is too long', ): - bot.send_message(chat_id=chat_id, text='a' * (constants.MessageLimit.TEXT_LENGTH + 1)) + await bot.send_message( + chat_id=chat_id, text='a' * (constants.MessageLimit.TEXT_LENGTH + 1) + ) @flaky(3, 1) - def test_max_caption_length(self, bot, chat_id): + @pytest.mark.asyncio + async def test_max_caption_length(self, bot, chat_id): good_caption = 'a' * constants.MessageLimit.CAPTION_LENGTH with data_file('telegram.png').open('rb') as f: - good_msg = bot.send_photo(photo=f, caption=good_caption, chat_id=chat_id) + good_msg = await bot.send_photo(photo=f, caption=good_caption, chat_id=chat_id) assert good_msg.caption == good_caption bad_caption = good_caption + 'Z' match = "Media_caption_too_long" with pytest.raises(BadRequest, match=match), data_file('telegram.png').open('rb') as f: - bot.send_photo(photo=f, caption=bad_caption, chat_id=chat_id) + await bot.send_photo(photo=f, caption=bad_caption, chat_id=chat_id) diff --git a/tests/test_contact.py b/tests/test_contact.py index 1ddd4970ef6..633d20e2f50 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -22,6 +22,7 @@ from telegram import Contact, Voice from telegram.error import BadRequest +from telegram.request import RequestData @pytest.fixture(scope='class') @@ -66,15 +67,17 @@ def test_de_json_all(self, bot): assert contact.last_name == self.last_name assert contact.user_id == self.user_id - def test_send_with_contact(self, monkeypatch, bot, chat_id, contact): - def test(url, data, **kwargs): + @pytest.mark.asyncio + async def test_send_with_contact(self, monkeypatch, bot, chat_id, contact): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters phone = data['phone_number'] == contact.phone_number first = data['first_name'] == contact.first_name last = data['last_name'] == contact.last_name return phone and first and last - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_contact(contact=contact, chat_id=chat_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_contact(contact=contact, chat_id=chat_id) assert message @flaky(3, 1) @@ -87,13 +90,14 @@ def test(url, data, **kwargs): ], indirect=['default_bot'], ) - def test_send_contact_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_contact_default_allow_sending_without_reply( self, default_bot, chat_id, contact, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_contact( + message = await default_bot.send_contact( chat_id, contact=contact, allow_sending_without_reply=custom, @@ -101,27 +105,31 @@ def test_send_contact_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_contact( + message = await default_bot.send_contact( chat_id, contact=contact, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_contact( + await default_bot.send_contact( chat_id, contact=contact, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_contact_default_protect_content(self, chat_id, default_bot, contact): - protected = default_bot.send_contact(chat_id, contact=contact) + async def test_send_contact_default_protect_content(self, chat_id, default_bot, contact): + protected = await default_bot.send_contact(chat_id, contact=contact) assert protected.has_protected_content - unprotected = default_bot.send_contact(chat_id, contact=contact, protect_content=False) + unprotected = await default_bot.send_contact( + chat_id, contact=contact, protect_content=False + ) assert not unprotected.has_protected_content - def test_send_contact_without_required(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_contact_without_required(self, bot, chat_id): with pytest.raises(ValueError, match='Either contact or phone_number and first_name'): - bot.send_contact(chat_id=chat_id) + await bot.send_contact(chat_id=chat_id) def test_to_dict(self, contact): contact_dict = contact.to_dict() diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py deleted file mode 100644 index 0cc449f586a..00000000000 --- a/tests/test_conversationhandler.py +++ /dev/null @@ -1,1785 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import logging -from time import sleep -from warnings import filterwarnings - -import pytest -from flaky import flaky - -from telegram import ( - CallbackQuery, - Chat, - ChosenInlineResult, - InlineQuery, - Message, - PreCheckoutQuery, - ShippingQuery, - Update, - User, - MessageEntity, -) -from telegram.ext import ( - ConversationHandler, - CommandHandler, - CallbackQueryHandler, - MessageHandler, - filters, - InlineQueryHandler, - CallbackContext, - DispatcherHandlerStop, - TypeHandler, - JobQueue, - StringCommandHandler, - StringRegexHandler, - PollHandler, - ShippingQueryHandler, - ChosenInlineResultHandler, - PreCheckoutQueryHandler, - PollAnswerHandler, -) -from telegram.warnings import PTBUserWarning - - -@pytest.fixture(scope='class') -def user1(): - return User(first_name='Misses Test', id=123, is_bot=False) - - -@pytest.fixture(scope='class') -def user2(): - return User(first_name='Mister Test', id=124, is_bot=False) - - -@pytest.fixture(autouse=True) -def start_stop_job_queue(dp): - dp.job_queue = JobQueue() - dp.job_queue.set_dispatcher(dp) - dp.job_queue.start() - yield - dp.job_queue.stop() - - -def raise_dphs(func): - def decorator(self, *args, **kwargs): - result = func(self, *args, **kwargs) - if self.raise_dp_handler_stop: - raise DispatcherHandlerStop(result) - return result - - return decorator - - -class TestConversationHandler: - # State definitions - # At first we're thirsty. Then we brew coffee, we drink it - # and then we can start coding! - END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4) - - # Drinking state definitions (nested) - # At first we're holding the cup. Then we sip coffee, and last we swallow it - HOLDING, SIPPING, SWALLOWING, REPLENISHING, STOPPING = map(chr, range(ord('a'), ord('f'))) - - current_state, entry_points, states, fallbacks = None, None, None, None - group = Chat(0, Chat.GROUP) - second_group = Chat(1, Chat.GROUP) - - raise_dp_handler_stop = False - test_flag = False - - def test_slot_behaviour(self, mro_slots): - handler = ConversationHandler(self.entry_points, self.states, self.fallbacks) - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - # Test related - @pytest.fixture(autouse=True) - def reset(self): - self.raise_dp_handler_stop = False - self.test_flag = False - self.current_state = {} - self.entry_points = [CommandHandler('start', self.start)] - self.states = { - self.THIRSTY: [CommandHandler('brew', self.brew), CommandHandler('wait', self.start)], - self.BREWING: [CommandHandler('pourCoffee', self.drink)], - self.DRINKING: [ - CommandHandler('startCoding', self.code), - CommandHandler('drinkMore', self.drink), - CommandHandler('end', self.end), - ], - self.CODING: [ - CommandHandler('keepCoding', self.code), - CommandHandler('gettingThirsty', self.start), - CommandHandler('drinkMore', self.drink), - ], - } - self.fallbacks = [CommandHandler('eat', self.start)] - self.is_timeout = False - - # for nesting tests - self.nested_states = { - self.THIRSTY: [CommandHandler('brew', self.brew), CommandHandler('wait', self.start)], - self.BREWING: [CommandHandler('pourCoffee', self.drink)], - self.CODING: [ - CommandHandler('keepCoding', self.code), - CommandHandler('gettingThirsty', self.start), - CommandHandler('drinkMore', self.drink), - ], - } - self.drinking_entry_points = [CommandHandler('hold', self.hold)] - self.drinking_states = { - self.HOLDING: [CommandHandler('sip', self.sip)], - self.SIPPING: [CommandHandler('swallow', self.swallow)], - self.SWALLOWING: [CommandHandler('hold', self.hold)], - } - self.drinking_fallbacks = [ - CommandHandler('replenish', self.replenish), - CommandHandler('stop', self.stop), - CommandHandler('end', self.end), - CommandHandler('startCoding', self.code), - CommandHandler('drinkMore', self.drink), - ] - self.drinking_entry_points.extend(self.drinking_fallbacks) - - # Map nested states to parent states: - self.drinking_map_to_parent = { - # Option 1 - Map a fictional internal state to an external parent state - self.REPLENISHING: self.BREWING, - # Option 2 - Map a fictional internal state to the END state on the parent - self.STOPPING: self.END, - # Option 3 - Map the internal END state to an external parent state - self.END: self.CODING, - # Option 4 - Map an external state to the same external parent state - self.CODING: self.CODING, - # Option 5 - Map an external state to the internal entry point - self.DRINKING: self.DRINKING, - } - - # State handlers - def _set_state(self, update, state): - self.current_state[update.message.from_user.id] = state - return state - - # Actions - @raise_dphs - def start(self, update, context): - if isinstance(update, Update): - return self._set_state(update, self.THIRSTY) - return self._set_state(context.bot, self.THIRSTY) - - @raise_dphs - def end(self, update, context): - return self._set_state(update, self.END) - - @raise_dphs - def start_end(self, update, context): - return self._set_state(update, self.END) - - @raise_dphs - def start_none(self, update, context): - return self._set_state(update, None) - - @raise_dphs - def brew(self, update, context): - if isinstance(update, Update): - return self._set_state(update, self.BREWING) - return self._set_state(context.bot, self.BREWING) - - @raise_dphs - def drink(self, update, context): - return self._set_state(update, self.DRINKING) - - @raise_dphs - def code(self, update, context): - return self._set_state(update, self.CODING) - - @raise_dphs - def passout(self, update, context): - assert update.message.text == '/brew' - assert isinstance(update, Update) - self.is_timeout = True - - @raise_dphs - def passout2(self, update, context): - assert isinstance(update, Update) - self.is_timeout = True - - @raise_dphs - def passout_context(self, update, context): - assert update.message.text == '/brew' - assert isinstance(context, CallbackContext) - self.is_timeout = True - - @raise_dphs - def passout2_context(self, update, context): - assert isinstance(context, CallbackContext) - self.is_timeout = True - - # Drinking actions (nested) - - @raise_dphs - def hold(self, update, context): - return self._set_state(update, self.HOLDING) - - @raise_dphs - def sip(self, update, context): - return self._set_state(update, self.SIPPING) - - @raise_dphs - def swallow(self, update, context): - return self._set_state(update, self.SWALLOWING) - - @raise_dphs - def replenish(self, update, context): - return self._set_state(update, self.REPLENISHING) - - @raise_dphs - def stop(self, update, context): - return self._set_state(update, self.STOPPING) - - # Tests - @pytest.mark.parametrize( - 'attr', - [ - 'entry_points', - 'states', - 'fallbacks', - 'per_chat', - 'name', - 'per_user', - 'allow_reentry', - 'conversation_timeout', - 'map_to_parent', - ], - indirect=False, - ) - def test_immutable(self, attr): - ch = ConversationHandler( - 'entry_points', - {'states': ['states']}, - 'fallbacks', - per_chat='per_chat', - per_user='per_user', - per_message=False, - allow_reentry='allow_reentry', - conversation_timeout='conversation_timeout', - name='name', - map_to_parent='map_to_parent', - ) - - value = getattr(ch, attr) - if isinstance(value, list): - assert value[0] == attr - elif isinstance(value, dict): - assert list(value.keys())[0] == attr - else: - assert getattr(ch, attr) == attr - with pytest.raises(AttributeError, match=f'You can not assign a new value to {attr}'): - setattr(ch, attr, True) - - def test_immutable_per_message(self): - ch = ConversationHandler( - 'entry_points', - {'states': ['states']}, - 'fallbacks', - per_chat='per_chat', - per_user='per_user', - per_message=False, - allow_reentry='allow_reentry', - conversation_timeout='conversation_timeout', - name='name', - map_to_parent='map_to_parent', - ) - assert ch.per_message is False - with pytest.raises(AttributeError, match='You can not assign a new value to per_message'): - ch.per_message = True - - def test_per_all_false(self): - with pytest.raises(ValueError, match="can't all be 'False'"): - ConversationHandler( - self.entry_points, - self.states, - self.fallbacks, - per_chat=False, - per_user=False, - per_message=False, - ) - - def test_name_and_persistent(self, dp): - with pytest.raises(ValueError, match="when handler is unnamed"): - dp.add_handler(ConversationHandler([], {}, [], persistent=True)) - c = ConversationHandler([], {}, [], name="handler", persistent=True) - assert c.name == "handler" - - def test_conversation_handler(self, dp, bot, user1, user2): - handler = ConversationHandler( - entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks - ) - dp.add_handler(handler) - - # User one, starts the state machine. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.THIRSTY - - # The user is thirsty and wants to brew coffee. - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.BREWING - - # Lets see if an invalid command makes sure, no state is changed. - message.text = '/nothing' - message.entities[0].length = len('/nothing') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.BREWING - - # Lets see if the state machine still works by pouring coffee. - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - - # Let's now verify that for another user, who did not start yet, - # the state has not been changed. - message.from_user = user2 - dp.process_update(Update(update_id=0, message=message)) - with pytest.raises(KeyError): - self.current_state[user2.id] - - def test_conversation_handler_end(self, caplog, dp, bot, user1): - handler = ConversationHandler( - entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks - ) - dp.add_handler(handler) - - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - message.text = '/end' - message.entities[0].length = len('/end') - caplog.clear() - with caplog.at_level(logging.ERROR): - dp.process_update(Update(update_id=0, message=message)) - assert len(caplog.records) == 0 - assert self.current_state[user1.id] == self.END - with pytest.raises(KeyError): - print(handler.conversations[(self.group.id, user1.id)]) - - def test_conversation_handler_fallback(self, dp, bot, user1, user2): - handler = ConversationHandler( - entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks - ) - dp.add_handler(handler) - - # first check if fallback will not trigger start when not started - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/eat', - entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/eat'))], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - with pytest.raises(KeyError): - self.current_state[user1.id] - - # User starts the state machine. - message.text = '/start' - message.entities[0].length = len('/start') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.THIRSTY - - # The user is thirsty and wants to brew coffee. - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.BREWING - - # Now a fallback command is issued - message.text = '/eat' - message.entities[0].length = len('/eat') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.THIRSTY - - def test_unknown_state_warning(self, dp, bot, user1, recwarn): - handler = ConversationHandler( - entry_points=[CommandHandler("start", lambda u, c: 1)], - states={ - 1: [TypeHandler(Update, lambda u, c: 69)], - 2: [TypeHandler(Update, lambda u, c: -1)], - }, - fallbacks=self.fallbacks, - name="xyz", - ) - dp.add_handler(handler) - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - sleep(0.5) - dp.process_update(Update(update_id=1, message=message)) - sleep(0.5) - assert len(recwarn) == 1 - assert str(recwarn[0].message) == ( - "Handler returned state 69 which is unknown to the ConversationHandler xyz." - ) - - def test_conversation_handler_per_chat(self, dp, bot, user1, user2): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - per_user=False, - ) - dp.add_handler(handler) - - # User one, starts the state machine. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - - # The user is thirsty and wants to brew coffee. - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - - # Let's now verify that for another user, who did not start yet, - # the state will be changed because they are in the same group. - message.from_user = user2 - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - - assert handler.conversations[(self.group.id,)] == self.DRINKING - - def test_conversation_handler_per_user(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - per_chat=False, - ) - dp.add_handler(handler) - - # User one, starts the state machine. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - - # The user is thirsty and wants to brew coffee. - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - - # Let's now verify that for the same user in a different group, the state will still be - # updated - message.chat = self.second_group - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - - assert handler.conversations[(user1.id,)] == self.DRINKING - - def test_conversation_handler_per_message(self, dp, bot, user1, user2): - def entry(update, context): - return 1 - - def one(update, context): - return 2 - - def two(update, context): - return ConversationHandler.END - - handler = ConversationHandler( - entry_points=[CallbackQueryHandler(entry)], - states={1: [CallbackQueryHandler(one)], 2: [CallbackQueryHandler(two)]}, - fallbacks=[], - per_message=True, - ) - dp.add_handler(handler) - - # User one, starts the state machine. - message = Message( - 0, None, self.group, from_user=user1, text='msg w/ inlinekeyboard', bot=bot - ) - - cbq = CallbackQuery(0, user1, None, message=message, data='data', bot=bot) - dp.process_update(Update(update_id=0, callback_query=cbq)) - - assert handler.conversations[(self.group.id, user1.id, message.message_id)] == 1 - - dp.process_update(Update(update_id=0, callback_query=cbq)) - - assert handler.conversations[(self.group.id, user1.id, message.message_id)] == 2 - - # Let's now verify that for a different user in the same group, the state will not be - # updated - cbq.from_user = user2 - dp.process_update(Update(update_id=0, callback_query=cbq)) - - assert handler.conversations[(self.group.id, user1.id, message.message_id)] == 2 - - def test_end_on_first_message(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[] - ) - dp.add_handler(handler) - - # User starts the state machine and immediately ends it. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - assert len(handler.conversations) == 0 - - def test_end_on_first_message_async(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[ - CommandHandler( - 'start', lambda update, context: dp.run_async(self.start_end, update, context) - ) - ], - states={}, - fallbacks=[], - ) - dp.add_handler(handler) - - # User starts the state machine with an async function that immediately ends the - # conversation. Async results are resolved when the users state is queried next time. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been accepted as the new state - assert len(handler.conversations) == 1 - - message.text = 'resolve promise pls' - message.entities[0].length = len('resolve promise pls') - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been resolved and the conversation ended. - assert len(handler.conversations) == 0 - - def test_end_on_first_message_async_handler(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_end, run_async=True)], - states={}, - fallbacks=[], - ) - dp.add_handler(handler) - - # User starts the state machine with an async function that immediately ends the - # conversation. Async results are resolved when the users state is queried next time. - message = Message( - 0, - None, - self.group, - text='/start', - from_user=user1, - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been accepted as the new state - assert len(handler.conversations) == 1 - - message.text = 'resolve promise pls' - message.entities[0].length = len('resolve promise pls') - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been resolved and the conversation ended. - assert len(handler.conversations) == 0 - - def test_none_on_first_message(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_none)], states={}, fallbacks=[] - ) - dp.add_handler(handler) - - # User starts the state machine and a callback function returns None - message = Message(0, None, self.group, from_user=user1, text='/start', bot=bot) - dp.process_update(Update(update_id=0, message=message)) - assert len(handler.conversations) == 0 - - def test_none_on_first_message_async(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[ - CommandHandler( - 'start', lambda update, context: dp.run_async(self.start_none, update, context) - ) - ], - states={}, - fallbacks=[], - ) - dp.add_handler(handler) - - # User starts the state machine with an async function that returns None - # Async results are resolved when the users state is queried next time. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been accepted as the new state - assert len(handler.conversations) == 1 - - message.text = 'resolve promise pls' - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been resolved and the conversation ended. - assert len(handler.conversations) == 0 - - def test_none_on_first_message_async_handler(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_none, run_async=True)], - states={}, - fallbacks=[], - ) - dp.add_handler(handler) - - # User starts the state machine with an async function that returns None - # Async results are resolved when the users state is queried next time. - message = Message( - 0, - None, - self.group, - text='/start', - from_user=user1, - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been accepted as the new state - assert len(handler.conversations) == 1 - - message.text = 'resolve promise pls' - dp.update_queue.put(Update(update_id=0, message=message)) - sleep(0.1) - # Assert that the Promise has been resolved and the conversation ended. - assert len(handler.conversations) == 0 - - def test_per_chat_message_without_chat(self, bot, user1): - handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[] - ) - cbq = CallbackQuery(0, user1, None, None, bot=bot) - update = Update(0, callback_query=cbq) - assert not handler.check_update(update) - - def test_channel_message_without_chat(self, bot): - handler = ConversationHandler( - entry_points=[MessageHandler(filters.ALL, self.start_end)], states={}, fallbacks=[] - ) - message = Message(0, date=None, chat=Chat(0, Chat.CHANNEL, 'Misses Test'), bot=bot) - - update = Update(0, channel_post=message) - assert not handler.check_update(update) - - update = Update(0, edited_channel_post=message) - assert not handler.check_update(update) - - def test_all_update_types(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[] - ) - message = Message(0, None, self.group, from_user=user1, text='ignore', bot=bot) - callback_query = CallbackQuery(0, user1, None, message=message, data='data', bot=bot) - chosen_inline_result = ChosenInlineResult(0, user1, 'query', bot=bot) - inline_query = InlineQuery(0, user1, 'query', 0, bot=bot) - pre_checkout_query = PreCheckoutQuery(0, user1, 'USD', 100, [], bot=bot) - shipping_query = ShippingQuery(0, user1, [], None, bot=bot) - assert not handler.check_update(Update(0, callback_query=callback_query)) - assert not handler.check_update(Update(0, chosen_inline_result=chosen_inline_result)) - assert not handler.check_update(Update(0, inline_query=inline_query)) - assert not handler.check_update(Update(0, message=message)) - assert not handler.check_update(Update(0, pre_checkout_query=pre_checkout_query)) - assert not handler.check_update(Update(0, shipping_query=shipping_query)) - - def test_no_jobqueue_warning(self, dp, bot, user1, recwarn): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - # save dp.job_queue in temp variable jqueue - # and then set dp.job_queue to None. - jqueue = dp.job_queue - dp.job_queue = None - dp.add_handler(handler) - - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - - dp.process_update(Update(update_id=0, message=message)) - sleep(0.5) - assert len(recwarn) == 1 - assert ( - str(recwarn[0].message) - == "Ignoring `conversation_timeout` because the Dispatcher has no JobQueue." - ) - # now set dp.job_queue back to it's original value - dp.job_queue = jqueue - - def test_schedule_job_exception(self, dp, bot, user1, monkeypatch, caplog): - def mocked_run_once(*a, **kw): - raise Exception("job error") - - class DictJB(JobQueue): - pass - - dp.job_queue = DictJB() - monkeypatch.setattr(dp.job_queue, "run_once", mocked_run_once) - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=100, - ) - dp.add_handler(handler) - - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - - with caplog.at_level(logging.ERROR): - dp.process_update(Update(update_id=0, message=message)) - sleep(0.5) - assert len(caplog.records) == 2 - assert ( - caplog.records[0].message - == "Failed to schedule timeout job due to the following exception:" - ) - assert caplog.records[1].message == "job error" - - def test_promise_exception(self, dp, bot, user1, caplog): - """ - Here we make sure that when a run_async handle raises an - exception, the state isn't changed. - """ - - def conv_entry(*a, **kw): - return 1 - - def raise_error(*a, **kw): - raise Exception("promise exception") - - handler = ConversationHandler( - entry_points=[CommandHandler("start", conv_entry)], - states={1: [MessageHandler(filters.ALL, raise_error)]}, - fallbacks=self.fallbacks, - run_async=True, - ) - dp.add_handler(handler) - - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - # start the conversation - dp.process_update(Update(update_id=0, message=message)) - sleep(0.1) - message.text = "error" - dp.process_update(Update(update_id=0, message=message)) - sleep(0.1) - message.text = "resolve promise pls" - caplog.clear() - with caplog.at_level(logging.ERROR): - dp.process_update(Update(update_id=0, message=message)) - sleep(0.5) - assert len(caplog.records) == 3 - assert caplog.records[0].message == "Promise function raised exception" - assert caplog.records[1].message == "promise exception" - # assert res is old state - assert handler.conversations.get((self.group.id, user1.id))[0] == 1 - - def test_conversation_timeout(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # Start state machine, then reach timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY - sleep(0.75) - assert handler.conversations.get((self.group.id, user1.id)) is None - - # Start state machine, do something, then reach timeout - dp.process_update(Update(update_id=1, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=2, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.BREWING - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - - def test_timeout_not_triggered_on_conv_end_async(self, bot, dp, user1): - def timeout(*a, **kw): - self.test_flag = True - - self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]}) - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - run_async=True, - ) - dp.add_handler(handler) - - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - # start the conversation - dp.process_update(Update(update_id=0, message=message)) - sleep(0.1) - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=1, message=message)) - sleep(0.1) - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=2, message=message)) - sleep(0.1) - message.text = '/end' - message.entities[0].length = len('/end') - dp.process_update(Update(update_id=3, message=message)) - sleep(1) - # assert timeout handler didn't got called - assert self.test_flag is False - - def test_conversation_timeout_dispatcher_handler_stop(self, dp, bot, user1, recwarn): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - - def timeout(*args, **kwargs): - raise DispatcherHandlerStop() - - self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]}) - dp.add_handler(handler) - - # Start state machine, then reach timeout - message = Message( - 0, - None, - self.group, - text='/start', - from_user=user1, - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY - sleep(0.9) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert len(recwarn) == 1 - assert str(recwarn[0].message).startswith('DispatcherHandlerStop in TIMEOUT') - - def test_conversation_handler_timeout_update_and_context(self, dp, bot, user1): - context = None - - def start_callback(u, c): - nonlocal context, self - context = c - return self.start(u, c) - - states = self.states - timeout_handler = CommandHandler('start', None) - states.update({ConversationHandler.TIMEOUT: [timeout_handler]}) - handler = ConversationHandler( - entry_points=[CommandHandler('start', start_callback)], - states=states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # Start state machine, then reach timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - update = Update(update_id=0, message=message) - - def timeout_callback(u, c): - nonlocal update, context, self - self.is_timeout = True - assert u is update - assert c is context - - timeout_handler.callback = timeout_callback - - dp.process_update(update) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert self.is_timeout - - @flaky(3, 1) - def test_conversation_timeout_keeps_extending(self, dp, bot, user1): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # Start state machine, wait, do something, verify the timeout is extended. - # t=0 /start (timeout=.5) - # t=.35 /brew (timeout=.85) - # t=.5 original timeout - # t=.6 /pourCoffee (timeout=1.1) - # t=.85 second timeout - # t=1.1 actual timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY - sleep(0.35) # t=.35 - assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.BREWING - sleep(0.25) # t=.6 - assert handler.conversations.get((self.group.id, user1.id)) == self.BREWING - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.DRINKING - sleep(0.4) # t=1.0 - assert handler.conversations.get((self.group.id, user1.id)) == self.DRINKING - sleep(0.3) # t=1.3 - assert handler.conversations.get((self.group.id, user1.id)) is None - - def test_conversation_timeout_two_users(self, dp, bot, user1, user2): - handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # Start state machine, do something as second user, then reach timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY - message.text = '/brew' - message.entities[0].length = len('/brew') - message.entities[0].length = len('/brew') - message.from_user = user2 - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user2.id)) is None - message.text = '/start' - message.entities[0].length = len('/start') - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user2.id)) == self.THIRSTY - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert handler.conversations.get((self.group.id, user2.id)) is None - - def test_conversation_handler_timeout_state(self, dp, bot, user1): - states = self.states - states.update( - { - ConversationHandler.TIMEOUT: [ - CommandHandler('brew', self.passout), - MessageHandler(~filters.Regex('oding'), self.passout2), - ] - } - ) - handler = ConversationHandler( - entry_points=self.entry_points, - states=states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # CommandHandler timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert self.is_timeout - - # MessageHandler timeout - self.is_timeout = False - message.text = '/start' - message.entities[0].length = len('/start') - dp.process_update(Update(update_id=1, message=message)) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert self.is_timeout - - # Timeout but no valid handler - self.is_timeout = False - dp.process_update(Update(update_id=0, message=message)) - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - message.text = '/startCoding' - message.entities[0].length = len('/startCoding') - dp.process_update(Update(update_id=0, message=message)) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert not self.is_timeout - - def test_conversation_handler_timeout_state_context(self, dp, bot, user1): - states = self.states - states.update( - { - ConversationHandler.TIMEOUT: [ - CommandHandler('brew', self.passout_context), - MessageHandler(~filters.Regex('oding'), self.passout2_context), - ] - } - ) - handler = ConversationHandler( - entry_points=self.entry_points, - states=states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # CommandHandler timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert self.is_timeout - - # MessageHandler timeout - self.is_timeout = False - message.text = '/start' - message.entities[0].length = len('/start') - dp.process_update(Update(update_id=1, message=message)) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert self.is_timeout - - # Timeout but no valid handler - self.is_timeout = False - dp.process_update(Update(update_id=0, message=message)) - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - message.text = '/startCoding' - message.entities[0].length = len('/startCoding') - dp.process_update(Update(update_id=0, message=message)) - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert not self.is_timeout - - def test_conversation_timeout_cancel_conflict(self, dp, bot, user1): - # Start state machine, wait half the timeout, - # then call a callback that takes more than the timeout - # t=0 /start (timeout=.5) - # t=.25 /slowbrew (sleep .5) - # | t=.5 original timeout (should not execute) - # | t=.75 /slowbrew returns (timeout=1.25) - # t=1.25 timeout - - def slowbrew(_update, context): - sleep(0.25) - # Let's give to the original timeout a chance to execute - sleep(0.25) - # By returning None we do not override the conversation state so - # we can see if the timeout has been executed - - states = self.states - states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew)) - states.update({ConversationHandler.TIMEOUT: [MessageHandler(None, self.passout2)]}) - - handler = ConversationHandler( - entry_points=self.entry_points, - states=states, - fallbacks=self.fallbacks, - conversation_timeout=0.5, - ) - dp.add_handler(handler) - - # CommandHandler timeout - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ) - dp.process_update(Update(update_id=0, message=message)) - sleep(0.25) - message.text = '/slowbrew' - message.entities[0].length = len('/slowbrew') - dp.process_update(Update(update_id=0, message=message)) - assert handler.conversations.get((self.group.id, user1.id)) is not None - assert not self.is_timeout - - sleep(0.7) - assert handler.conversations.get((self.group.id, user1.id)) is None - assert self.is_timeout - - def test_handlers_generate_warning(self, recwarn): - """ - this function tests all handler + per_* setting combinations. - """ - - # the warning message action needs to be set to always, - # otherwise only the first occurrence will be issued - filterwarnings(action="always", category=PTBUserWarning) - - # this class doesn't do anything, its just not the Update class - class NotUpdate: - pass - - # this conversation handler has the string, string_regex, Pollhandler and TypeHandler - # which should all generate a warning no matter the per_* setting. TypeHandler should - # not when the class is Update - ConversationHandler( - entry_points=[StringCommandHandler("code", self.code)], - states={ - self.BREWING: [ - StringRegexHandler("code", self.code), - PollHandler(self.code), - TypeHandler(NotUpdate, self.code), - ], - }, - fallbacks=[TypeHandler(Update, self.code)], - ) - - # these handlers should all raise a warning when per_chat is True - ConversationHandler( - entry_points=[ShippingQueryHandler(self.code)], - states={ - self.BREWING: [ - InlineQueryHandler(self.code), - PreCheckoutQueryHandler(self.code), - PollAnswerHandler(self.code), - ], - }, - fallbacks=[ChosenInlineResultHandler(self.code)], - per_chat=True, - ) - - # the CallbackQueryHandler should *not* raise when per_message is True, - # but any other one should - ConversationHandler( - entry_points=[CallbackQueryHandler(self.code)], - states={ - self.BREWING: [CommandHandler("code", self.code)], - }, - fallbacks=[CallbackQueryHandler(self.code)], - per_message=True, - ) - - # the CallbackQueryHandler should raise when per_message is False - ConversationHandler( - entry_points=[CommandHandler("code", self.code)], - states={ - self.BREWING: [CommandHandler("code", self.code)], - }, - fallbacks=[CallbackQueryHandler(self.code)], - per_message=False, - ) - - # adding a nested conv to a conversation with timeout should warn - child = ConversationHandler( - entry_points=[CommandHandler("code", self.code)], - states={ - self.BREWING: [CommandHandler("code", self.code)], - }, - fallbacks=[CommandHandler("code", self.code)], - ) - - ConversationHandler( - entry_points=[CommandHandler("code", self.code)], - states={ - self.BREWING: [child], - }, - fallbacks=[CommandHandler("code", self.code)], - conversation_timeout=42, - ) - - # If per_message is True, per_chat should also be True, since msg ids are not unique - ConversationHandler( - entry_points=[CallbackQueryHandler(self.code, "code")], - states={ - self.BREWING: [CallbackQueryHandler(self.code, "code")], - }, - fallbacks=[CallbackQueryHandler(self.code, "code")], - per_message=True, - per_chat=False, - ) - - # the overall number of handlers throwing a warning is 13 - assert len(recwarn) == 13 - # now we test the messages, they are raised in the order they are inserted - # into the conversation handler - assert str(recwarn[0].message) == ( - "The `ConversationHandler` only handles updates of type `telegram.Update`. " - "StringCommandHandler handles updates of type `str`." - ) - assert str(recwarn[1].message) == ( - "The `ConversationHandler` only handles updates of type `telegram.Update`. " - "StringRegexHandler handles updates of type `str`." - ) - assert str(recwarn[2].message) == ( - "PollHandler will never trigger in a conversation since it has no information " - "about the chat or the user who voted in it. Do you mean the " - "`PollAnswerHandler`?" - ) - assert str(recwarn[3].message) == ( - "The `ConversationHandler` only handles updates of type `telegram.Update`. " - "The TypeHandler is set to handle NotUpdate." - ) - - per_faq_link = ( - " Read this FAQ entry to learn more about the per_* settings: " - "https://github.com/python-telegram-bot/python-telegram-bot/wiki" - "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversationhandler-do." - ) - - assert str(recwarn[4].message) == ( - "Updates handled by ShippingQueryHandler only have information about the user," - " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link - ) - assert str(recwarn[5].message) == ( - "Updates handled by ChosenInlineResultHandler only have information about the user," - " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link - ) - assert str(recwarn[6].message) == ( - "Updates handled by InlineQueryHandler only have information about the user," - " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link - ) - assert str(recwarn[7].message) == ( - "Updates handled by PreCheckoutQueryHandler only have information about the user," - " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link - ) - assert str(recwarn[8].message) == ( - "Updates handled by PollAnswerHandler only have information about the user," - " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link - ) - assert str(recwarn[9].message) == ( - "If 'per_message=True', all entry points, state handlers, and fallbacks must be " - "'CallbackQueryHandler', since no other handlers have a message context." - + per_faq_link - ) - assert str(recwarn[10].message) == ( - "If 'per_message=False', 'CallbackQueryHandler' will not be tracked for " - "every message." + per_faq_link - ) - assert str(recwarn[11].message) == ( - "Using `conversation_timeout` with nested conversations is currently not " - "supported. You can still try to use it, but it will likely behave differently" - " from what you expect." - ) - - assert str(recwarn[12].message) == ( - "If 'per_message=True' is used, 'per_chat=True' should also be used, " - "since message IDs are not globally unique." - ) - - # this for loop checks if the correct stacklevel is used when generating the warning - for warning in recwarn: - assert warning.filename == __file__, "incorrect stacklevel!" - - def test_nested_conversation_handler(self, dp, bot, user1, user2): - self.nested_states[self.DRINKING] = [ - ConversationHandler( - entry_points=self.drinking_entry_points, - states=self.drinking_states, - fallbacks=self.drinking_fallbacks, - map_to_parent=self.drinking_map_to_parent, - ) - ] - handler = ConversationHandler( - entry_points=self.entry_points, states=self.nested_states, fallbacks=self.fallbacks - ) - dp.add_handler(handler) - - # User one, starts the state machine. - message = Message( - 0, - None, - self.group, - from_user=user1, - text='/start', - bot=bot, - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - ) - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.THIRSTY - - # The user is thirsty and wants to brew coffee. - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.BREWING - - # Lets pour some coffee. - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - - # The user is holding the cup - message.text = '/hold' - message.entities[0].length = len('/hold') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.HOLDING - - # The user is sipping coffee - message.text = '/sip' - message.entities[0].length = len('/sip') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.SIPPING - - # The user is swallowing - message.text = '/swallow' - message.entities[0].length = len('/swallow') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.SWALLOWING - - # The user is holding the cup again - message.text = '/hold' - message.entities[0].length = len('/hold') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.HOLDING - - # The user wants to replenish the coffee supply - message.text = '/replenish' - message.entities[0].length = len('/replenish') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.REPLENISHING - assert handler.conversations[(0, user1.id)] == self.BREWING - - # The user wants to drink their coffee again - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - - # The user is now ready to start coding - message.text = '/startCoding' - message.entities[0].length = len('/startCoding') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.CODING - - # The user decides it's time to drink again - message.text = '/drinkMore' - message.entities[0].length = len('/drinkMore') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - - # The user is holding their cup - message.text = '/hold' - message.entities[0].length = len('/hold') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.HOLDING - - # The user wants to end with the drinking and go back to coding - message.text = '/end' - message.entities[0].length = len('/end') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.END - assert handler.conversations[(0, user1.id)] == self.CODING - - # The user wants to drink once more - message.text = '/drinkMore' - message.entities[0].length = len('/drinkMore') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - - # The user wants to stop altogether - message.text = '/stop' - message.entities[0].length = len('/stop') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.STOPPING - assert handler.conversations.get((0, user1.id)) is None - - def test_conversation_dispatcher_handler_stop(self, dp, bot, user1, user2): - self.nested_states[self.DRINKING] = [ - ConversationHandler( - entry_points=self.drinking_entry_points, - states=self.drinking_states, - fallbacks=self.drinking_fallbacks, - map_to_parent=self.drinking_map_to_parent, - ) - ] - handler = ConversationHandler( - entry_points=self.entry_points, states=self.nested_states, fallbacks=self.fallbacks - ) - - def test_callback(u, c): - self.test_flag = True - - dp.add_handler(handler) - dp.add_handler(TypeHandler(Update, test_callback), group=1) - self.raise_dp_handler_stop = True - - # User one, starts the state machine. - message = Message( - 0, - None, - self.group, - text='/start', - bot=bot, - from_user=user1, - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - ) - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.THIRSTY - assert not self.test_flag - - # The user is thirsty and wants to brew coffee. - message.text = '/brew' - message.entities[0].length = len('/brew') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.BREWING - assert not self.test_flag - - # Lets pour some coffee. - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - assert not self.test_flag - - # The user is holding the cup - message.text = '/hold' - message.entities[0].length = len('/hold') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.HOLDING - assert not self.test_flag - - # The user is sipping coffee - message.text = '/sip' - message.entities[0].length = len('/sip') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.SIPPING - assert not self.test_flag - - # The user is swallowing - message.text = '/swallow' - message.entities[0].length = len('/swallow') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.SWALLOWING - assert not self.test_flag - - # The user is holding the cup again - message.text = '/hold' - message.entities[0].length = len('/hold') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.HOLDING - assert not self.test_flag - - # The user wants to replenish the coffee supply - message.text = '/replenish' - message.entities[0].length = len('/replenish') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.REPLENISHING - assert handler.conversations[(0, user1.id)] == self.BREWING - assert not self.test_flag - - # The user wants to drink their coffee again - message.text = '/pourCoffee' - message.entities[0].length = len('/pourCoffee') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - assert not self.test_flag - - # The user is now ready to start coding - message.text = '/startCoding' - message.entities[0].length = len('/startCoding') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.CODING - assert not self.test_flag - - # The user decides it's time to drink again - message.text = '/drinkMore' - message.entities[0].length = len('/drinkMore') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - assert not self.test_flag - - # The user is holding their cup - message.text = '/hold' - message.entities[0].length = len('/hold') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.HOLDING - assert not self.test_flag - - # The user wants to end with the drinking and go back to coding - message.text = '/end' - message.entities[0].length = len('/end') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.END - assert handler.conversations[(0, user1.id)] == self.CODING - assert not self.test_flag - - # The user wants to drink once more - message.text = '/drinkMore' - message.entities[0].length = len('/drinkMore') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.DRINKING - assert not self.test_flag - - # The user wants to stop altogether - message.text = '/stop' - message.entities[0].length = len('/stop') - dp.process_update(Update(update_id=0, message=message)) - assert self.current_state[user1.id] == self.STOPPING - assert handler.conversations.get((0, user1.id)) is None - assert not self.test_flag - - def test_conversation_handler_run_async_true(self, dp): - conv_handler = ConversationHandler( - entry_points=self.entry_points, - states=self.states, - fallbacks=self.fallbacks, - run_async=True, - ) - - all_handlers = conv_handler.entry_points + conv_handler.fallbacks - for state_handlers in conv_handler.states.values(): - all_handlers += state_handlers - - for handler in all_handlers: - assert handler.run_async - - def test_conversation_handler_run_async_false(self, dp): - conv_handler = ConversationHandler( - entry_points=[CommandHandler('start', self.start_end, run_async=True)], - states=self.states, - fallbacks=self.fallbacks, - run_async=False, - ) - - for handler in conv_handler.entry_points: - assert handler.run_async - - all_handlers = conv_handler.fallbacks - for state_handlers in conv_handler.states.values(): - all_handlers += state_handlers - - for handler in all_handlers: - assert not handler.run_async.value diff --git a/tests/test_defaults.py b/tests/test_defaults.py deleted file mode 100644 index bee205c4d3f..00000000000 --- a/tests/test_defaults.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. - -import pytest -import inspect - -from telegram.ext import Defaults -from telegram import User - - -class TestDefault: - def test_slot_behaviour(self, mro_slots): - a = Defaults(parse_mode='HTML', quote=True) - for attr in a.__slots__: - assert getattr(a, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(a)) == len(set(mro_slots(a))), "duplicate slot" - - def test_data_assignment(self, dp): - defaults = Defaults() - - for name, val in inspect.getmembers(Defaults, lambda x: isinstance(x, property)): - with pytest.raises(AttributeError): - setattr(defaults, name, True) - - def test_equality(self): - a = Defaults(parse_mode='HTML', quote=True) - b = Defaults(parse_mode='HTML', quote=True) - c = Defaults(parse_mode='HTML', quote=True, protect_content=True) - d = Defaults(parse_mode='HTML', timeout=50) - e = User(123, 'test_user', False) - - assert a == b - assert hash(a) == hash(b) - assert a is not b - - assert a != c - assert hash(a) != hash(c) - - assert a != d - assert hash(a) != hash(d) - - assert a != e - assert hash(a) != hash(e) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py deleted file mode 100644 index eecc123b484..00000000000 --- a/tests/test_dispatcher.py +++ /dev/null @@ -1,1130 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import logging -from queue import Queue -from threading import current_thread -from time import sleep - -import pytest - -from telegram import Message, User, Chat, Update, Bot, MessageEntity -from telegram.ext import ( - CommandHandler, - MessageHandler, - JobQueue, - filters, - Defaults, - CallbackContext, - ContextTypes, - BasePersistence, - PersistenceInput, - Dispatcher, - DispatcherHandlerStop, - DispatcherBuilder, - UpdaterBuilder, -) - -from telegram._utils.defaultvalue import DEFAULT_FALSE -from telegram.error import TelegramError -from tests.conftest import create_dp -from collections import defaultdict - - -@pytest.fixture(scope='function') -def dp2(bot): - yield from create_dp(bot) - - -class CustomContext(CallbackContext): - pass - - -class TestDispatcher: - message_update = Update( - 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - ) - received = None - count = 0 - - @pytest.fixture(autouse=True, name='reset') - def reset_fixture(self): - self.reset() - - def reset(self): - self.received = None - self.count = 0 - - def error_handler_context(self, update, context): - self.received = context.error.message - - def error_handler_raise_error(self, update, context): - raise Exception('Failing bigly') - - def callback_increase_count(self, update, context): - self.count += 1 - - def callback_set_count(self, count): - def callback(update, context): - self.count = count - - return callback - - def callback_raise_error(self, update, context): - raise TelegramError(update.message.text) - - def callback_received(self, update, context): - self.received = update.message - - def callback_context(self, update, context): - if ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.error, TelegramError) - ): - self.received = context.error.message - - def test_slot_behaviour(self, bot, mro_slots): - dp = DispatcherBuilder().bot(bot).build() - for at in dp.__slots__: - at = f"_Dispatcher{at}" if at.startswith('__') and not at.endswith('__') else at - assert getattr(dp, at, 'err') != 'err', f"got extra slot '{at}'" - assert len(mro_slots(dp)) == len(set(mro_slots(dp))), "duplicate slot" - - def test_manual_init_warning(self, recwarn): - Dispatcher( - bot=None, - update_queue=None, - workers=7, - exception_event=None, - job_queue=None, - persistence=None, - context_types=ContextTypes(), - ) - assert len(recwarn) == 1 - assert ( - str(recwarn[-1].message) - == '`Dispatcher` instances should be built via the `DispatcherBuilder`.' - ) - assert recwarn[0].filename == __file__, "stacklevel is incorrect!" - - @pytest.mark.parametrize("data", ["chat_data", "user_data"]) - def test_chat_user_data_read_only(self, dp, data): - read_only_data = getattr(dp, data) - writable_data = getattr(dp, f"_{data}") - writable_data[123] = 321 - assert read_only_data == writable_data - with pytest.raises(TypeError): - read_only_data[111] = 123 - - @pytest.mark.parametrize( - 'builder', - (DispatcherBuilder(), UpdaterBuilder()), - ids=('DispatcherBuilder', 'UpdaterBuilder'), - ) - def test_less_than_one_worker_warning(self, dp, recwarn, builder): - builder.bot(dp.bot).workers(0).build() - assert len(recwarn) == 1 - assert ( - str(recwarn[0].message) - == 'Asynchronous callbacks can not be processed without at least one worker thread.' - ) - assert recwarn[0].filename == __file__, "stacklevel is incorrect!" - - def test_builder(self, dp): - builder_1 = dp.builder() - builder_2 = dp.builder() - assert isinstance(builder_1, DispatcherBuilder) - assert isinstance(builder_2, DispatcherBuilder) - assert builder_1 is not builder_2 - - # Make sure that setting a token doesn't raise an exception - # i.e. check that the builders are "empty"/new - builder_1.token(dp.bot.token) - builder_2.token(dp.bot.token) - - def test_one_context_per_update(self, dp): - def one(update, context): - if update.message.text == 'test': - context.my_flag = True - - def two(update, context): - if update.message.text == 'test': - if not hasattr(context, 'my_flag'): - pytest.fail() - else: - if hasattr(context, 'my_flag'): - pytest.fail() - - dp.add_handler(MessageHandler(filters.Regex('test'), one), group=1) - dp.add_handler(MessageHandler(None, two), group=2) - u = Update(1, Message(1, None, None, None, text='test')) - dp.process_update(u) - u.message.text = 'something' - dp.process_update(u) - - def test_error_handler(self, dp): - dp.add_error_handler(self.error_handler_context) - error = TelegramError('Unauthorized.') - dp.update_queue.put(error) - sleep(0.1) - assert self.received == 'Unauthorized.' - - # Remove handler - dp.remove_error_handler(self.error_handler_context) - self.reset() - - dp.update_queue.put(error) - sleep(0.1) - assert self.received is None - - def test_double_add_error_handler(self, dp, caplog): - dp.add_error_handler(self.error_handler_context) - with caplog.at_level(logging.DEBUG): - dp.add_error_handler(self.error_handler_context) - assert len(caplog.records) == 1 - assert caplog.records[-1].getMessage().startswith('The callback is already registered') - - def test_construction_with_bad_persistence(self, caplog, bot): - class my_per: - def __init__(self): - self.store_data = PersistenceInput(False, False, False, False) - - with pytest.raises( - TypeError, match='persistence must be based on telegram.ext.BasePersistence' - ): - DispatcherBuilder().bot(bot).persistence(my_per()).build() - - def test_error_handler_that_raises_errors(self, dp): - """ - Make sure that errors raised in error handlers don't break the main loop of the dispatcher - """ - handler_raise_error = MessageHandler(filters.ALL, self.callback_raise_error) - handler_increase_count = MessageHandler(filters.ALL, self.callback_increase_count) - error = TelegramError('Unauthorized.') - - dp.add_error_handler(self.error_handler_raise_error) - - # From errors caused by handlers - dp.add_handler(handler_raise_error) - dp.update_queue.put(self.message_update) - sleep(0.1) - - # From errors in the update_queue - dp.remove_handler(handler_raise_error) - dp.add_handler(handler_increase_count) - dp.update_queue.put(error) - dp.update_queue.put(self.message_update) - sleep(0.1) - - assert self.count == 1 - - @pytest.mark.parametrize(['run_async', 'expected_output'], [(True, 5), (False, 0)]) - def test_default_run_async_error_handler(self, dp, monkeypatch, run_async, expected_output): - def mock_async_err_handler(*args, **kwargs): - self.count = 5 - - # set defaults value to dp.bot - dp.bot._defaults = Defaults(run_async=run_async) - try: - dp.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - dp.add_error_handler(self.error_handler_context) - - monkeypatch.setattr(dp, 'run_async', mock_async_err_handler) - dp.process_update(self.message_update) - - assert self.count == expected_output - - finally: - # reset dp.bot.defaults values - dp.bot._defaults = None - - @pytest.mark.parametrize( - ['run_async', 'expected_output'], [(True, 'running async'), (False, None)] - ) - def test_default_run_async(self, monkeypatch, dp, run_async, expected_output): - def mock_run_async(*args, **kwargs): - self.received = 'running async' - - # set defaults value to dp.bot - dp.bot._defaults = Defaults(run_async=run_async) - try: - dp.add_handler(MessageHandler(filters.ALL, lambda u, c: None)) - monkeypatch.setattr(dp, 'run_async', mock_run_async) - dp.process_update(self.message_update) - assert self.received == expected_output - - finally: - # reset defaults value - dp.bot._defaults = None - - def test_run_async_multiple(self, bot, dp, dp2): - def get_dispatcher_name(q): - q.put(current_thread().name) - - q1 = Queue() - q2 = Queue() - - dp.run_async(get_dispatcher_name, q1) - dp2.run_async(get_dispatcher_name, q2) - - sleep(0.1) - - name1 = q1.get() - name2 = q2.get() - - assert name1 != name2 - - def test_async_raises_dispatcher_handler_stop(self, dp, recwarn): - def callback(update, context): - raise DispatcherHandlerStop() - - dp.add_handler(MessageHandler(filters.ALL, callback, run_async=True)) - - dp.update_queue.put(self.message_update) - sleep(0.1) - assert len(recwarn) == 1 - assert str(recwarn[-1].message).startswith( - 'DispatcherHandlerStop is not supported with async functions' - ) - - def test_add_async_handler(self, dp): - dp.add_handler( - MessageHandler( - filters.ALL, - self.callback_received, - run_async=True, - ) - ) - - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.received == self.message_update.message - - def test_run_async_no_error_handler(self, dp, caplog): - def func(): - raise RuntimeError('Async Error') - - with caplog.at_level(logging.ERROR): - dp.run_async(func) - sleep(0.1) - assert len(caplog.records) == 1 - assert caplog.records[-1].getMessage().startswith('No error handlers are registered') - - def test_async_handler_async_error_handler_context(self, dp): - dp.add_handler(MessageHandler(filters.ALL, self.callback_raise_error, run_async=True)) - dp.add_error_handler(self.error_handler_context, run_async=True) - - dp.update_queue.put(self.message_update) - sleep(2) - assert self.received == self.message_update.message.text - - def test_async_handler_error_handler_that_raises_error(self, dp, caplog): - handler = MessageHandler(filters.ALL, self.callback_raise_error, run_async=True) - dp.add_handler(handler) - dp.add_error_handler(self.error_handler_raise_error, run_async=False) - - with caplog.at_level(logging.ERROR): - dp.update_queue.put(self.message_update) - sleep(0.1) - assert len(caplog.records) == 1 - assert ( - caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') - ) - - # Make sure that the main loop still runs - dp.remove_handler(handler) - dp.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, run_async=True)) - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.count == 1 - - def test_async_handler_async_error_handler_that_raises_error(self, dp, caplog): - handler = MessageHandler(filters.ALL, self.callback_raise_error, run_async=True) - dp.add_handler(handler) - dp.add_error_handler(self.error_handler_raise_error, run_async=True) - - with caplog.at_level(logging.ERROR): - dp.update_queue.put(self.message_update) - sleep(0.1) - assert len(caplog.records) == 1 - assert ( - caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') - ) - - # Make sure that the main loop still runs - dp.remove_handler(handler) - dp.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, run_async=True)) - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.count == 1 - - def test_error_in_handler(self, dp): - dp.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - dp.add_error_handler(self.error_handler_context) - - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.received == self.message_update.message.text - - def test_add_remove_handler(self, dp): - handler = MessageHandler(filters.ALL, self.callback_increase_count) - dp.add_handler(handler) - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.count == 1 - dp.remove_handler(handler) - dp.update_queue.put(self.message_update) - assert self.count == 1 - - def test_add_remove_handler_non_default_group(self, dp): - handler = MessageHandler(filters.ALL, self.callback_increase_count) - dp.add_handler(handler, group=2) - with pytest.raises(KeyError): - dp.remove_handler(handler) - dp.remove_handler(handler, group=2) - - def test_error_start_twice(self, dp): - assert dp.running - dp.start() - - def test_handler_order_in_group(self, dp): - dp.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1))) - dp.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2))) - dp.add_handler(MessageHandler(filters.TEXT, self.callback_set_count(3))) - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.count == 2 - - def test_groups(self, dp): - dp.add_handler(MessageHandler(filters.ALL, self.callback_increase_count)) - dp.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=2) - dp.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=-1) - - dp.update_queue.put(self.message_update) - sleep(0.1) - assert self.count == 3 - - def test_add_handlers_complex(self, dp): - """Tests both add_handler & add_handlers together & confirms the correct insertion order""" - msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) - msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) - - dp.add_handler(msg_handler_set_count, 1) - dp.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) - - photo_update = Update(2, message=Message(2, None, None, photo=True)) - dp.update_queue.put(self.message_update) # Putting updates in the queue calls the callback - dp.update_queue.put(photo_update) - sleep(0.1) # sleep is required otherwise there is random behaviour - - # Test if handler was added to correct group with correct order- - assert ( - self.count == 2 - and len(dp.handlers[1]) == 3 - and dp.handlers[1][0] is msg_handler_set_count - ) - - # Now lets test add_handlers when `handlers` is a dict- - voice_filter_handler_to_check = MessageHandler(filters.VOICE, self.callback_increase_count) - dp.add_handlers( - handlers={ - 1: [ - MessageHandler(filters.USER, self.callback_increase_count), - voice_filter_handler_to_check, - ], - -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], - } - ) - - user_update = Update(3, message=Message(3, None, None, from_user=User(1, 's', True))) - voice_update = Update(4, message=Message(4, None, None, voice=True)) - dp.update_queue.put(user_update) - dp.update_queue.put(voice_update) - sleep(0.1) - - assert ( - self.count == 4 - and len(dp.handlers[1]) == 5 - and dp.handlers[1][-1] is voice_filter_handler_to_check - ) - - dp.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) - sleep(0.1) - - assert self.count == 2 and len(dp.handlers[-1]) == 1 - - # Now lets test the errors which can be produced- - with pytest.raises(ValueError, match="The `group` argument"): - dp.add_handlers({2: [msg_handler_set_count]}, group=0) - with pytest.raises(ValueError, match="Handlers for group 3"): - dp.add_handlers({3: msg_handler_set_count}) - with pytest.raises(ValueError, match="The `handlers` argument must be a sequence"): - dp.add_handlers({msg_handler_set_count}) - - def test_add_handler_errors(self, dp): - handler = 'not a handler' - with pytest.raises(TypeError, match='handler is not an instance of'): - dp.add_handler(handler) - - handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) - with pytest.raises(TypeError, match='group is not int'): - dp.add_handler(handler, 'one') - - def test_flow_stop(self, dp, bot): - passed = [] - - def start1(b, u): - passed.append('start1') - raise DispatcherHandlerStop - - def start2(b, u): - passed.append('start2') - - def start3(b, u): - passed.append('start3') - - def error(b, u, e): - passed.append('error') - passed.append(e) - - update = Update( - 1, - message=Message( - 1, - None, - None, - None, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ), - ) - - # If Stop raised handlers in other groups should not be called. - passed = [] - dp.add_handler(CommandHandler('start', start1), 1) - dp.add_handler(CommandHandler('start', start3), 1) - dp.add_handler(CommandHandler('start', start2), 2) - dp.process_update(update) - assert passed == ['start1'] - - def test_exception_in_handler(self, dp, bot): - passed = [] - err = Exception('General exception') - - def start1(u, c): - passed.append('start1') - raise err - - def start2(u, c): - passed.append('start2') - - def start3(u, c): - passed.append('start3') - - def error(u, c): - passed.append('error') - passed.append(c.error) - - update = Update( - 1, - message=Message( - 1, - None, - None, - None, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ), - ) - - # If an unhandled exception was caught, no further handlers from the same group should be - # called. Also, the error handler should be called and receive the exception - passed = [] - dp.add_handler(CommandHandler('start', start1), 1) - dp.add_handler(CommandHandler('start', start2), 1) - dp.add_handler(CommandHandler('start', start3), 2) - dp.add_error_handler(error) - dp.process_update(update) - assert passed == ['start1', 'error', err, 'start3'] - - def test_telegram_error_in_handler(self, dp, bot): - passed = [] - err = TelegramError('Telegram error') - - def start1(u, c): - passed.append('start1') - raise err - - def start2(u, c): - passed.append('start2') - - def start3(u, c): - passed.append('start3') - - def error(u, c): - passed.append('error') - passed.append(c.error) - - update = Update( - 1, - message=Message( - 1, - None, - None, - None, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ), - ) - - # If a TelegramException was caught, an error handler should be called and no further - # handlers from the same group should be called. - dp.add_handler(CommandHandler('start', start1), 1) - dp.add_handler(CommandHandler('start', start2), 1) - dp.add_handler(CommandHandler('start', start3), 2) - dp.add_error_handler(error) - dp.process_update(update) - assert passed == ['start1', 'error', err, 'start3'] - assert passed[2] is err - - def test_error_while_saving_chat_data(self, bot): - increment = [] - - class OwnPersistence(BasePersistence): - def get_callback_data(self): - return None - - def update_callback_data(self, data): - raise Exception - - def get_bot_data(self): - return {} - - def update_bot_data(self, data): - raise Exception - - def drop_chat_data(self, chat_id): - pass - - def drop_user_data(self, user_id): - pass - - def get_chat_data(self): - return defaultdict(dict) - - def update_chat_data(self, chat_id, data): - raise Exception - - def get_user_data(self): - return defaultdict(dict) - - def update_user_data(self, user_id, data): - raise Exception - - def get_conversations(self, name): - pass - - def update_conversation(self, name, key, new_state): - pass - - def refresh_user_data(self, user_id, user_data): - pass - - def refresh_chat_data(self, chat_id, chat_data): - pass - - def refresh_bot_data(self, bot_data): - pass - - def flush(self): - pass - - def start1(u, c): - pass - - def error(u, c): - increment.append("error") - - # If updating a user_data or chat_data from a persistence object throws an error, - # the error handler should catch it - - update = Update( - 1, - message=Message( - 1, - None, - Chat(1, "lala"), - from_user=User(1, "Test", False), - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ), - ) - my_persistence = OwnPersistence() - dp = DispatcherBuilder().bot(bot).persistence(my_persistence).build() - dp.add_handler(CommandHandler('start', start1)) - dp.add_error_handler(error) - dp.process_update(update) - assert increment == ["error", "error", "error", "error"] - - def test_flow_stop_in_error_handler(self, dp, bot): - passed = [] - err = TelegramError('Telegram error') - - def start1(u, c): - passed.append('start1') - raise err - - def start2(u, c): - passed.append('start2') - - def start3(u, c): - passed.append('start3') - - def error(u, c): - passed.append('error') - passed.append(c.error) - raise DispatcherHandlerStop - - update = Update( - 1, - message=Message( - 1, - None, - None, - None, - text='/start', - entities=[ - MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - ], - bot=bot, - ), - ) - - # If a TelegramException was caught, an error handler should be called and no further - # handlers from the same group should be called. - dp.add_handler(CommandHandler('start', start1), 1) - dp.add_handler(CommandHandler('start', start2), 1) - dp.add_handler(CommandHandler('start', start3), 2) - dp.add_error_handler(error) - dp.process_update(update) - assert passed == ['start1', 'error', err] - assert passed[2] is err - - def test_sensible_worker_thread_names(self, dp2): - thread_names = [thread.name for thread in dp2._Dispatcher__async_threads] - for thread_name in thread_names: - assert thread_name.startswith(f"Bot:{dp2.bot.id}:worker:") - - @pytest.mark.parametrize( - 'message', - [ - Message(message_id=1, chat=Chat(id=2, type=None), migrate_from_chat_id=1, date=None), - Message(message_id=1, chat=Chat(id=1, type=None), migrate_to_chat_id=2, date=None), - Message(message_id=1, chat=Chat(id=1, type=None), date=None), - None, - ], - ) - @pytest.mark.parametrize('old_chat_id', [None, 1, "1"]) - @pytest.mark.parametrize('new_chat_id', [None, 2, "1"]) - def test_migrate_chat_data(self, dp, message: 'Message', old_chat_id: int, new_chat_id: int): - def call(match: str): - with pytest.raises(ValueError, match=match): - dp.migrate_chat_data( - message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id - ) - - if message and (old_chat_id or new_chat_id): - call(r"^Message and chat_id pair are mutually exclusive$") - return - - if not any((message, old_chat_id, new_chat_id)): - call(r"^chat_id pair or message must be passed$") - return - - if message: - if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: - call(r"^Invalid message instance") - return - effective_old_chat_id = message.migrate_from_chat_id or message.chat.id - effective_new_chat_id = message.migrate_to_chat_id or message.chat.id - - elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): - call(r"^old_chat_id and new_chat_id must be integers$") - return - else: - effective_old_chat_id = old_chat_id - effective_new_chat_id = new_chat_id - - dp.chat_data[effective_old_chat_id]['key'] = "test" - dp.migrate_chat_data(message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id) - assert effective_old_chat_id not in dp.chat_data - assert dp.chat_data[effective_new_chat_id]['key'] == "test" - - def test_error_while_persisting(self, dp, caplog): - class OwnPersistence(BasePersistence): - def update(self, data): - raise Exception('PersistenceError') - - def update_callback_data(self, data): - self.update(data) - - def update_bot_data(self, data): - self.update(data) - - def update_chat_data(self, chat_id, data): - self.update(data) - - def update_user_data(self, user_id, data): - self.update(data) - - def drop_user_data(self, user_id): - pass - - def drop_chat_data(self, chat_id): - pass - - def get_chat_data(self): - pass - - def get_bot_data(self): - pass - - def get_user_data(self): - pass - - def get_callback_data(self): - pass - - def get_conversations(self, name): - pass - - def update_conversation(self, name, key, new_state): - pass - - def refresh_bot_data(self, bot_data): - pass - - def refresh_user_data(self, user_id, user_data): - pass - - def refresh_chat_data(self, chat_id, chat_data): - pass - - def flush(self): - pass - - def callback(update, context): - pass - - test_flag = [] - - def error(update, context): - nonlocal test_flag - test_flag.append(str(context.error) == 'PersistenceError') - raise Exception('ErrorHandlingError') - - update = Update( - 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - ) - handler = MessageHandler(filters.ALL, callback) - dp.add_handler(handler) - dp.add_error_handler(error) - - dp.persistence = OwnPersistence() - - with caplog.at_level(logging.ERROR): - dp.process_update(update) - - assert test_flag == [True, True, True, True] - assert len(caplog.records) == 4 - for record in caplog.records: - message = record.getMessage() - assert message.startswith('An error was raised and an uncaught') - - def test_persisting_no_user_no_chat(self, dp): - class OwnPersistence(BasePersistence): - def __init__(self): - super().__init__() - self.test_flag_bot_data = False - self.test_flag_chat_data = False - self.test_flag_user_data = False - - def update_bot_data(self, data): - self.test_flag_bot_data = True - - def update_chat_data(self, chat_id, data): - self.test_flag_chat_data = True - - def update_user_data(self, user_id, data): - self.test_flag_user_data = True - - def update_conversation(self, name, key, new_state): - pass - - def drop_chat_data(self, chat_id): - pass - - def drop_user_data(self, user_id): - pass - - def get_conversations(self, name): - pass - - def get_user_data(self): - pass - - def get_bot_data(self): - pass - - def get_chat_data(self): - pass - - def refresh_bot_data(self, bot_data): - pass - - def refresh_user_data(self, user_id, user_data): - pass - - def refresh_chat_data(self, chat_id, chat_data): - pass - - def get_callback_data(self): - pass - - def update_callback_data(self, data): - pass - - def flush(self): - pass - - def callback(update, context): - pass - - handler = MessageHandler(filters.ALL, callback) - dp.add_handler(handler) - dp.persistence = OwnPersistence() - - update = Update( - 1, message=Message(1, None, None, from_user=User(1, '', False), text='Text') - ) - dp.process_update(update) - assert dp.persistence.test_flag_bot_data - assert dp.persistence.test_flag_user_data - assert not dp.persistence.test_flag_chat_data - - dp.persistence.test_flag_bot_data = False - dp.persistence.test_flag_user_data = False - dp.persistence.test_flag_chat_data = False - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - dp.process_update(update) - assert dp.persistence.test_flag_bot_data - assert not dp.persistence.test_flag_user_data - assert dp.persistence.test_flag_chat_data - - @pytest.mark.parametrize( - "c_id,expected", - [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], - ids=["test chat_id removal", "test no key in data (no error)"], - ) - def test_drop_chat_data(self, dp, c_id, expected): - dp._chat_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) - dp.drop_chat_data(c_id) - assert dp.chat_data == expected - - @pytest.mark.parametrize( - "u_id,expected", - [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], - ids=["test user_id removal", "test no key in data (no error)"], - ) - def test_drop_user_data(self, dp, u_id, expected): - dp._user_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) - dp.drop_user_data(u_id) - assert dp.user_data == expected - - def test_update_persistence_once_per_update(self, monkeypatch, dp): - def update_persistence(*args, **kwargs): - self.count += 1 - - def dummy_callback(*args): - pass - - monkeypatch.setattr(dp, 'update_persistence', update_persistence) - - for group in range(5): - dp.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text=None)) - dp.process_update(update) - assert self.count == 0 - - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='text')) - dp.process_update(update) - assert self.count == 1 - - def test_update_persistence_all_async(self, monkeypatch, dp): - def update_persistence(*args, **kwargs): - self.count += 1 - - def dummy_callback(*args, **kwargs): - pass - - monkeypatch.setattr(dp, 'update_persistence', update_persistence) - monkeypatch.setattr(dp, 'run_async', dummy_callback) - - for group in range(5): - dp.add_handler( - MessageHandler(filters.TEXT, dummy_callback, run_async=True), group=group - ) - - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - dp.process_update(update) - assert self.count == 0 - - dp.bot._defaults = Defaults(run_async=True) - try: - for group in range(5): - dp.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - dp.process_update(update) - assert self.count == 0 - finally: - dp.bot._defaults = None - - @pytest.mark.parametrize('run_async', [DEFAULT_FALSE, False]) - def test_update_persistence_one_sync(self, monkeypatch, dp, run_async): - def update_persistence(*args, **kwargs): - self.count += 1 - - def dummy_callback(*args, **kwargs): - pass - - monkeypatch.setattr(dp, 'update_persistence', update_persistence) - monkeypatch.setattr(dp, 'run_async', dummy_callback) - - for group in range(5): - dp.add_handler( - MessageHandler(filters.TEXT, dummy_callback, run_async=True), group=group - ) - dp.add_handler(MessageHandler(filters.TEXT, dummy_callback, run_async=run_async), group=5) - - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - dp.process_update(update) - assert self.count == 1 - - @pytest.mark.parametrize('run_async,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)]) - def test_update_persistence_defaults_async(self, monkeypatch, dp, run_async, expected): - def update_persistence(*args, **kwargs): - self.count += 1 - - def dummy_callback(*args, **kwargs): - pass - - monkeypatch.setattr(dp, 'update_persistence', update_persistence) - monkeypatch.setattr(dp, 'run_async', dummy_callback) - dp.bot._defaults = Defaults(run_async=run_async) - - try: - for group in range(5): - dp.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - - update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - dp.process_update(update) - assert self.count == expected - finally: - dp.bot._defaults = None - - def test_custom_context_init(self, bot): - cc = ContextTypes( - context=CustomContext, - user_data=int, - chat_data=float, - bot_data=complex, - ) - - dispatcher = DispatcherBuilder().bot(bot).context_types(cc).build() - - assert isinstance(dispatcher.user_data[1], int) - assert isinstance(dispatcher.chat_data[1], float) - assert isinstance(dispatcher.bot_data, complex) - - def test_custom_context_error_handler(self, bot): - def error_handler(_, context): - self.received = ( - type(context), - type(context.user_data), - type(context.chat_data), - type(context.bot_data), - ) - - dispatcher = ( - DispatcherBuilder() - .bot(bot) - .context_types( - ContextTypes( - context=CustomContext, bot_data=int, user_data=float, chat_data=complex - ) - ) - .build() - ) - dispatcher.add_error_handler(error_handler) - dispatcher.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - - dispatcher.process_update(self.message_update) - sleep(0.1) - assert self.received == (CustomContext, float, complex, int) - - def test_custom_context_handler_callback(self, bot): - def callback(_, context): - self.received = ( - type(context), - type(context.user_data), - type(context.chat_data), - type(context.bot_data), - ) - - dispatcher = ( - DispatcherBuilder() - .bot(bot) - .context_types( - ContextTypes( - context=CustomContext, bot_data=int, user_data=float, chat_data=complex - ) - ) - .build() - ) - dispatcher.add_handler(MessageHandler(filters.ALL, callback)) - - dispatcher.process_update(self.message_update) - sleep(0.1) - assert self.received == (CustomContext, float, complex, int) diff --git a/tests/test_document.py b/tests/test_document.py index ffd74662bd3..c77ab345bbb 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -25,6 +25,7 @@ from telegram import Document, PhotoSize, Voice, MessageEntity, Bot from telegram.error import BadRequest, TelegramError from telegram.helpers import escape_markdown +from telegram.request import RequestData from tests.conftest import ( check_shortcut_signature, check_shortcut_call, @@ -41,9 +42,10 @@ def document_file(): @pytest.fixture(scope='class') -def document(bot, chat_id): +@pytest.mark.asyncio +async def document(bot, chat_id): with data_file('telegram.png').open('rb') as f: - return bot.send_document(chat_id, document=f, timeout=50).document + return (await bot.send_document(chat_id, document=f, read_timeout=50)).document class TestDocument: @@ -58,11 +60,6 @@ class TestDocument: document_file_id = '5a3128a4d2a04750b5b58397f3b5e812' document_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' - def test_slot_behaviour(self, document, mro_slots): - for attr in document.__slots__: - assert getattr(document, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(document)) == len(set(mro_slots(document))), "duplicate slot" - def test_creation(self, document): assert isinstance(document, Document) assert isinstance(document.file_id, str) @@ -79,8 +76,9 @@ def test_expected_values(self, document): assert document.thumb.height == self.thumb_height @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, document_file, document, thumb_file): - message = bot.send_document( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, document_file, document, thumb_file): + message = await bot.send_document( chat_id, document=document_file, caption=self.caption, @@ -106,21 +104,27 @@ def test_send_all_args(self, bot, chat_id, document_file, document, thumb_file): assert message.has_protected_content @flaky(3, 1) - def test_get_and_download(self, bot, document): - new_file = bot.get_file(document.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, document): + path = Path('telegram.png') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(document.file_id) assert new_file.file_size == document.file_size assert new_file.file_id == document.file_id assert new_file.file_unique_id == document.file_unique_id assert new_file.file_path.startswith('https://') - new_file.download('telegram.png') + await new_file.download('telegram.png') - assert Path('telegram.png').is_file() + assert path.is_file() @flaky(3, 1) - def test_send_url_gif_file(self, bot, chat_id): - message = bot.send_document(chat_id, self.document_file_url) + @pytest.mark.asyncio + async def test_send_url_gif_file(self, bot, chat_id): + message = await bot.send_document(chat_id, self.document_file_url) document = message.document @@ -135,16 +139,19 @@ def test_send_url_gif_file(self, bot, chat_id): assert document.file_size == 3878 @flaky(3, 1) - def test_send_resend(self, bot, chat_id, document): - message = bot.send_document(chat_id=chat_id, document=document.file_id) + @pytest.mark.asyncio + async def test_send_resend(self, bot, chat_id, document): + message = await bot.send_document(chat_id=chat_id, document=document.file_id) assert message.document == document @pytest.mark.parametrize('disable_content_type_detection', [True, False, None]) - def test_send_with_document( + @pytest.mark.asyncio + async def test_send_with_document( self, monkeypatch, bot, chat_id, document, disable_content_type_detection ): - def make_assertion(url, data, **kwargs): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters type_detection = ( data.get('disable_content_type_detection') == disable_content_type_detection ) @@ -152,7 +159,7 @@ def make_assertion(url, data, **kwargs): monkeypatch.setattr(bot.request, 'post', make_assertion) - message = bot.send_document( + message = await bot.send_document( document=document, chat_id=chat_id, disable_content_type_detection=disable_content_type_detection, @@ -161,14 +168,15 @@ def make_assertion(url, data, **kwargs): assert message @flaky(3, 1) - def test_send_document_caption_entities(self, bot, chat_id, document): + @pytest.mark.asyncio + async def test_send_document_caption_entities(self, bot, chat_id, document): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_document( + message = await bot.send_document( chat_id, document, caption=test_string, caption_entities=entities ) @@ -177,20 +185,22 @@ def test_send_document_caption_entities(self, bot, chat_id, document): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_document_default_parse_mode_1(self, default_bot, chat_id, document): + @pytest.mark.asyncio + async def test_send_document_default_parse_mode_1(self, default_bot, chat_id, document): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_document(chat_id, document, caption=test_markdown_string) + message = await default_bot.send_document(chat_id, document, caption=test_markdown_string) assert message.caption_markdown == test_markdown_string assert message.caption == test_string @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_document_default_parse_mode_2(self, default_bot, chat_id, document): + @pytest.mark.asyncio + async def test_send_document_default_parse_mode_2(self, default_bot, chat_id, document): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_document( + message = await default_bot.send_document( chat_id, document, caption=test_markdown_string, parse_mode=None ) assert message.caption == test_markdown_string @@ -198,10 +208,11 @@ def test_send_document_default_parse_mode_2(self, default_bot, chat_id, document @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_document_default_parse_mode_3(self, default_bot, chat_id, document): + @pytest.mark.asyncio + async def test_send_document_default_parse_mode_3(self, default_bot, chat_id, document): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_document( + message = await default_bot.send_document( chat_id, document, caption=test_markdown_string, parse_mode='HTML' ) assert message.caption == test_markdown_string @@ -217,13 +228,14 @@ def test_send_document_default_parse_mode_3(self, default_bot, chat_id, document ], indirect=['default_bot'], ) - def test_send_document_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_document_default_allow_sending_without_reply( self, default_bot, chat_id, document, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_document( + message = await default_bot.send_document( chat_id, document, allow_sending_without_reply=custom, @@ -231,38 +243,39 @@ def test_send_document_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_document( + message = await default_bot.send_document( chat_id, document, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_document( + await default_bot.send_document( chat_id, document, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_document_default_protect_content(self, chat_id, default_bot, document): - protected = default_bot.send_document(chat_id, document) + async def test_send_document_default_protect_content(self, chat_id, default_bot, document): + protected = await default_bot.send_document(chat_id, document) assert protected.has_protected_content - unprotected = default_bot.send_document(chat_id, document, protect_content=False) + unprotected = await default_bot.send_document(chat_id, document, protect_content=False) assert not unprotected.has_protected_content - def test_send_document_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_document_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('document') == expected and data.get('thumb') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_document(chat_id, file, thumb=file) + await bot.send_document(chat_id, file, thumb=file) assert test_flag - monkeypatch.delattr(bot, '_post') def test_de_json(self, bot, document): json_dict = { @@ -293,29 +306,34 @@ def test_to_dict(self, document): assert document_dict['file_size'] == document.file_size @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): - with Path(os.devnull).open('rb') as f, pytest.raises(TelegramError): - bot.send_document(chat_id=chat_id, document=f) + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): + with open(os.devnull, 'rb') as f: + with pytest.raises(TelegramError): + await bot.send_document(chat_id=chat_id, document=f) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_document(chat_id=chat_id, document='') + await bot.send_document(chat_id=chat_id, document='') - def test_error_send_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_document(chat_id=chat_id) + await bot.send_document(chat_id=chat_id) - def test_get_file_instance_method(self, monkeypatch, document): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, document): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == document.file_id assert check_shortcut_signature(Document.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(document.get_file, document.get_bot(), 'get_file') - assert check_defaults_handling(document.get_file, document.get_bot()) + assert await check_shortcut_call(document.get_file, document.get_bot(), 'get_file') + assert await check_defaults_handling(document.get_file, document.get_bot()) monkeypatch.setattr(document.get_bot(), 'get_file', make_assertion) - assert document.get_file() + assert await document.get_file() def test_equality(self, document): a = Document(document.file_id, document.file_unique_id) diff --git a/tests/test_error.py b/tests/test_error.py index eecefcd2b67..c425fdf9e3a 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -22,7 +22,7 @@ import pytest from telegram.error import ( - Unauthorized, + Forbidden, InvalidToken, NetworkError, BadRequest, @@ -48,14 +48,14 @@ def test_telegram_error(self): raise TelegramError("Bad Request: test message") def test_unauthorized(self): - with pytest.raises(Unauthorized, match="test message"): - raise Unauthorized("test message") - with pytest.raises(Unauthorized, match="^Test message$"): - raise Unauthorized("Error: test message") - with pytest.raises(Unauthorized, match="^Test message$"): - raise Unauthorized("[Error]: test message") - with pytest.raises(Unauthorized, match="^Test message$"): - raise Unauthorized("Bad Request: test message") + with pytest.raises(Forbidden, match="test message"): + raise Forbidden("test message") + with pytest.raises(Forbidden, match="^Test message$"): + raise Forbidden("Error: test message") + with pytest.raises(Forbidden, match="^Test message$"): + raise Forbidden("[Error]: test message") + with pytest.raises(Forbidden, match="^Test message$"): + raise Forbidden("Bad Request: test message") def test_invalid_token(self): with pytest.raises(InvalidToken, match="Invalid token"): @@ -105,7 +105,7 @@ def test_conflict(self): "exception, attributes", [ (TelegramError("test message"), ["message"]), - (Unauthorized("test message"), ["message"]), + (Forbidden("test message"), ["message"]), (InvalidToken(), ["message"]), (NetworkError("test message"), ["message"]), (BadRequest("test message"), ["message"]), @@ -130,7 +130,7 @@ def test_errors_pickling(self, exception, attributes): "inst", [ (TelegramError("test message")), - (Unauthorized("test message")), + (Forbidden("test message")), (InvalidToken()), (NetworkError("test message")), (BadRequest("test message")), @@ -164,7 +164,7 @@ def make_assertion(cls): covered_subclasses.update( { TelegramError: { - Unauthorized, + Forbidden, InvalidToken, NetworkError, ChatMigrated, diff --git a/tests/test_file.py b/tests/test_file.py index e423f298462..d9fa647c637 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -88,41 +88,46 @@ def test_to_dict(self, file): assert file_dict['file_size'] == file.file_size @flaky(3, 1) - def test_error_get_empty_file_id(self, bot): + @pytest.mark.asyncio + async def test_error_get_empty_file_id(self, bot): with pytest.raises(TelegramError): - bot.get_file(file_id='') + await bot.get_file(file_id='') - def test_download_mutuall_exclusive(self, file): + @pytest.mark.asyncio + async def test_download_mutually_exclusive(self, file): with pytest.raises(ValueError, match='`custom_path` and `out` are mutually exclusive'): - file.download('custom_path', 'out') + await file.download('custom_path', 'out') - def test_download(self, monkeypatch, file): - def test(*args, **kwargs): + @pytest.mark.asyncio + async def test_download(self, monkeypatch, file): + async def test(*args, **kwargs): return self.file_content - monkeypatch.setattr('telegram.request.Request.retrieve', test) - out_file = file.download() + monkeypatch.setattr(file.get_bot().request, 'retrieve', test) + out_file = await file.download() try: assert out_file.read_bytes() == self.file_content finally: out_file.unlink() - def test_download_local_file(self, local_file): - assert local_file.download() == Path(local_file.file_path) + @pytest.mark.asyncio + async def test_download_local_file(self, local_file): + assert await local_file.download() == Path(local_file.file_path) @pytest.mark.parametrize( 'custom_path_type', [str, Path], ids=['str custom_path', 'pathlib.Path custom_path'] ) - def test_download_custom_path(self, monkeypatch, file, custom_path_type): - def test(*args, **kwargs): + @pytest.mark.asyncio + async def test_download_custom_path(self, monkeypatch, file, custom_path_type): + async def test(*args, **kwargs): return self.file_content - monkeypatch.setattr('telegram.request.Request.retrieve', test) + monkeypatch.setattr(file.get_bot().request, 'retrieve', test) file_handle, custom_path = mkstemp() custom_path = Path(custom_path) try: - out_file = file.download(custom_path_type(custom_path)) + out_file = await file.download(custom_path_type(custom_path)) assert out_file == custom_path assert out_file.read_bytes() == self.file_content finally: @@ -132,25 +137,27 @@ def test(*args, **kwargs): @pytest.mark.parametrize( 'custom_path_type', [str, Path], ids=['str custom_path', 'pathlib.Path custom_path'] ) - def test_download_custom_path_local_file(self, local_file, custom_path_type): + @pytest.mark.asyncio + async def test_download_custom_path_local_file(self, local_file, custom_path_type): file_handle, custom_path = mkstemp() custom_path = Path(custom_path) try: - out_file = local_file.download(custom_path_type(custom_path)) + out_file = await local_file.download(custom_path_type(custom_path)) assert out_file == custom_path assert out_file.read_bytes() == self.file_content finally: os.close(file_handle) custom_path.unlink() - def test_download_no_filename(self, monkeypatch, file): - def test(*args, **kwargs): + @pytest.mark.asyncio + async def test_download_no_filename(self, monkeypatch, file): + async def test(*args, **kwargs): return self.file_content file.file_path = None - monkeypatch.setattr('telegram.request.Request.retrieve', test) - out_file = file.download() + monkeypatch.setattr(file.get_bot().request, 'retrieve', test) + out_file = await file.download() assert str(out_file)[-len(file.file_id) :] == file.file_id try: @@ -158,51 +165,55 @@ def test(*args, **kwargs): finally: out_file.unlink() - def test_download_file_obj(self, monkeypatch, file): - def test(*args, **kwargs): + @pytest.mark.asyncio + async def test_download_file_obj(self, monkeypatch, file): + async def test(*args, **kwargs): return self.file_content - monkeypatch.setattr('telegram.request.Request.retrieve', test) + monkeypatch.setattr(file.get_bot().request, 'retrieve', test) with TemporaryFile() as custom_fobj: - out_fobj = file.download(out=custom_fobj) + out_fobj = await file.download(out=custom_fobj) assert out_fobj is custom_fobj out_fobj.seek(0) assert out_fobj.read() == self.file_content - def test_download_file_obj_local_file(self, local_file): + @pytest.mark.asyncio + async def test_download_file_obj_local_file(self, local_file): with TemporaryFile() as custom_fobj: - out_fobj = local_file.download(out=custom_fobj) + out_fobj = await local_file.download(out=custom_fobj) assert out_fobj is custom_fobj out_fobj.seek(0) assert out_fobj.read() == self.file_content - def test_download_bytearray(self, monkeypatch, file): - def test(*args, **kwargs): + @pytest.mark.asyncio + async def test_download_bytearray(self, monkeypatch, file): + async def test(*args, **kwargs): return self.file_content - monkeypatch.setattr('telegram.request.Request.retrieve', test) + monkeypatch.setattr(file.get_bot().request, 'retrieve', test) # Check that a download to a newly allocated bytearray works. - buf = file.download_as_bytearray() + buf = await file.download_as_bytearray() assert buf == bytearray(self.file_content) # Check that a download to a given bytearray works (extends the bytearray). buf2 = buf[:] - buf3 = file.download_as_bytearray(buf=buf2) + buf3 = await file.download_as_bytearray(buf=buf2) assert buf3 is buf2 assert buf2[len(buf) :] == buf assert buf2[: len(buf)] == buf - def test_download_bytearray_local_file(self, local_file): + @pytest.mark.asyncio + async def test_download_bytearray_local_file(self, local_file): # Check that a download to a newly allocated bytearray works. - buf = local_file.download_as_bytearray() + buf = await local_file.download_as_bytearray() assert buf == bytearray(self.file_content) # Check that a download to a given bytearray works (extends the bytearray). buf2 = buf[:] - buf3 = local_file.download_as_bytearray(buf=buf2) + buf3 = await local_file.download_as_bytearray(buf=buf2) assert buf3 is buf2 assert buf2[len(buf) :] == buf assert buf2[: len(buf)] == buf diff --git a/tests/test_files.py b/tests/test_files.py index 5158bb6e904..df9d227be03 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -67,16 +67,12 @@ def test_parse_file_input_file_like(self): parsed = telegram._utils.files.parse_file_input(file) assert isinstance(parsed, InputFile) - assert not parsed.attach assert parsed.filename == 'game.gif' with source_file.open('rb') as file: - parsed = telegram._utils.files.parse_file_input( - file, attach=True, filename='test_file' - ) + parsed = telegram._utils.files.parse_file_input(file, filename='test_file') assert isinstance(parsed, InputFile) - assert parsed.attach assert parsed.filename == 'test_file' def test_parse_file_input_bytes(self): @@ -84,15 +80,13 @@ def test_parse_file_input_bytes(self): parsed = telegram._utils.files.parse_file_input(source_file.read_bytes()) assert isinstance(parsed, InputFile) - assert not parsed.attach assert parsed.filename == 'application.octet-stream' parsed = telegram._utils.files.parse_file_input( - source_file.read_bytes(), attach=True, filename='test_file' + source_file.read_bytes(), filename='test_file' ) assert isinstance(parsed, InputFile) - assert parsed.attach assert parsed.filename == 'test_file' def test_parse_file_input_tg_object(self): diff --git a/tests/test_filters.py b/tests/test_filters.py deleted file mode 100644 index 853460730c9..00000000000 --- a/tests/test_filters.py +++ /dev/null @@ -1,2274 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import datetime - -import pytest - -from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice, CallbackQuery -from telegram.ext import filters -import inspect -import re - - -@pytest.fixture(scope='function') -def update(): - return Update( - 0, - Message( - 0, - datetime.datetime.utcnow(), - Chat(0, 'private'), - from_user=User(0, 'Testuser', False), - via_bot=User(0, "Testbot", True), - sender_chat=Chat(0, 'Channel'), - forward_from=User(0, "HAL9000", False), - forward_from_chat=Chat(0, "Channel"), - ), - ) - - -@pytest.fixture(scope='function', params=MessageEntity.ALL_TYPES) -def message_entity(request): - return MessageEntity(request.param, 0, 0, url='', user=User(1, 'first_name', False)) - - -@pytest.fixture( - scope='class', - params=[{'class': filters.MessageFilter}, {'class': filters.UpdateFilter}], - ids=['MessageFilter', 'UpdateFilter'], -) -def base_class(request): - return request.param['class'] - - -class TestFilters: - def test_all_filters_slot_behaviour(self, mro_slots): - """ - Use depth first search to get all nested filters, and instantiate them (which need it) with - the correct number of arguments, then test each filter separately. Also tests setting - custom attributes on custom filters. - """ - - def filter_class(obj): - return True if inspect.isclass(obj) and "filters" in repr(obj) else False - - # The total no. of filters is about 72 as of 31/10/21. - # Gather all the filters to test using DFS- - visited = [] - classes = inspect.getmembers(filters, predicate=filter_class) # List[Tuple[str, type]] - stack = classes.copy() - while stack: - cls = stack[-1][-1] # get last element and its class - for inner_cls in inspect.getmembers( - cls, # Get inner filters - lambda a: inspect.isclass(a) and not issubclass(a, cls.__class__), - ): - if inner_cls[1] not in visited: - stack.append(inner_cls) - visited.append(inner_cls[1]) - classes.append(inner_cls) - break - else: - stack.pop() - - # Now start the actual testing - for name, cls in classes: - # Can't instantiate abstract classes without overriding methods, so skip them for now - exclude = {'_MergedFilter', '_XORFilter'} - if inspect.isabstract(cls) or name in {'__class__', '__base__'} | exclude: - continue - - assert '__slots__' in cls.__dict__, f"Filter {name!r} doesn't have __slots__" - # get no. of args minus the 'self', 'args' and 'kwargs' argument - init_sig = inspect.signature(cls.__init__).parameters - extra = 0 - for param in init_sig: - if param in {'self', 'args', 'kwargs'}: - extra += 1 - args = len(init_sig) - extra - - if not args: - inst = cls() - elif args == 1: - inst = cls('1') - else: - inst = cls(*['blah']) - - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), f"same slot in {name}" - - for attr in cls.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}' for {name}" - - def test_filters_all(self, update): - assert filters.ALL.check_update(update) - - def test_filters_text(self, update): - update.message.text = 'test' - assert filters.TEXT.check_update(update) - update.message.text = '/test' - assert filters.Text().check_update(update) - - def test_filters_text_strings(self, update): - update.message.text = '/test' - assert filters.Text(('/test', 'test1')).check_update(update) - assert not filters.Text(['test1', 'test2']).check_update(update) - - def test_filters_caption(self, update): - update.message.caption = 'test' - assert filters.CAPTION.check_update(update) - update.message.caption = None - assert not filters.CAPTION.check_update(update) - - def test_filters_caption_strings(self, update): - update.message.caption = 'test' - assert filters.Caption(('test', 'test1')).check_update(update) - assert not filters.Caption(['test1', 'test2']).check_update(update) - - def test_filters_command_default(self, update): - update.message.text = 'test' - assert not filters.COMMAND.check_update(update) - update.message.text = '/test' - assert not filters.COMMAND.check_update(update) - # Only accept commands at the beginning - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 3, 5)] - assert not filters.COMMAND.check_update(update) - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - assert filters.COMMAND.check_update(update) - - def test_filters_command_anywhere(self, update): - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 5, 4)] - assert filters.Command(False).check_update(update) - - def test_filters_regex(self, update): - sre_type = type(re.match("", "")) - update.message.text = '/start deep-linked param' - result = filters.Regex(r'deep-linked param').check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert type(matches[0]) is sre_type - update.message.text = '/help' - assert filters.Regex(r'help').check_update(update) - - update.message.text = 'test' - assert not filters.Regex(r'fail').check_update(update) - assert filters.Regex(r'test').check_update(update) - assert filters.Regex(re.compile(r'test')).check_update(update) - assert filters.Regex(re.compile(r'TEST', re.IGNORECASE)).check_update(update) - - update.message.text = 'i love python' - assert filters.Regex(r'.\b[lo]{2}ve python').check_update(update) - - update.message.text = None - assert not filters.Regex(r'fail').check_update(update) - - def test_filters_regex_multiple(self, update): - sre_type = type(re.match("", "")) - update.message.text = '/start deep-linked param' - result = (filters.Regex('deep') & filters.Regex(r'linked param')).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.Regex('deep') | filters.Regex(r'linked param')).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.Regex('not int') | filters.Regex(r'linked param')).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.Regex('not int') & filters.Regex(r'linked param')).check_update(update) - assert not result - - def test_filters_merged_with_regex(self, update): - sre_type = type(re.match("", "")) - update.message.text = '/start deep-linked param' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = (filters.COMMAND & filters.Regex(r'linked param')).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.Regex(r'linked param') & filters.COMMAND).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.Regex(r'linked param') | filters.COMMAND).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - # Should not give a match since it's a or filter and it short circuits - result = (filters.COMMAND | filters.Regex(r'linked param')).check_update(update) - assert result is True - - def test_regex_complex_merges(self, update): - sre_type = type(re.match("", "")) - update.message.text = 'test it out' - test_filter = filters.Regex('test') & ( - (filters.StatusUpdate.ALL | filters.FORWARDED) | filters.Regex('out') - ) - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 2 - assert all(type(res) is sre_type for res in matches) - update.message.forward_date = datetime.datetime.utcnow() - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.text = 'test it' - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.forward_date = None - result = test_filter.check_update(update) - assert not result - update.message.text = 'test it out' - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.pinned_message = True - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.text = 'it out' - result = test_filter.check_update(update) - assert not result - - update.message.text = 'test it out' - update.message.forward_date = None - update.message.pinned_message = None - test_filter = (filters.Regex('test') | filters.COMMAND) & ( - filters.Regex('it') | filters.StatusUpdate.ALL - ) - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 2 - assert all(type(res) is sre_type for res in matches) - update.message.text = 'test' - result = test_filter.check_update(update) - assert not result - update.message.pinned_message = True - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 1 - assert all(type(res) is sre_type for res in matches) - update.message.text = 'nothing' - result = test_filter.check_update(update) - assert not result - update.message.text = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = test_filter.check_update(update) - assert result - assert isinstance(result, bool) - update.message.text = '/start it' - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 1 - assert all(type(res) is sre_type for res in matches) - - def test_regex_inverted(self, update): - update.message.text = '/start deep-linked param' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - inv = ~filters.Regex(r'deep-linked param') - result = inv.check_update(update) - assert not result - update.message.text = 'not it' - result = inv.check_update(update) - assert result - assert isinstance(result, bool) - - inv = ~filters.Regex('linked') & filters.COMMAND - update.message.text = "it's linked" - result = inv.check_update(update) - assert not result - update.message.text = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = inv.check_update(update) - assert result - update.message.text = '/linked' - result = inv.check_update(update) - assert not result - - inv = ~filters.Regex('linked') | filters.COMMAND - update.message.text = "it's linked" - update.message.entities = [] - result = inv.check_update(update) - assert not result - update.message.text = '/start linked' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = inv.check_update(update) - assert result - update.message.text = '/start' - result = inv.check_update(update) - assert result - update.message.text = 'nothig' - update.message.entities = [] - result = inv.check_update(update) - assert result - - def test_filters_caption_regex(self, update): - sre_type = type(re.match("", "")) - update.message.caption = '/start deep-linked param' - result = filters.CaptionRegex(r'deep-linked param').check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert type(matches[0]) is sre_type - update.message.caption = '/help' - assert filters.CaptionRegex(r'help').check_update(update) - - update.message.caption = 'test' - assert not filters.CaptionRegex(r'fail').check_update(update) - assert filters.CaptionRegex(r'test').check_update(update) - assert filters.CaptionRegex(re.compile(r'test')).check_update(update) - assert filters.CaptionRegex(re.compile(r'TEST', re.IGNORECASE)).check_update(update) - - update.message.caption = 'i love python' - assert filters.CaptionRegex(r'.\b[lo]{2}ve python').check_update(update) - - update.message.caption = None - assert not filters.CaptionRegex(r'fail').check_update(update) - - def test_filters_caption_regex_multiple(self, update): - sre_type = type(re.match("", "")) - update.message.caption = '/start deep-linked param' - _and = filters.CaptionRegex('deep') & filters.CaptionRegex(r'linked param') - result = _and.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - _or = filters.CaptionRegex('deep') | filters.CaptionRegex(r'linked param') - result = _or.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - _or = filters.CaptionRegex('not int') | filters.CaptionRegex(r'linked param') - result = _or.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - _and = filters.CaptionRegex('not int') & filters.CaptionRegex(r'linked param') - result = _and.check_update(update) - assert not result - - def test_filters_merged_with_caption_regex(self, update): - sre_type = type(re.match("", "")) - update.message.caption = '/start deep-linked param' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = (filters.COMMAND & filters.CaptionRegex(r'linked param')).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.CaptionRegex(r'linked param') & filters.COMMAND).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - result = (filters.CaptionRegex(r'linked param') | filters.COMMAND).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - # Should not give a match since it's a or filter and it short circuits - result = (filters.COMMAND | filters.CaptionRegex(r'linked param')).check_update(update) - assert result is True - - def test_caption_regex_complex_merges(self, update): - sre_type = type(re.match("", "")) - update.message.caption = 'test it out' - test_filter = filters.CaptionRegex('test') & ( - (filters.StatusUpdate.ALL | filters.FORWARDED) | filters.CaptionRegex('out') - ) - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 2 - assert all(type(res) is sre_type for res in matches) - update.message.forward_date = datetime.datetime.utcnow() - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.caption = 'test it' - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.forward_date = None - result = test_filter.check_update(update) - assert not result - update.message.caption = 'test it out' - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.pinned_message = True - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert all(type(res) is sre_type for res in matches) - update.message.caption = 'it out' - result = test_filter.check_update(update) - assert not result - - update.message.caption = 'test it out' - update.message.forward_date = None - update.message.pinned_message = None - test_filter = (filters.CaptionRegex('test') | filters.COMMAND) & ( - filters.CaptionRegex('it') | filters.StatusUpdate.ALL - ) - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 2 - assert all(type(res) is sre_type for res in matches) - update.message.caption = 'test' - result = test_filter.check_update(update) - assert not result - update.message.pinned_message = True - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 1 - assert all(type(res) is sre_type for res in matches) - update.message.caption = 'nothing' - result = test_filter.check_update(update) - assert not result - update.message.caption = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = test_filter.check_update(update) - assert result - assert isinstance(result, bool) - update.message.caption = '/start it' - result = test_filter.check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert len(matches) == 1 - assert all(type(res) is sre_type for res in matches) - - def test_caption_regex_inverted(self, update): - update.message.caption = '/start deep-linked param' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - test_filter = ~filters.CaptionRegex(r'deep-linked param') - result = test_filter.check_update(update) - assert not result - update.message.caption = 'not it' - result = test_filter.check_update(update) - assert result - assert isinstance(result, bool) - - test_filter = ~filters.CaptionRegex('linked') & filters.COMMAND - update.message.caption = "it's linked" - result = test_filter.check_update(update) - assert not result - update.message.caption = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = test_filter.check_update(update) - assert result - update.message.caption = '/linked' - result = test_filter.check_update(update) - assert not result - - test_filter = ~filters.CaptionRegex('linked') | filters.COMMAND - update.message.caption = "it's linked" - update.message.entities = [] - result = test_filter.check_update(update) - assert not result - update.message.caption = '/start linked' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - result = test_filter.check_update(update) - assert result - update.message.caption = '/start' - result = test_filter.check_update(update) - assert result - update.message.caption = 'nothig' - update.message.entities = [] - result = test_filter.check_update(update) - assert result - - def test_filters_reply(self, update): - another_message = Message( - 1, - datetime.datetime.utcnow(), - Chat(0, 'private'), - from_user=User(1, 'TestOther', False), - ) - update.message.text = 'test' - assert not filters.REPLY.check_update(update) - update.message.reply_to_message = another_message - assert filters.REPLY.check_update(update) - - def test_filters_audio(self, update): - assert not filters.AUDIO.check_update(update) - update.message.audio = 'test' - assert filters.AUDIO.check_update(update) - - def test_filters_document(self, update): - assert not filters.DOCUMENT.check_update(update) - update.message.document = 'test' - assert filters.DOCUMENT.check_update(update) - - def test_filters_document_type(self, update): - update.message.document = Document( - "file_id", 'unique_id', mime_type="application/vnd.android.package-archive" - ) - assert filters.Document.APK.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.DOC.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "application/msword" - assert filters.Document.DOC.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.DOCX.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = ( - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ) - assert filters.Document.DOCX.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.EXE.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "application/octet-stream" - assert filters.Document.EXE.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.DOCX.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "image/gif" - assert filters.Document.GIF.check_update(update) - assert filters.Document.IMAGE.check_update(update) - assert not filters.Document.JPG.check_update(update) - assert not filters.Document.TEXT.check_update(update) - - update.message.document.mime_type = "image/jpeg" - assert filters.Document.JPG.check_update(update) - assert filters.Document.IMAGE.check_update(update) - assert not filters.Document.MP3.check_update(update) - assert not filters.Document.VIDEO.check_update(update) - - update.message.document.mime_type = "audio/mpeg" - assert filters.Document.MP3.check_update(update) - assert filters.Document.AUDIO.check_update(update) - assert not filters.Document.PDF.check_update(update) - assert not filters.Document.IMAGE.check_update(update) - - update.message.document.mime_type = "application/pdf" - assert filters.Document.PDF.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.PY.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "text/x-python" - assert filters.Document.PY.check_update(update) - assert filters.Document.TEXT.check_update(update) - assert not filters.Document.SVG.check_update(update) - assert not filters.Document.APPLICATION.check_update(update) - - update.message.document.mime_type = "image/svg+xml" - assert filters.Document.SVG.check_update(update) - assert filters.Document.IMAGE.check_update(update) - assert not filters.Document.TXT.check_update(update) - assert not filters.Document.VIDEO.check_update(update) - - update.message.document.mime_type = "text/plain" - assert filters.Document.TXT.check_update(update) - assert filters.Document.TEXT.check_update(update) - assert not filters.Document.TARGZ.check_update(update) - assert not filters.Document.APPLICATION.check_update(update) - - update.message.document.mime_type = "application/x-compressed-tar" - assert filters.Document.TARGZ.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.WAV.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "audio/x-wav" - assert filters.Document.WAV.check_update(update) - assert filters.Document.AUDIO.check_update(update) - assert not filters.Document.XML.check_update(update) - assert not filters.Document.IMAGE.check_update(update) - - update.message.document.mime_type = "text/xml" - assert filters.Document.XML.check_update(update) - assert filters.Document.TEXT.check_update(update) - assert not filters.Document.ZIP.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "application/zip" - assert filters.Document.ZIP.check_update(update) - assert filters.Document.APPLICATION.check_update(update) - assert not filters.Document.APK.check_update(update) - assert not filters.Document.AUDIO.check_update(update) - - update.message.document.mime_type = "image/x-rgb" - assert not filters.Document.Category("application/").check_update(update) - assert not filters.Document.MimeType("application/x-sh").check_update(update) - update.message.document.mime_type = "application/x-sh" - assert filters.Document.Category("application/").check_update(update) - assert filters.Document.MimeType("application/x-sh").check_update(update) - - update.message.document.mime_type = None - assert not filters.Document.Category("application/").check_update(update) - assert not filters.Document.MimeType("application/x-sh").check_update(update) - - def test_filters_file_extension_basic(self, update): - update.message.document = Document( - "file_id", - "unique_id", - file_name="file.jpg", - mime_type="image/jpeg", - ) - assert filters.Document.FileExtension("jpg").check_update(update) - assert not filters.Document.FileExtension("jpeg").check_update(update) - assert not filters.Document.FileExtension("file.jpg").check_update(update) - - update.message.document.file_name = "file.tar.gz" - assert filters.Document.FileExtension("tar.gz").check_update(update) - assert filters.Document.FileExtension("gz").check_update(update) - assert not filters.Document.FileExtension("tgz").check_update(update) - assert not filters.Document.FileExtension("jpg").check_update(update) - - update.message.document.file_name = None - assert not filters.Document.FileExtension("jpg").check_update(update) - - update.message.document = None - assert not filters.Document.FileExtension("jpg").check_update(update) - - def test_filters_file_extension_minds_dots(self, update): - update.message.document = Document( - "file_id", - "unique_id", - file_name="file.jpg", - mime_type="image/jpeg", - ) - assert not filters.Document.FileExtension(".jpg").check_update(update) - assert not filters.Document.FileExtension("e.jpg").check_update(update) - assert not filters.Document.FileExtension("file.jpg").check_update(update) - assert not filters.Document.FileExtension("").check_update(update) - - update.message.document.file_name = "file..jpg" - assert filters.Document.FileExtension("jpg").check_update(update) - assert filters.Document.FileExtension(".jpg").check_update(update) - assert not filters.Document.FileExtension("..jpg").check_update(update) - - update.message.document.file_name = "file.docx" - assert filters.Document.FileExtension("docx").check_update(update) - assert not filters.Document.FileExtension("doc").check_update(update) - assert not filters.Document.FileExtension("ocx").check_update(update) - - update.message.document.file_name = "file" - assert not filters.Document.FileExtension("").check_update(update) - assert not filters.Document.FileExtension("file").check_update(update) - - update.message.document.file_name = "file." - assert filters.Document.FileExtension("").check_update(update) - - def test_filters_file_extension_none_arg(self, update): - update.message.document = Document( - "file_id", - "unique_id", - file_name="file.jpg", - mime_type="image/jpeg", - ) - assert not filters.Document.FileExtension(None).check_update(update) - - update.message.document.file_name = "file" - assert filters.Document.FileExtension(None).check_update(update) - assert not filters.Document.FileExtension("None").check_update(update) - - update.message.document.file_name = "file." - assert not filters.Document.FileExtension(None).check_update(update) - - update.message.document = None - assert not filters.Document.FileExtension(None).check_update(update) - - def test_filters_file_extension_case_sensitivity(self, update): - update.message.document = Document( - "file_id", - "unique_id", - file_name="file.jpg", - mime_type="image/jpeg", - ) - assert filters.Document.FileExtension("JPG").check_update(update) - assert filters.Document.FileExtension("jpG").check_update(update) - - update.message.document.file_name = "file.JPG" - assert filters.Document.FileExtension("jpg").check_update(update) - assert not filters.Document.FileExtension("jpg", case_sensitive=True).check_update(update) - - update.message.document.file_name = "file.Dockerfile" - assert filters.Document.FileExtension("Dockerfile", case_sensitive=True).check_update( - update - ) - assert not filters.Document.FileExtension("DOCKERFILE", case_sensitive=True).check_update( - update - ) - - def test_filters_file_extension_name(self): - assert filters.Document.FileExtension("jpg").name == ( - "filters.Document.FileExtension('jpg')" - ) - assert filters.Document.FileExtension("JPG").name == ( - "filters.Document.FileExtension('jpg')" - ) - assert filters.Document.FileExtension("jpg", case_sensitive=True).name == ( - "filters.Document.FileExtension('jpg', case_sensitive=True)" - ) - assert filters.Document.FileExtension("JPG", case_sensitive=True).name == ( - "filters.Document.FileExtension('JPG', case_sensitive=True)" - ) - assert filters.Document.FileExtension(".jpg").name == ( - "filters.Document.FileExtension('.jpg')" - ) - assert filters.Document.FileExtension("").name == "filters.Document.FileExtension('')" - assert filters.Document.FileExtension(None).name == "filters.Document.FileExtension(None)" - - def test_filters_animation(self, update): - assert not filters.ANIMATION.check_update(update) - update.message.animation = 'test' - assert filters.ANIMATION.check_update(update) - - def test_filters_photo(self, update): - assert not filters.PHOTO.check_update(update) - update.message.photo = 'test' - assert filters.PHOTO.check_update(update) - - def test_filters_sticker(self, update): - assert not filters.STICKER.check_update(update) - update.message.sticker = 'test' - assert filters.STICKER.check_update(update) - - def test_filters_video(self, update): - assert not filters.VIDEO.check_update(update) - update.message.video = 'test' - assert filters.VIDEO.check_update(update) - - def test_filters_voice(self, update): - assert not filters.VOICE.check_update(update) - update.message.voice = 'test' - assert filters.VOICE.check_update(update) - - def test_filters_video_note(self, update): - assert not filters.VIDEO_NOTE.check_update(update) - update.message.video_note = 'test' - assert filters.VIDEO_NOTE.check_update(update) - - def test_filters_contact(self, update): - assert not filters.CONTACT.check_update(update) - update.message.contact = 'test' - assert filters.CONTACT.check_update(update) - - def test_filters_location(self, update): - assert not filters.LOCATION.check_update(update) - update.message.location = 'test' - assert filters.LOCATION.check_update(update) - - def test_filters_venue(self, update): - assert not filters.VENUE.check_update(update) - update.message.venue = 'test' - assert filters.VENUE.check_update(update) - - def test_filters_status_update(self, update): - assert not filters.StatusUpdate.ALL.check_update(update) - - update.message.new_chat_members = ['test'] - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.NEW_CHAT_MEMBERS.check_update(update) - update.message.new_chat_members = None - - update.message.left_chat_member = 'test' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.LEFT_CHAT_MEMBER.check_update(update) - update.message.left_chat_member = None - - update.message.new_chat_title = 'test' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.NEW_CHAT_TITLE.check_update(update) - update.message.new_chat_title = '' - - update.message.new_chat_photo = 'test' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.NEW_CHAT_PHOTO.check_update(update) - update.message.new_chat_photo = None - - update.message.delete_chat_photo = True - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.DELETE_CHAT_PHOTO.check_update(update) - update.message.delete_chat_photo = False - - update.message.group_chat_created = True - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.CHAT_CREATED.check_update(update) - update.message.group_chat_created = False - - update.message.supergroup_chat_created = True - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.CHAT_CREATED.check_update(update) - update.message.supergroup_chat_created = False - - update.message.channel_chat_created = True - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.CHAT_CREATED.check_update(update) - update.message.channel_chat_created = False - - update.message.message_auto_delete_timer_changed = True - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.MESSAGE_AUTO_DELETE_TIMER_CHANGED.check_update(update) - update.message.message_auto_delete_timer_changed = False - - update.message.migrate_to_chat_id = 100 - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.MIGRATE.check_update(update) - update.message.migrate_to_chat_id = 0 - - update.message.migrate_from_chat_id = 100 - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.MIGRATE.check_update(update) - update.message.migrate_from_chat_id = 0 - - update.message.pinned_message = 'test' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.PINNED_MESSAGE.check_update(update) - update.message.pinned_message = None - - update.message.connected_website = 'https://example.com/' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.CONNECTED_WEBSITE.check_update(update) - update.message.connected_website = None - - update.message.proximity_alert_triggered = 'alert' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.PROXIMITY_ALERT_TRIGGERED.check_update(update) - update.message.proximity_alert_triggered = None - - update.message.voice_chat_scheduled = 'scheduled' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.VOICE_CHAT_SCHEDULED.check_update(update) - update.message.voice_chat_scheduled = None - - update.message.voice_chat_started = 'hello' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.VOICE_CHAT_STARTED.check_update(update) - update.message.voice_chat_started = None - - update.message.voice_chat_ended = 'bye' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.VOICE_CHAT_ENDED.check_update(update) - update.message.voice_chat_ended = None - - update.message.voice_chat_participants_invited = 'invited' - assert filters.StatusUpdate.ALL.check_update(update) - assert filters.StatusUpdate.VOICE_CHAT_PARTICIPANTS_INVITED.check_update(update) - update.message.voice_chat_participants_invited = None - - def test_filters_forwarded(self, update): - assert not filters.FORWARDED.check_update(update) - update.message.forward_date = datetime.datetime.utcnow() - assert filters.FORWARDED.check_update(update) - - def test_filters_game(self, update): - assert not filters.GAME.check_update(update) - update.message.game = 'test' - assert filters.GAME.check_update(update) - - def test_entities_filter(self, update, message_entity): - update.message.entities = [message_entity] - assert filters.Entity(message_entity.type).check_update(update) - - update.message.entities = [] - assert not filters.Entity(MessageEntity.MENTION).check_update(update) - - second = message_entity.to_dict() - second['type'] = 'bold' - second = MessageEntity.de_json(second, None) - update.message.entities = [message_entity, second] - assert filters.Entity(message_entity.type).check_update(update) - assert not filters.CaptionEntity(message_entity.type).check_update(update) - - def test_caption_entities_filter(self, update, message_entity): - update.message.caption_entities = [message_entity] - assert filters.CaptionEntity(message_entity.type).check_update(update) - - update.message.caption_entities = [] - assert not filters.CaptionEntity(MessageEntity.MENTION).check_update(update) - - second = message_entity.to_dict() - second['type'] = 'bold' - second = MessageEntity.de_json(second, None) - update.message.caption_entities = [message_entity, second] - assert filters.CaptionEntity(message_entity.type).check_update(update) - assert not filters.Entity(message_entity.type).check_update(update) - - @pytest.mark.parametrize( - 'chat_type, results', - [ - (Chat.PRIVATE, (True, False, False, False, False)), - (Chat.GROUP, (False, True, False, True, False)), - (Chat.SUPERGROUP, (False, False, True, True, False)), - (Chat.CHANNEL, (False, False, False, False, True)), - ], - ) - def test_filters_chat_types(self, update, chat_type, results): - update.message.chat.type = chat_type - assert filters.ChatType.PRIVATE.check_update(update) is results[0] - assert filters.ChatType.GROUP.check_update(update) is results[1] - assert filters.ChatType.SUPERGROUP.check_update(update) is results[2] - assert filters.ChatType.GROUPS.check_update(update) is results[3] - assert filters.ChatType.CHANNEL.check_update(update) is results[4] - - def test_filters_user_init(self): - with pytest.raises(RuntimeError, match='in conjunction with'): - filters.User(user_id=1, username='user') - - def test_filters_user_allow_empty(self, update): - assert not filters.User().check_update(update) - assert filters.User(allow_empty=True).check_update(update) - - def test_filters_user_id(self, update): - assert not filters.User(user_id=1).check_update(update) - update.message.from_user.id = 1 - assert filters.User(user_id=1).check_update(update) - assert filters.USER.check_update(update) - update.message.from_user.id = 2 - assert filters.User(user_id=[1, 2]).check_update(update) - assert not filters.User(user_id=[3, 4]).check_update(update) - update.message.from_user = None - assert not filters.USER.check_update(update) - assert not filters.User(user_id=[3, 4]).check_update(update) - - def test_filters_username(self, update): - assert not filters.User(username='user').check_update(update) - assert not filters.User(username='Testuser').check_update(update) - update.message.from_user.username = 'user@' - assert filters.User(username='@user@').check_update(update) - assert filters.User(username='user@').check_update(update) - assert filters.User(username=['user1', 'user@', 'user2']).check_update(update) - assert not filters.User(username=['@username', '@user_2']).check_update(update) - update.message.from_user = None - assert not filters.User(username=['@username', '@user_2']).check_update(update) - - def test_filters_user_change_id(self, update): - f = filters.User(user_id=1) - assert f.user_ids == {1} - update.message.from_user.id = 1 - assert f.check_update(update) - update.message.from_user.id = 2 - assert not f.check_update(update) - f.user_ids = 2 - assert f.user_ids == {2} - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.usernames = 'user' - - def test_filters_user_change_username(self, update): - f = filters.User(username='user') - update.message.from_user.username = 'user' - assert f.check_update(update) - update.message.from_user.username = 'User' - assert not f.check_update(update) - f.usernames = 'User' - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='user_id in conjunction'): - f.user_ids = 1 - - def test_filters_user_add_user_by_name(self, update): - users = ['user_a', 'user_b', 'user_c'] - f = filters.User() - - for user in users: - update.message.from_user.username = user - assert not f.check_update(update) - - f.add_usernames('user_a') - f.add_usernames(['user_b', 'user_c']) - - for user in users: - update.message.from_user.username = user - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='user_id in conjunction'): - f.add_user_ids(1) - - def test_filters_user_add_user_by_id(self, update): - users = [1, 2, 3] - f = filters.User() - - for user in users: - update.message.from_user.id = user - assert not f.check_update(update) - - f.add_user_ids(1) - f.add_user_ids([2, 3]) - - for user in users: - update.message.from_user.username = user - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.add_usernames('user') - - def test_filters_user_remove_user_by_name(self, update): - users = ['user_a', 'user_b', 'user_c'] - f = filters.User(username=users) - - with pytest.raises(RuntimeError, match='user_id in conjunction'): - f.remove_user_ids(1) - - for user in users: - update.message.from_user.username = user - assert f.check_update(update) - - f.remove_usernames('user_a') - f.remove_usernames(['user_b', 'user_c']) - - for user in users: - update.message.from_user.username = user - assert not f.check_update(update) - - def test_filters_user_remove_user_by_id(self, update): - users = [1, 2, 3] - f = filters.User(user_id=users) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.remove_usernames('user') - - for user in users: - update.message.from_user.id = user - assert f.check_update(update) - - f.remove_user_ids(1) - f.remove_user_ids([2, 3]) - - for user in users: - update.message.from_user.username = user - assert not f.check_update(update) - - def test_filters_user_repr(self): - f = filters.User([1, 2]) - assert str(f) == 'filters.User(1, 2)' - f.remove_user_ids(1) - f.remove_user_ids(2) - assert str(f) == 'filters.User()' - f.add_usernames('@foobar') - assert str(f) == 'filters.User(foobar)' - f.add_usernames('@barfoo') - assert str(f).startswith('filters.User(') - # we don't know th exact order - assert 'barfoo' in str(f) and 'foobar' in str(f) - - with pytest.raises(RuntimeError, match='Cannot set name'): - f.name = 'foo' - - def test_filters_chat_init(self): - with pytest.raises(RuntimeError, match='in conjunction with'): - filters.Chat(chat_id=1, username='chat') - - def test_filters_chat_allow_empty(self, update): - assert not filters.Chat().check_update(update) - assert filters.Chat(allow_empty=True).check_update(update) - - def test_filters_chat_id(self, update): - assert not filters.Chat(chat_id=1).check_update(update) - assert filters.CHAT.check_update(update) - update.message.chat.id = 1 - assert filters.Chat(chat_id=1).check_update(update) - assert filters.CHAT.check_update(update) - update.message.chat.id = 2 - assert filters.Chat(chat_id=[1, 2]).check_update(update) - assert not filters.Chat(chat_id=[3, 4]).check_update(update) - update.message.chat = None - assert not filters.CHAT.check_update(update) - assert not filters.Chat(chat_id=[3, 4]).check_update(update) - - def test_filters_chat_username(self, update): - assert not filters.Chat(username='chat').check_update(update) - assert not filters.Chat(username='Testchat').check_update(update) - update.message.chat.username = 'chat@' - assert filters.Chat(username='@chat@').check_update(update) - assert filters.Chat(username='chat@').check_update(update) - assert filters.Chat(username=['chat1', 'chat@', 'chat2']).check_update(update) - assert not filters.Chat(username=['@username', '@chat_2']).check_update(update) - update.message.chat = None - assert not filters.Chat(username=['@username', '@chat_2']).check_update(update) - - def test_filters_chat_change_id(self, update): - f = filters.Chat(chat_id=1) - assert f.chat_ids == {1} - update.message.chat.id = 1 - assert f.check_update(update) - update.message.chat.id = 2 - assert not f.check_update(update) - f.chat_ids = 2 - assert f.chat_ids == {2} - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.usernames = 'chat' - - def test_filters_chat_change_username(self, update): - f = filters.Chat(username='chat') - update.message.chat.username = 'chat' - assert f.check_update(update) - update.message.chat.username = 'User' - assert not f.check_update(update) - f.usernames = 'User' - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.chat_ids = 1 - - def test_filters_chat_add_chat_by_name(self, update): - chats = ['chat_a', 'chat_b', 'chat_c'] - f = filters.Chat() - - for chat in chats: - update.message.chat.username = chat - assert not f.check_update(update) - - f.add_usernames('chat_a') - f.add_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.chat.username = chat - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.add_chat_ids(1) - - def test_filters_chat_add_chat_by_id(self, update): - chats = [1, 2, 3] - f = filters.Chat() - - for chat in chats: - update.message.chat.id = chat - assert not f.check_update(update) - - f.add_chat_ids(1) - f.add_chat_ids([2, 3]) - - for chat in chats: - update.message.chat.username = chat - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.add_usernames('chat') - - def test_filters_chat_remove_chat_by_name(self, update): - chats = ['chat_a', 'chat_b', 'chat_c'] - f = filters.Chat(username=chats) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.remove_chat_ids(1) - - for chat in chats: - update.message.chat.username = chat - assert f.check_update(update) - - f.remove_usernames('chat_a') - f.remove_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.chat.username = chat - assert not f.check_update(update) - - def test_filters_chat_remove_chat_by_id(self, update): - chats = [1, 2, 3] - f = filters.Chat(chat_id=chats) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.remove_usernames('chat') - - for chat in chats: - update.message.chat.id = chat - assert f.check_update(update) - - f.remove_chat_ids(1) - f.remove_chat_ids([2, 3]) - - for chat in chats: - update.message.chat.username = chat - assert not f.check_update(update) - - def test_filters_chat_repr(self): - f = filters.Chat([1, 2]) - assert str(f) == 'filters.Chat(1, 2)' - f.remove_chat_ids(1) - f.remove_chat_ids(2) - assert str(f) == 'filters.Chat()' - f.add_usernames('@foobar') - assert str(f) == 'filters.Chat(foobar)' - f.add_usernames('@barfoo') - assert str(f).startswith('filters.Chat(') - # we don't know th exact order - assert 'barfoo' in str(f) and 'foobar' in str(f) - - with pytest.raises(RuntimeError, match='Cannot set name'): - f.name = 'foo' - - def test_filters_forwarded_from_init(self): - with pytest.raises(RuntimeError, match='in conjunction with'): - filters.ForwardedFrom(chat_id=1, username='chat') - - def test_filters_forwarded_from_allow_empty(self, update): - assert not filters.ForwardedFrom().check_update(update) - assert filters.ForwardedFrom(allow_empty=True).check_update(update) - - def test_filters_forwarded_from_id(self, update): - # Test with User id- - assert not filters.ForwardedFrom(chat_id=1).check_update(update) - update.message.forward_from.id = 1 - assert filters.ForwardedFrom(chat_id=1).check_update(update) - update.message.forward_from.id = 2 - assert filters.ForwardedFrom(chat_id=[1, 2]).check_update(update) - assert not filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) - update.message.forward_from = None - assert not filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) - - # Test with Chat id- - update.message.forward_from_chat.id = 4 - assert filters.ForwardedFrom(chat_id=[4]).check_update(update) - assert filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) - - update.message.forward_from_chat.id = 2 - assert not filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) - assert filters.ForwardedFrom(chat_id=2).check_update(update) - update.message.forward_from_chat = None - - def test_filters_forwarded_from_username(self, update): - # For User username - assert not filters.ForwardedFrom(username='chat').check_update(update) - assert not filters.ForwardedFrom(username='Testchat').check_update(update) - update.message.forward_from.username = 'chat@' - assert filters.ForwardedFrom(username='@chat@').check_update(update) - assert filters.ForwardedFrom(username='chat@').check_update(update) - assert filters.ForwardedFrom(username=['chat1', 'chat@', 'chat2']).check_update(update) - assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) - update.message.forward_from = None - assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) - - # For Chat username - assert not filters.ForwardedFrom(username='chat').check_update(update) - assert not filters.ForwardedFrom(username='Testchat').check_update(update) - update.message.forward_from_chat.username = 'chat@' - assert filters.ForwardedFrom(username='@chat@').check_update(update) - assert filters.ForwardedFrom(username='chat@').check_update(update) - assert filters.ForwardedFrom(username=['chat1', 'chat@', 'chat2']).check_update(update) - assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) - update.message.forward_from_chat = None - assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) - - def test_filters_forwarded_from_change_id(self, update): - f = filters.ForwardedFrom(chat_id=1) - # For User ids- - assert f.chat_ids == {1} - update.message.forward_from.id = 1 - assert f.check_update(update) - update.message.forward_from.id = 2 - assert not f.check_update(update) - f.chat_ids = 2 - assert f.chat_ids == {2} - assert f.check_update(update) - - # For Chat ids- - f = filters.ForwardedFrom(chat_id=1) # reset this - update.message.forward_from = None # and change this to None, only one of them can be True - assert f.chat_ids == {1} - update.message.forward_from_chat.id = 1 - assert f.check_update(update) - update.message.forward_from_chat.id = 2 - assert not f.check_update(update) - f.chat_ids = 2 - assert f.chat_ids == {2} - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.usernames = 'chat' - - def test_filters_forwarded_from_change_username(self, update): - # For User usernames - f = filters.ForwardedFrom(username='chat') - update.message.forward_from.username = 'chat' - assert f.check_update(update) - update.message.forward_from.username = 'User' - assert not f.check_update(update) - f.usernames = 'User' - assert f.check_update(update) - - # For Chat usernames - update.message.forward_from = None - f = filters.ForwardedFrom(username='chat') - update.message.forward_from_chat.username = 'chat' - assert f.check_update(update) - update.message.forward_from_chat.username = 'User' - assert not f.check_update(update) - f.usernames = 'User' - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.chat_ids = 1 - - def test_filters_forwarded_from_add_chat_by_name(self, update): - chats = ['chat_a', 'chat_b', 'chat_c'] - f = filters.ForwardedFrom() - - # For User usernames - for chat in chats: - update.message.forward_from.username = chat - assert not f.check_update(update) - - f.add_usernames('chat_a') - f.add_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.forward_from.username = chat - assert f.check_update(update) - - # For Chat usernames - update.message.forward_from = None - f = filters.ForwardedFrom() - for chat in chats: - update.message.forward_from_chat.username = chat - assert not f.check_update(update) - - f.add_usernames('chat_a') - f.add_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.forward_from_chat.username = chat - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.add_chat_ids(1) - - def test_filters_forwarded_from_add_chat_by_id(self, update): - chats = [1, 2, 3] - f = filters.ForwardedFrom() - - # For User ids - for chat in chats: - update.message.forward_from.id = chat - assert not f.check_update(update) - - f.add_chat_ids(1) - f.add_chat_ids([2, 3]) - - for chat in chats: - update.message.forward_from.username = chat - assert f.check_update(update) - - # For Chat ids- - update.message.forward_from = None - f = filters.ForwardedFrom() - for chat in chats: - update.message.forward_from_chat.id = chat - assert not f.check_update(update) - - f.add_chat_ids(1) - f.add_chat_ids([2, 3]) - - for chat in chats: - update.message.forward_from_chat.username = chat - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.add_usernames('chat') - - def test_filters_forwarded_from_remove_chat_by_name(self, update): - chats = ['chat_a', 'chat_b', 'chat_c'] - f = filters.ForwardedFrom(username=chats) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.remove_chat_ids(1) - - # For User usernames - for chat in chats: - update.message.forward_from.username = chat - assert f.check_update(update) - - f.remove_usernames('chat_a') - f.remove_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.forward_from.username = chat - assert not f.check_update(update) - - # For Chat usernames - update.message.forward_from = None - f = filters.ForwardedFrom(username=chats) - for chat in chats: - update.message.forward_from_chat.username = chat - assert f.check_update(update) - - f.remove_usernames('chat_a') - f.remove_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.forward_from_chat.username = chat - assert not f.check_update(update) - - def test_filters_forwarded_from_remove_chat_by_id(self, update): - chats = [1, 2, 3] - f = filters.ForwardedFrom(chat_id=chats) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.remove_usernames('chat') - - # For User ids - for chat in chats: - update.message.forward_from.id = chat - assert f.check_update(update) - - f.remove_chat_ids(1) - f.remove_chat_ids([2, 3]) - - for chat in chats: - update.message.forward_from.username = chat - assert not f.check_update(update) - - # For Chat ids - update.message.forward_from = None - f = filters.ForwardedFrom(chat_id=chats) - for chat in chats: - update.message.forward_from_chat.id = chat - assert f.check_update(update) - - f.remove_chat_ids(1) - f.remove_chat_ids([2, 3]) - - for chat in chats: - update.message.forward_from_chat.username = chat - assert not f.check_update(update) - - def test_filters_forwarded_from_repr(self): - f = filters.ForwardedFrom([1, 2]) - assert str(f) == 'filters.ForwardedFrom(1, 2)' - f.remove_chat_ids(1) - f.remove_chat_ids(2) - assert str(f) == 'filters.ForwardedFrom()' - f.add_usernames('@foobar') - assert str(f) == 'filters.ForwardedFrom(foobar)' - f.add_usernames('@barfoo') - assert str(f).startswith('filters.ForwardedFrom(') - # we don't know the exact order - assert 'barfoo' in str(f) and 'foobar' in str(f) - - with pytest.raises(RuntimeError, match='Cannot set name'): - f.name = 'foo' - - def test_filters_sender_chat_init(self): - with pytest.raises(RuntimeError, match='in conjunction with'): - filters.SenderChat(chat_id=1, username='chat') - - def test_filters_sender_chat_allow_empty(self, update): - assert not filters.SenderChat().check_update(update) - assert filters.SenderChat(allow_empty=True).check_update(update) - - def test_filters_sender_chat_id(self, update): - assert not filters.SenderChat(chat_id=1).check_update(update) - update.message.sender_chat.id = 1 - assert filters.SenderChat(chat_id=1).check_update(update) - update.message.sender_chat.id = 2 - assert filters.SenderChat(chat_id=[1, 2]).check_update(update) - assert not filters.SenderChat(chat_id=[3, 4]).check_update(update) - assert filters.SenderChat.ALL.check_update(update) - update.message.sender_chat = None - assert not filters.SenderChat(chat_id=[3, 4]).check_update(update) - assert not filters.SenderChat.ALL.check_update(update) - - def test_filters_sender_chat_username(self, update): - assert not filters.SenderChat(username='chat').check_update(update) - assert not filters.SenderChat(username='Testchat').check_update(update) - update.message.sender_chat.username = 'chat@' - assert filters.SenderChat(username='@chat@').check_update(update) - assert filters.SenderChat(username='chat@').check_update(update) - assert filters.SenderChat(username=['chat1', 'chat@', 'chat2']).check_update(update) - assert not filters.SenderChat(username=['@username', '@chat_2']).check_update(update) - assert filters.SenderChat.ALL.check_update(update) - update.message.sender_chat = None - assert not filters.SenderChat(username=['@username', '@chat_2']).check_update(update) - assert not filters.SenderChat.ALL.check_update(update) - - def test_filters_sender_chat_change_id(self, update): - f = filters.SenderChat(chat_id=1) - assert f.chat_ids == {1} - update.message.sender_chat.id = 1 - assert f.check_update(update) - update.message.sender_chat.id = 2 - assert not f.check_update(update) - f.chat_ids = 2 - assert f.chat_ids == {2} - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.usernames = 'chat' - - def test_filters_sender_chat_change_username(self, update): - f = filters.SenderChat(username='chat') - update.message.sender_chat.username = 'chat' - assert f.check_update(update) - update.message.sender_chat.username = 'User' - assert not f.check_update(update) - f.usernames = 'User' - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.chat_ids = 1 - - def test_filters_sender_chat_add_sender_chat_by_name(self, update): - chats = ['chat_a', 'chat_b', 'chat_c'] - f = filters.SenderChat() - - for chat in chats: - update.message.sender_chat.username = chat - assert not f.check_update(update) - - f.add_usernames('chat_a') - f.add_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.sender_chat.username = chat - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.add_chat_ids(1) - - def test_filters_sender_chat_add_sender_chat_by_id(self, update): - chats = [1, 2, 3] - f = filters.SenderChat() - - for chat in chats: - update.message.sender_chat.id = chat - assert not f.check_update(update) - - f.add_chat_ids(1) - f.add_chat_ids([2, 3]) - - for chat in chats: - update.message.sender_chat.username = chat - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.add_usernames('chat') - - def test_filters_sender_chat_remove_sender_chat_by_name(self, update): - chats = ['chat_a', 'chat_b', 'chat_c'] - f = filters.SenderChat(username=chats) - - with pytest.raises(RuntimeError, match='chat_id in conjunction'): - f.remove_chat_ids(1) - - for chat in chats: - update.message.sender_chat.username = chat - assert f.check_update(update) - - f.remove_usernames('chat_a') - f.remove_usernames(['chat_b', 'chat_c']) - - for chat in chats: - update.message.sender_chat.username = chat - assert not f.check_update(update) - - def test_filters_sender_chat_remove_sender_chat_by_id(self, update): - chats = [1, 2, 3] - f = filters.SenderChat(chat_id=chats) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.remove_usernames('chat') - - for chat in chats: - update.message.sender_chat.id = chat - assert f.check_update(update) - - f.remove_chat_ids(1) - f.remove_chat_ids([2, 3]) - - for chat in chats: - update.message.sender_chat.username = chat - assert not f.check_update(update) - - def test_filters_sender_chat_repr(self): - f = filters.SenderChat([1, 2]) - assert str(f) == 'filters.SenderChat(1, 2)' - f.remove_chat_ids(1) - f.remove_chat_ids(2) - assert str(f) == 'filters.SenderChat()' - f.add_usernames('@foobar') - assert str(f) == 'filters.SenderChat(foobar)' - f.add_usernames('@barfoo') - assert str(f).startswith('filters.SenderChat(') - # we don't know th exact order - assert 'barfoo' in str(f) and 'foobar' in str(f) - - with pytest.raises(RuntimeError, match='Cannot set name'): - f.name = 'foo' - - def test_filters_sender_chat_super_group(self, update): - update.message.sender_chat.type = Chat.PRIVATE - assert not filters.SenderChat.SUPER_GROUP.check_update(update) - assert filters.SenderChat.ALL.check_update(update) - update.message.sender_chat.type = Chat.CHANNEL - assert not filters.SenderChat.SUPER_GROUP.check_update(update) - update.message.sender_chat.type = Chat.SUPERGROUP - assert filters.SenderChat.SUPER_GROUP.check_update(update) - assert filters.SenderChat.ALL.check_update(update) - update.message.sender_chat = None - assert not filters.SenderChat.SUPER_GROUP.check_update(update) - assert not filters.SenderChat.ALL.check_update(update) - - def test_filters_sender_chat_channel(self, update): - update.message.sender_chat.type = Chat.PRIVATE - assert not filters.SenderChat.CHANNEL.check_update(update) - update.message.sender_chat.type = Chat.SUPERGROUP - assert not filters.SenderChat.CHANNEL.check_update(update) - update.message.sender_chat.type = Chat.CHANNEL - assert filters.SenderChat.CHANNEL.check_update(update) - update.message.sender_chat = None - assert not filters.SenderChat.CHANNEL.check_update(update) - - def test_filters_is_automatic_forward(self, update): - assert not filters.IS_AUTOMATIC_FORWARD.check_update(update) - update.message.is_automatic_forward = True - assert filters.IS_AUTOMATIC_FORWARD.check_update(update) - - def test_filters_has_protected_content(self, update): - assert not filters.HAS_PROTECTED_CONTENT.check_update(update) - update.message.has_protected_content = True - assert filters.HAS_PROTECTED_CONTENT.check_update(update) - - def test_filters_invoice(self, update): - assert not filters.INVOICE.check_update(update) - update.message.invoice = 'test' - assert filters.INVOICE.check_update(update) - - def test_filters_successful_payment(self, update): - assert not filters.SUCCESSFUL_PAYMENT.check_update(update) - update.message.successful_payment = 'test' - assert filters.SUCCESSFUL_PAYMENT.check_update(update) - - def test_filters_passport_data(self, update): - assert not filters.PASSPORT_DATA.check_update(update) - update.message.passport_data = 'test' - assert filters.PASSPORT_DATA.check_update(update) - - def test_filters_poll(self, update): - assert not filters.POLL.check_update(update) - update.message.poll = 'test' - assert filters.POLL.check_update(update) - - @pytest.mark.parametrize('emoji', Dice.ALL_EMOJI) - def test_filters_dice(self, update, emoji): - update.message.dice = Dice(4, emoji) - assert filters.Dice.ALL.check_update(update) and filters.Dice().check_update(update) - - to_camel = emoji.name.title().replace('_', '') - assert repr(filters.Dice.ALL) == "filters.Dice.ALL" - assert repr(getattr(filters.Dice, to_camel)(4)) == f"filters.Dice.{to_camel}([4])" - - update.message.dice = None - assert not filters.Dice.ALL.check_update(update) - - @pytest.mark.parametrize('emoji', Dice.ALL_EMOJI) - def test_filters_dice_list(self, update, emoji): - update.message.dice = None - assert not filters.Dice(5).check_update(update) - - update.message.dice = Dice(5, emoji) - assert filters.Dice(5).check_update(update) - assert repr(filters.Dice(5)) == "filters.Dice([5])" - assert filters.Dice({5, 6}).check_update(update) - assert not filters.Dice(1).check_update(update) - assert not filters.Dice([2, 3]).check_update(update) - - def test_filters_dice_type(self, update): - update.message.dice = Dice(5, '🎲') - assert filters.Dice.DICE.check_update(update) - assert repr(filters.Dice.DICE) == "filters.Dice.DICE" - assert filters.Dice.Dice([4, 5]).check_update(update) - assert not filters.Dice.Darts(5).check_update(update) - assert not filters.Dice.BASKETBALL.check_update(update) - assert not filters.Dice.Dice([6]).check_update(update) - - update.message.dice = Dice(5, '🎯') - assert filters.Dice.DARTS.check_update(update) - assert filters.Dice.Darts([4, 5]).check_update(update) - assert not filters.Dice.Dice(5).check_update(update) - assert not filters.Dice.BASKETBALL.check_update(update) - assert not filters.Dice.Darts([6]).check_update(update) - - update.message.dice = Dice(5, '🏀') - assert filters.Dice.BASKETBALL.check_update(update) - assert filters.Dice.Basketball([4, 5]).check_update(update) - assert not filters.Dice.Dice(5).check_update(update) - assert not filters.Dice.DARTS.check_update(update) - assert not filters.Dice.Basketball([4]).check_update(update) - - update.message.dice = Dice(5, '⚽') - assert filters.Dice.FOOTBALL.check_update(update) - assert filters.Dice.Football([4, 5]).check_update(update) - assert not filters.Dice.Dice(5).check_update(update) - assert not filters.Dice.DARTS.check_update(update) - assert not filters.Dice.Football([4]).check_update(update) - - update.message.dice = Dice(5, '🎰') - assert filters.Dice.SLOT_MACHINE.check_update(update) - assert filters.Dice.SlotMachine([4, 5]).check_update(update) - assert not filters.Dice.Dice(5).check_update(update) - assert not filters.Dice.DARTS.check_update(update) - assert not filters.Dice.SlotMachine([4]).check_update(update) - - update.message.dice = Dice(5, '🎳') - assert filters.Dice.BOWLING.check_update(update) - assert filters.Dice.Bowling([4, 5]).check_update(update) - assert not filters.Dice.Dice(5).check_update(update) - assert not filters.Dice.DARTS.check_update(update) - assert not filters.Dice.Bowling([4]).check_update(update) - - def test_language_filter_single(self, update): - update.message.from_user.language_code = 'en_US' - assert filters.Language('en_US').check_update(update) - assert filters.Language('en').check_update(update) - assert not filters.Language('en_GB').check_update(update) - assert not filters.Language('da').check_update(update) - update.message.from_user.language_code = 'da' - assert not filters.Language('en_US').check_update(update) - assert not filters.Language('en').check_update(update) - assert not filters.Language('en_GB').check_update(update) - assert filters.Language('da').check_update(update) - - update.message.from_user = None - assert not filters.Language('da').check_update(update) - - def test_language_filter_multiple(self, update): - f = filters.Language(['en_US', 'da']) - update.message.from_user.language_code = 'en_US' - assert f.check_update(update) - update.message.from_user.language_code = 'en_GB' - assert not f.check_update(update) - update.message.from_user.language_code = 'da' - assert f.check_update(update) - - def test_and_filters(self, update): - update.message.text = 'test' - update.message.forward_date = datetime.datetime.utcnow() - assert (filters.TEXT & filters.FORWARDED).check_update(update) - update.message.text = '/test' - assert (filters.TEXT & filters.FORWARDED).check_update(update) - update.message.text = 'test' - update.message.forward_date = None - assert not (filters.TEXT & filters.FORWARDED).check_update(update) - - update.message.text = 'test' - update.message.forward_date = datetime.datetime.utcnow() - assert (filters.TEXT & filters.FORWARDED & filters.ChatType.PRIVATE).check_update(update) - - def test_or_filters(self, update): - update.message.text = 'test' - assert (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) - update.message.group_chat_created = True - assert (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) - update.message.text = None - assert (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) - update.message.group_chat_created = False - assert not (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) - - def test_and_or_filters(self, update): - update.message.text = 'test' - update.message.forward_date = datetime.datetime.utcnow() - assert (filters.TEXT & (filters.StatusUpdate.ALL | filters.FORWARDED)).check_update(update) - update.message.forward_date = None - assert not (filters.TEXT & (filters.FORWARDED | filters.StatusUpdate.ALL)).check_update( - update - ) - update.message.pinned_message = True - assert filters.TEXT & (filters.FORWARDED | filters.StatusUpdate.ALL).check_update(update) - - assert ( - str(filters.TEXT & (filters.FORWARDED | filters.Entity(MessageEntity.MENTION))) - == '>' - ) - - def test_xor_filters(self, update): - update.message.text = 'test' - update.effective_user.id = 123 - assert not (filters.TEXT ^ filters.User(123)).check_update(update) - update.message.text = None - update.effective_user.id = 1234 - assert not (filters.TEXT ^ filters.User(123)).check_update(update) - update.message.text = 'test' - assert (filters.TEXT ^ filters.User(123)).check_update(update) - update.message.text = None - update.effective_user.id = 123 - assert (filters.TEXT ^ filters.User(123)).check_update(update) - - def test_xor_filters_repr(self, update): - assert str(filters.TEXT ^ filters.User(123)) == '' - with pytest.raises(RuntimeError, match='Cannot set name'): - (filters.TEXT ^ filters.User(123)).name = 'foo' - - def test_and_xor_filters(self, update): - update.message.text = 'test' - update.message.forward_date = datetime.datetime.utcnow() - assert (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) - update.message.text = None - update.effective_user.id = 123 - assert (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) - update.message.text = 'test' - assert not (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) - update.message.forward_date = None - update.message.text = None - update.effective_user.id = 123 - assert not (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) - update.message.text = 'test' - update.effective_user.id = 456 - assert not (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) - - assert ( - str(filters.FORWARDED & (filters.TEXT ^ filters.User(123))) - == '>' - ) - - def test_xor_regex_filters(self, update): - sre_type = type(re.match("", "")) - update.message.text = 'test' - update.message.forward_date = datetime.datetime.utcnow() - assert not (filters.FORWARDED ^ filters.Regex('^test$')).check_update(update) - update.message.forward_date = None - result = (filters.FORWARDED ^ filters.Regex('^test$')).check_update(update) - assert result - assert isinstance(result, dict) - matches = result['matches'] - assert isinstance(matches, list) - assert type(matches[0]) is sre_type - update.message.forward_date = datetime.datetime.utcnow() - update.message.text = None - assert (filters.FORWARDED ^ filters.Regex('^test$')).check_update(update) is True - - def test_inverted_filters(self, update): - update.message.text = '/test' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - assert filters.COMMAND.check_update(update) - assert not (~filters.COMMAND).check_update(update) - update.message.text = 'test' - update.message.entities = [] - assert not filters.COMMAND.check_update(update) - assert (~filters.COMMAND).check_update(update) - - def test_inverted_filters_repr(self, update): - assert str(~filters.TEXT) == '' - with pytest.raises(RuntimeError, match='Cannot set name'): - (~filters.TEXT).name = 'foo' - - def test_inverted_and_filters(self, update): - update.message.text = '/test' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - update.message.forward_date = 1 - assert (filters.FORWARDED & filters.COMMAND).check_update(update) - assert not (~filters.FORWARDED & filters.COMMAND).check_update(update) - assert not (filters.FORWARDED & ~filters.COMMAND).check_update(update) - assert not (~(filters.FORWARDED & filters.COMMAND)).check_update(update) - update.message.forward_date = None - assert not (filters.FORWARDED & filters.COMMAND).check_update(update) - assert (~filters.FORWARDED & filters.COMMAND).check_update(update) - assert not (filters.FORWARDED & ~filters.COMMAND).check_update(update) - assert (~(filters.FORWARDED & filters.COMMAND)).check_update(update) - update.message.text = 'test' - update.message.entities = [] - assert not (filters.FORWARDED & filters.COMMAND).check_update(update) - assert not (~filters.FORWARDED & filters.COMMAND).check_update(update) - assert not (filters.FORWARDED & ~filters.COMMAND).check_update(update) - assert (~(filters.FORWARDED & filters.COMMAND)).check_update(update) - - def test_indirect_message(self, update): - class _CustomFilter(filters.MessageFilter): - test_flag = False - - def filter(self, message: Message): - self.test_flag = True - return self.test_flag - - c = _CustomFilter() - u = Update(0, callback_query=CallbackQuery('0', update.effective_user, '', update.message)) - assert not c.check_update(u) - assert not c.test_flag - assert c.check_update(update) - assert c.test_flag - - def test_custom_unnamed_filter(self, update, base_class): - class Unnamed(base_class): - def filter(self, _): - return True - - unnamed = Unnamed() - assert str(unnamed) == Unnamed.__name__ - - def test_update_type_message(self, update): - assert filters.UpdateType.MESSAGE.check_update(update) - assert not filters.UpdateType.EDITED_MESSAGE.check_update(update) - assert filters.UpdateType.MESSAGES.check_update(update) - assert not filters.UpdateType.CHANNEL_POST.check_update(update) - assert not filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) - assert not filters.UpdateType.CHANNEL_POSTS.check_update(update) - assert not filters.UpdateType.EDITED.check_update(update) - - def test_update_type_edited_message(self, update): - update.edited_message, update.message = update.message, update.edited_message - assert not filters.UpdateType.MESSAGE.check_update(update) - assert filters.UpdateType.EDITED_MESSAGE.check_update(update) - assert filters.UpdateType.MESSAGES.check_update(update) - assert not filters.UpdateType.CHANNEL_POST.check_update(update) - assert not filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) - assert not filters.UpdateType.CHANNEL_POSTS.check_update(update) - assert filters.UpdateType.EDITED.check_update(update) - - def test_update_type_channel_post(self, update): - update.channel_post, update.message = update.message, update.edited_message - assert not filters.UpdateType.MESSAGE.check_update(update) - assert not filters.UpdateType.EDITED_MESSAGE.check_update(update) - assert not filters.UpdateType.MESSAGES.check_update(update) - assert filters.UpdateType.CHANNEL_POST.check_update(update) - assert not filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) - assert filters.UpdateType.CHANNEL_POSTS.check_update(update) - assert not filters.UpdateType.EDITED.check_update(update) - - def test_update_type_edited_channel_post(self, update): - update.edited_channel_post, update.message = update.message, update.edited_message - assert not filters.UpdateType.MESSAGE.check_update(update) - assert not filters.UpdateType.EDITED_MESSAGE.check_update(update) - assert not filters.UpdateType.MESSAGES.check_update(update) - assert not filters.UpdateType.CHANNEL_POST.check_update(update) - assert filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) - assert filters.UpdateType.CHANNEL_POSTS.check_update(update) - assert filters.UpdateType.EDITED.check_update(update) - - def test_merged_short_circuit_and(self, update, base_class): - update.message.text = '/test' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - - class TestException(Exception): - pass - - class RaisingFilter(base_class): - def filter(self, _): - raise TestException - - raising_filter = RaisingFilter() - - with pytest.raises(TestException): - (filters.COMMAND & raising_filter).check_update(update) - - update.message.text = 'test' - update.message.entities = [] - (filters.COMMAND & raising_filter).check_update(update) - - def test_merged_filters_repr(self, update): - with pytest.raises(RuntimeError, match='Cannot set name'): - (filters.TEXT & filters.PHOTO).name = 'foo' - - def test_merged_short_circuit_or(self, update, base_class): - update.message.text = 'test' - - class TestException(Exception): - pass - - class RaisingFilter(base_class): - def filter(self, _): - raise TestException - - raising_filter = RaisingFilter() - - with pytest.raises(TestException): - (filters.COMMAND | raising_filter).check_update(update) - - update.message.text = '/test' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - (filters.COMMAND | raising_filter).check_update(update) - - def test_merged_data_merging_and(self, update, base_class): - update.message.text = '/test' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] - - class DataFilter(base_class): - data_filter = True - - def __init__(self, data): - self.data = data - - def filter(self, _): - return {'test': [self.data]} - - result = (filters.COMMAND & DataFilter('blah')).check_update(update) - assert result['test'] == ['blah'] - - result = (DataFilter('blah1') & DataFilter('blah2')).check_update(update) - assert result['test'] == ['blah1', 'blah2'] - - update.message.text = 'test' - update.message.entities = [] - result = (filters.COMMAND & DataFilter('blah')).check_update(update) - assert not result - - def test_merged_data_merging_or(self, update, base_class): - update.message.text = '/test' - - class DataFilter(base_class): - data_filter = True - - def __init__(self, data): - self.data = data - - def filter(self, _): - return {'test': [self.data]} - - result = (filters.COMMAND | DataFilter('blah')).check_update(update) - assert result - - result = (DataFilter('blah1') | DataFilter('blah2')).check_update(update) - assert result['test'] == ['blah1'] - - update.message.text = 'test' - result = (filters.COMMAND | DataFilter('blah')).check_update(update) - assert result['test'] == ['blah'] - - def test_filters_via_bot_init(self): - with pytest.raises(RuntimeError, match='in conjunction with'): - filters.ViaBot(bot_id=1, username='bot') - - def test_filters_via_bot_allow_empty(self, update): - assert not filters.ViaBot().check_update(update) - assert filters.ViaBot(allow_empty=True).check_update(update) - - def test_filters_via_bot_id(self, update): - assert not filters.ViaBot(bot_id=1).check_update(update) - update.message.via_bot.id = 1 - assert filters.ViaBot(bot_id=1).check_update(update) - update.message.via_bot.id = 2 - assert filters.ViaBot(bot_id=[1, 2]).check_update(update) - assert not filters.ViaBot(bot_id=[3, 4]).check_update(update) - update.message.via_bot = None - assert not filters.ViaBot(bot_id=[3, 4]).check_update(update) - - def test_filters_via_bot_username(self, update): - assert not filters.ViaBot(username='bot').check_update(update) - assert not filters.ViaBot(username='Testbot').check_update(update) - update.message.via_bot.username = 'bot@' - assert filters.ViaBot(username='@bot@').check_update(update) - assert filters.ViaBot(username='bot@').check_update(update) - assert filters.ViaBot(username=['bot1', 'bot@', 'bot2']).check_update(update) - assert not filters.ViaBot(username=['@username', '@bot_2']).check_update(update) - update.message.via_bot = None - assert not filters.User(username=['@username', '@bot_2']).check_update(update) - - def test_filters_via_bot_change_id(self, update): - f = filters.ViaBot(bot_id=3) - assert f.bot_ids == {3} - update.message.via_bot.id = 3 - assert f.check_update(update) - update.message.via_bot.id = 2 - assert not f.check_update(update) - f.bot_ids = 2 - assert f.bot_ids == {2} - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.usernames = 'user' - - def test_filters_via_bot_change_username(self, update): - f = filters.ViaBot(username='bot') - update.message.via_bot.username = 'bot' - assert f.check_update(update) - update.message.via_bot.username = 'Bot' - assert not f.check_update(update) - f.usernames = 'Bot' - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='bot_id in conjunction'): - f.bot_ids = 1 - - def test_filters_via_bot_add_user_by_name(self, update): - users = ['bot_a', 'bot_b', 'bot_c'] - f = filters.ViaBot() - - for user in users: - update.message.via_bot.username = user - assert not f.check_update(update) - - f.add_usernames('bot_a') - f.add_usernames(['bot_b', 'bot_c']) - - for user in users: - update.message.via_bot.username = user - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='bot_id in conjunction'): - f.add_bot_ids(1) - - def test_filters_via_bot_add_user_by_id(self, update): - users = [1, 2, 3] - f = filters.ViaBot() - - for user in users: - update.message.via_bot.id = user - assert not f.check_update(update) - - f.add_bot_ids(1) - f.add_bot_ids([2, 3]) - - for user in users: - update.message.via_bot.username = user - assert f.check_update(update) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.add_usernames('bot') - - def test_filters_via_bot_remove_user_by_name(self, update): - users = ['bot_a', 'bot_b', 'bot_c'] - f = filters.ViaBot(username=users) - - with pytest.raises(RuntimeError, match='bot_id in conjunction'): - f.remove_bot_ids(1) - - for user in users: - update.message.via_bot.username = user - assert f.check_update(update) - - f.remove_usernames('bot_a') - f.remove_usernames(['bot_b', 'bot_c']) - - for user in users: - update.message.via_bot.username = user - assert not f.check_update(update) - - def test_filters_via_bot_remove_user_by_id(self, update): - users = [1, 2, 3] - f = filters.ViaBot(bot_id=users) - - with pytest.raises(RuntimeError, match='username in conjunction'): - f.remove_usernames('bot') - - for user in users: - update.message.via_bot.id = user - assert f.check_update(update) - - f.remove_bot_ids(1) - f.remove_bot_ids([2, 3]) - - for user in users: - update.message.via_bot.username = user - assert not f.check_update(update) - - def test_filters_via_bot_repr(self): - f = filters.ViaBot([1, 2]) - assert str(f) == 'filters.ViaBot(1, 2)' - f.remove_bot_ids(1) - f.remove_bot_ids(2) - assert str(f) == 'filters.ViaBot()' - f.add_usernames('@foobar') - assert str(f) == 'filters.ViaBot(foobar)' - f.add_usernames('@barfoo') - assert str(f).startswith('filters.ViaBot(') - # we don't know th exact order - assert 'barfoo' in str(f) and 'foobar' in str(f) - - with pytest.raises(RuntimeError, match='Cannot set name'): - f.name = 'foo' - - def test_filters_attachment(self, update): - assert not filters.ATTACHMENT.check_update(update) - # we need to define a new Update (or rather, message class) here because - # effective_attachment is only evaluated once per instance, and the filter relies on that - up = Update( - 0, - Message( - 0, - datetime.datetime.utcnow(), - Chat(0, 'private'), - document=Document("str", "other_str"), - ), - ) - assert filters.ATTACHMENT.check_update(up) diff --git a/tests/test_forcereply.py b/tests/test_forcereply.py index d21c9e0d193..15d35c9d544 100644 --- a/tests/test_forcereply.py +++ b/tests/test_forcereply.py @@ -42,8 +42,9 @@ def test_slot_behaviour(self, force_reply, mro_slots): assert len(mro_slots(force_reply)) == len(set(mro_slots(force_reply))), "duplicate slot" @flaky(3, 1) - def test_send_message_with_force_reply(self, bot, chat_id, force_reply): - message = bot.send_message(chat_id, 'text', reply_markup=force_reply) + @pytest.mark.asyncio + async def test_send_message_with_force_reply(self, bot, chat_id, force_reply): + message = await bot.send_message(chat_id, 'text', reply_markup=force_reply) assert message.text == 'text' diff --git a/tests/test_inlinekeyboardmarkup.py b/tests/test_inlinekeyboardmarkup.py index abe9b744fca..84859eef41e 100644 --- a/tests/test_inlinekeyboardmarkup.py +++ b/tests/test_inlinekeyboardmarkup.py @@ -49,8 +49,11 @@ def test_slot_behaviour(self, inline_keyboard_markup, mro_slots): assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" @flaky(3, 1) - def test_send_message_with_inline_keyboard_markup(self, bot, chat_id, inline_keyboard_markup): - message = bot.send_message( + @pytest.mark.asyncio + async def test_send_message_with_inline_keyboard_markup( + self, bot, chat_id, inline_keyboard_markup + ): + message = await bot.send_message( chat_id, 'Testing InlineKeyboardMarkup', reply_markup=inline_keyboard_markup ) @@ -95,8 +98,9 @@ def test_wrong_keyboard_inputs(self): with pytest.raises(ValueError): InlineKeyboardMarkup(InlineKeyboardButton('b1', '1')) - def test_expected_values_empty_switch(self, inline_keyboard_markup, bot, monkeypatch): - def test( + @pytest.mark.asyncio + async def test_expected_values_empty_switch(self, inline_keyboard_markup, bot, monkeypatch): + async def make_assertion( url, data, reply_to_message_id=None, @@ -125,8 +129,8 @@ def test( inline_keyboard_markup.inline_keyboard[0][1].callback_data = None inline_keyboard_markup.inline_keyboard[0][1].switch_inline_query_current_chat = '' - monkeypatch.setattr(bot, '_message', test) - bot.send_message(123, 'test', reply_markup=inline_keyboard_markup) + monkeypatch.setattr(bot, '_send_message', make_assertion) + await bot.send_message(123, 'test', reply_markup=inline_keyboard_markup) def test_to_dict(self, inline_keyboard_markup): inline_keyboard_markup_dict = inline_keyboard_markup.to_dict() diff --git a/tests/test_inlinequery.py b/tests/test_inlinequery.py index 31589bbf233..fdd15a1fdf7 100644 --- a/tests/test_inlinequery.py +++ b/tests/test_inlinequery.py @@ -19,7 +19,7 @@ import pytest -from telegram import User, Location, InlineQuery, Update, Bot, Chat +from telegram import User, Location, InlineQuery, Update, Bot from tests.conftest import check_shortcut_signature, check_shortcut_call, check_defaults_handling @@ -31,7 +31,6 @@ def inline_query(bot): TestInlineQuery.query, TestInlineQuery.offset, location=TestInlineQuery.location, - chat_type=TestInlineQuery.chat_type, bot=bot, ) @@ -42,12 +41,6 @@ class TestInlineQuery: query = 'query text' offset = 'offset' location = Location(8.8, 53.1) - chat_type = Chat.SENDER - - def test_slot_behaviour(self, inline_query, mro_slots): - for attr in inline_query.__slots__: - assert getattr(inline_query, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inline_query)) == len(set(mro_slots(inline_query))), "duplicate slot" def test_de_json(self, bot): json_dict = { @@ -56,7 +49,6 @@ def test_de_json(self, bot): 'query': self.query, 'offset': self.offset, 'location': self.location.to_dict(), - 'chat_type': self.chat_type, } inline_query_json = InlineQuery.de_json(json_dict, bot) @@ -65,7 +57,6 @@ def test_de_json(self, bot): assert inline_query_json.location == self.location assert inline_query_json.query == self.query assert inline_query_json.offset == self.offset - assert inline_query_json.chat_type == self.chat_type def test_to_dict(self, inline_query): inline_query_dict = inline_query.to_dict() @@ -76,35 +67,37 @@ def test_to_dict(self, inline_query): assert inline_query_dict['location'] == inline_query.location.to_dict() assert inline_query_dict['query'] == inline_query.query assert inline_query_dict['offset'] == inline_query.offset - assert inline_query_dict['chat_type'] == inline_query.chat_type - def test_answer(self, monkeypatch, inline_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_answer(self, monkeypatch, inline_query): + async def make_assertion(*_, **kwargs): return kwargs['inline_query_id'] == inline_query.id assert check_shortcut_signature( InlineQuery.answer, Bot.answer_inline_query, ['inline_query_id'], ['auto_pagination'] ) - assert check_shortcut_call( + assert await check_shortcut_call( inline_query.answer, inline_query.get_bot(), 'answer_inline_query' ) - assert check_defaults_handling(inline_query.answer, inline_query.get_bot()) + assert await check_defaults_handling(inline_query.answer, inline_query.get_bot()) monkeypatch.setattr(inline_query.get_bot(), 'answer_inline_query', make_assertion) - assert inline_query.answer(results=[]) + assert await inline_query.answer(results=[]) - def test_answer_error(self, inline_query): + @pytest.mark.asyncio + async def test_answer_error(self, inline_query): with pytest.raises(ValueError, match='mutually exclusive'): - inline_query.answer(results=[], auto_pagination=True, current_offset='foobar') + await inline_query.answer(results=[], auto_pagination=True, current_offset='foobar') - def test_answer_auto_pagination(self, monkeypatch, inline_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_answer_auto_pagination(self, monkeypatch, inline_query): + async def make_assertion(*_, **kwargs): inline_query_id_matches = kwargs['inline_query_id'] == inline_query.id offset_matches = kwargs.get('current_offset') == inline_query.offset return offset_matches and inline_query_id_matches monkeypatch.setattr(inline_query.get_bot(), 'answer_inline_query', make_assertion) - assert inline_query.answer(results=[], auto_pagination=True) + assert await inline_query.answer(results=[], auto_pagination=True) def test_equality(self): a = InlineQuery(self.id_, User(1, '', False), '', '') diff --git a/tests/test_inlinequeryhandler.py b/tests/test_inlinequeryhandler.py deleted file mode 100644 index fc0be644a21..00000000000 --- a/tests/test_inlinequeryhandler.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - CallbackQuery, - Bot, - Message, - User, - Chat, - InlineQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, - Location, -) -from telegram.ext import InlineQueryHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=2, **request.param) - - -@pytest.fixture(scope='function') -def inline_query(bot): - return Update( - 0, - inline_query=InlineQuery( - 'id', - User(2, 'test user', False), - 'test query', - offset='22', - location=Location(latitude=-23.691288, longitude=-46.788279), - ), - ) - - -class TestInlineQueryHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - handler = InlineQueryHandler(self.callback_context) - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.inline_query, InlineQuery) - ) - - def callback_context_pattern(self, update, context): - if context.matches[0].groups(): - self.test_flag = context.matches[0].groups() == ('t', ' query') - if context.matches[0].groupdict(): - self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' query'} - - def test_other_update_types(self, false_update): - handler = InlineQueryHandler(self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp, inline_query): - handler = InlineQueryHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(inline_query) - assert self.test_flag - - def test_context_pattern(self, dp, inline_query): - handler = InlineQueryHandler( - self.callback_context_pattern, pattern=r'(?P.*)est(?P.*)' - ) - dp.add_handler(handler) - - dp.process_update(inline_query) - assert self.test_flag - - dp.remove_handler(handler) - handler = InlineQueryHandler(self.callback_context_pattern, pattern=r'(t)est(.*)') - dp.add_handler(handler) - - dp.process_update(inline_query) - assert self.test_flag - - @pytest.mark.parametrize('chat_types', [[Chat.SENDER], [Chat.SENDER, Chat.SUPERGROUP], []]) - @pytest.mark.parametrize( - 'chat_type,result', [(Chat.SENDER, True), (Chat.CHANNEL, False), (None, False)] - ) - def test_chat_types(self, dp, inline_query, chat_types, chat_type, result): - try: - inline_query.inline_query.chat_type = chat_type - - handler = InlineQueryHandler(self.callback_context, chat_types=chat_types) - dp.add_handler(handler) - dp.process_update(inline_query) - - if not chat_types: - assert self.test_flag is False - else: - assert self.test_flag == result - - finally: - inline_query.inline_query.chat_type = None diff --git a/tests/test_inputfile.py b/tests/test_inputfile.py index 84ae4c3dac5..22634068fc6 100644 --- a/tests/test_inputfile.py +++ b/tests/test_inputfile.py @@ -122,14 +122,28 @@ def read(self): == 'blah.jpg' ) - def test_send_bytes(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_bytes(self, bot, chat_id): # We test this here and not at the respective test modules because it's not worth # duplicating the test for the different methods - message = bot.send_document(chat_id, data_file('text_file.txt').read_bytes()) + message = await bot.send_document(chat_id, data_file('text_file.txt').read_bytes()) out = BytesIO() - assert message.document.get_file().download(out=out) + assert await (await message.document.get_file()).download(out=out) + out.seek(0) + + assert out.read().decode('utf-8') == 'PTB Rocks! ⅞' + + @pytest.mark.asyncio + async def test_send_string(self, bot, chat_id): + # We test this here and not at the respective test modules because it's not worth + # duplicating the test for the different methods + message = await bot.send_document( + chat_id, InputFile(data_file('text_file.txt').read_text(encoding='utf-8')) + ) + out = BytesIO() + assert await (await message.document.get_file()).download(out=out) out.seek(0) - assert out.read().decode('utf-8') == 'PTB Rocks!' + assert out.read().decode('utf-8') == 'PTB Rocks! ⅞' diff --git a/tests/test_inputmedia.py b/tests/test_inputmedia.py index 885a1128a15..015eb027e05 100644 --- a/tests/test_inputmedia.py +++ b/tests/test_inputmedia.py @@ -33,6 +33,7 @@ # noinspection PyUnresolvedReferences from telegram.error import BadRequest +from telegram.request import RequestData from .test_animation import animation, animation_file # noqa: F401 # noinspection PyUnresolvedReferences @@ -430,8 +431,9 @@ def media_group(photo, thumb): # noqa: F811 class TestSendMediaGroup: @flaky(3, 1) - def test_send_media_group_photo(self, bot, chat_id, media_group): - messages = bot.send_media_group(chat_id, media_group) + @pytest.mark.asyncio + async def test_send_media_group_photo(self, bot, chat_id, media_group): + messages = await bot.send_media_group(chat_id, media_group) assert isinstance(messages, list) assert len(messages) == 3 assert all(isinstance(mes, Message) for mes in messages) @@ -442,9 +444,10 @@ def test_send_media_group_photo(self, bot, chat_id, media_group): ) @flaky(3, 1) - def test_send_media_group_all_args(self, bot, chat_id, media_group): - m1 = bot.send_message(chat_id, text="test") - messages = bot.send_media_group( + @pytest.mark.asyncio + async def test_send_media_group_all_args(self, bot, chat_id, media_group): + m1 = await bot.send_message(chat_id, text="test") + messages = await bot.send_media_group( chat_id, media_group, disable_notification=True, @@ -462,7 +465,8 @@ def test_send_media_group_all_args(self, bot, chat_id, media_group): assert all(mes.has_protected_content for mes in messages) @flaky(3, 1) - def test_send_media_group_custom_filename( + @pytest.mark.asyncio + async def test_send_media_group_custom_filename( self, bot, chat_id, @@ -472,10 +476,13 @@ def test_send_media_group_custom_filename( video_file, # noqa: F811 monkeypatch, ): - def make_assertion(url, data, **kwargs): - result = all(im.media.filename == 'custom_filename' for im in data['media']) - # We are a bit hacky here b/c Bot.send_media_group expects a list of Message-dicts - return [Message(0, None, None, text=result).to_dict()] + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + result = all( + field_tuple[0] == 'custom_filename' + for field_tuple in request_data.multipart_data.values() + ) + if result is True: + raise Exception('Test was successful') monkeypatch.setattr(bot.request, 'post', make_assertion) @@ -486,25 +493,28 @@ def make_assertion(url, data, **kwargs): InputMediaVideo(video_file, filename='custom_filename'), ] - assert bot.send_media_group(chat_id, media)[0].text is True + with pytest.raises(Exception, match='Test was successful'): + await bot.send_media_group(chat_id, media) - def test_send_media_group_with_thumbs( + @pytest.mark.asyncio + async def test_send_media_group_with_thumbs( self, bot, chat_id, video_file, photo_file, monkeypatch # noqa: F811 ): - def test(*args, **kwargs): - data = kwargs['fields'] - video_check = data[input_video.media.attach] == input_video.media.field_tuple - thumb_check = data[input_video.thumb.attach] == input_video.thumb.field_tuple + async def make_assertion(method, url, request_data: RequestData, *args, **kwargs): + files = request_data.multipart_data + video_check = files[input_video.media.attach_name] == input_video.media.field_tuple + thumb_check = files[input_video.thumb.attach_name] == input_video.thumb.field_tuple result = video_check and thumb_check raise Exception(f"Test was {'successful' if result else 'failing'}") - monkeypatch.setattr('telegram.request.Request._request_wrapper', test) + monkeypatch.setattr(bot.request, '_request_wrapper', make_assertion) input_video = InputMediaVideo(video_file, thumb=photo_file) with pytest.raises(Exception, match='Test was successful'): - bot.send_media_group(chat_id, [input_video, input_video]) + await bot.send_media_group(chat_id, [input_video, input_video]) @flaky(3, 1) # noqa: F811 - def test_send_media_group_new_files( + @pytest.mark.asyncio + async def test_send_media_group_new_files( self, bot, chat_id, @@ -512,8 +522,8 @@ def test_send_media_group_new_files( photo_file, # noqa: F811 animation_file, # noqa: F811 ): - def func(): - return bot.send_media_group( + async def func(): + return await bot.send_media_group( chat_id, [ InputMediaVideo(video_file), @@ -522,7 +532,7 @@ def func(): ], ) - messages = expect_bad_request( + messages = await expect_bad_request( func, 'Type of file mismatch', 'Telegram did not accept the file.' ) @@ -541,13 +551,14 @@ def func(): ], indirect=['default_bot'], ) - def test_send_media_group_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_media_group_default_allow_sending_without_reply( self, default_bot, chat_id, media_group, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - messages = default_bot.send_media_group( + messages = await default_bot.send_media_group( chat_id, media_group, allow_sending_without_reply=custom, @@ -555,63 +566,76 @@ def test_send_media_group_default_allow_sending_without_reply( ) assert [m.reply_to_message is None for m in messages] elif default_bot.defaults.allow_sending_without_reply: - messages = default_bot.send_media_group( + messages = await default_bot.send_media_group( chat_id, media_group, reply_to_message_id=reply_to_message.message_id ) assert [m.reply_to_message is None for m in messages] else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_media_group( + await default_bot.send_media_group( chat_id, media_group, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_media_group_default_protect_content(self, chat_id, media_group, default_bot): - protected = default_bot.send_media_group(chat_id, media_group) + async def test_send_media_group_default_protect_content( + self, chat_id, media_group, default_bot + ): + protected = await default_bot.send_media_group(chat_id, media_group) assert all(msg.has_protected_content for msg in protected) - unprotected = default_bot.send_media_group(chat_id, media_group, protect_content=False) + unprotected = await default_bot.send_media_group( + chat_id, media_group, protect_content=False + ) assert not all(msg.has_protected_content for msg in unprotected) @flaky(3, 1) - def test_edit_message_media(self, bot, chat_id, media_group): - messages = bot.send_media_group(chat_id, media_group) + @pytest.mark.asyncio + async def test_edit_message_media(self, bot, chat_id, media_group): + messages = await bot.send_media_group(chat_id, media_group) cid = messages[-1].chat.id mid = messages[-1].message_id - new_message = bot.edit_message_media(chat_id=cid, message_id=mid, media=media_group[0]) + new_message = await bot.edit_message_media( + chat_id=cid, message_id=mid, media=media_group[0] + ) assert isinstance(new_message, Message) @flaky(3, 1) - def test_edit_message_media_new_file(self, bot, chat_id, media_group, thumb_file): - messages = bot.send_media_group(chat_id, media_group) + @pytest.mark.asyncio + async def test_edit_message_media_new_file(self, bot, chat_id, media_group, thumb_file): + messages = await bot.send_media_group(chat_id, media_group) cid = messages[-1].chat.id mid = messages[-1].message_id - new_message = bot.edit_message_media( + new_message = await bot.edit_message_media( chat_id=cid, message_id=mid, media=InputMediaPhoto(thumb_file) ) assert isinstance(new_message, Message) - def test_edit_message_media_with_thumb( + @pytest.mark.asyncio + async def test_edit_message_media_with_thumb( self, bot, chat_id, video_file, photo_file, monkeypatch # noqa: F811 ): - def test(*args, **kwargs): - data = kwargs['fields'] - video_check = data[input_video.media.attach] == input_video.media.field_tuple - thumb_check = data[input_video.thumb.attach] == input_video.thumb.field_tuple + async def make_assertion( + method: str, url: str, request_data: RequestData = None, *args, **kwargs + ): + files = request_data.multipart_data + video_check = files[input_video.media.attach_name] == input_video.media.field_tuple + thumb_check = files[input_video.thumb.attach_name] == input_video.thumb.field_tuple result = video_check and thumb_check raise Exception(f"Test was {'successful' if result else 'failing'}") - monkeypatch.setattr('telegram.request.Request._request_wrapper', test) + monkeypatch.setattr(bot.request, '_request_wrapper', make_assertion) input_video = InputMediaVideo(video_file, thumb=photo_file) with pytest.raises(Exception, match='Test was successful'): - bot.edit_message_media(chat_id=chat_id, message_id=123, media=input_video) + await bot.edit_message_media(chat_id=chat_id, message_id=123, media=input_video) @flaky(3, 1) @pytest.mark.parametrize( 'default_bot', [{'parse_mode': ParseMode.HTML}], indirect=True, ids=['HTML-Bot'] ) @pytest.mark.parametrize('media_type', ['animation', 'document', 'audio', 'photo', 'video']) - def test_edit_message_media_default_parse_mode( + @pytest.mark.asyncio + async def test_edit_message_media_default_parse_mode( self, chat_id, default_bot, @@ -650,9 +674,9 @@ def build_media(parse_mode, med_type): if med_type == 'video': return InputMediaVideo(video, **kwargs) - message = default_bot.send_photo(chat_id, photo) + message = await default_bot.send_photo(chat_id, photo) - message = default_bot.edit_message_media( + message = await default_bot.edit_message_media( build_media(parse_mode=ParseMode.HTML, med_type=media_type), message.chat_id, message.message_id, @@ -661,9 +685,9 @@ def build_media(parse_mode, med_type): assert message.caption_entities == test_entities # Remove caption to avoid "Message not changed" - message.edit_caption() + await message.edit_caption() - message = default_bot.edit_message_media( + message = await default_bot.edit_message_media( build_media(parse_mode=ParseMode.MARKDOWN_V2, med_type=media_type), message.chat_id, message.message_id, @@ -672,9 +696,9 @@ def build_media(parse_mode, med_type): assert message.caption_entities == test_entities # Remove caption to avoid "Message not changed" - message.edit_caption() + await message.edit_caption() - message = default_bot.edit_message_media( + message = await default_bot.edit_message_media( build_media(parse_mode=None, med_type=media_type), message.chat_id, message.message_id, diff --git a/tests/test_invoice.py b/tests/test_invoice.py index e38321e8294..977e0d8dd77 100644 --- a/tests/test_invoice.py +++ b/tests/test_invoice.py @@ -21,6 +21,7 @@ from telegram import LabeledPrice, Invoice from telegram.error import BadRequest +from telegram.request import RequestData @pytest.fixture(scope='class') @@ -80,8 +81,9 @@ def test_to_dict(self, invoice): assert invoice_dict['total_amount'] == invoice.total_amount @flaky(3, 1) - def test_send_required_args_only(self, bot, chat_id, provider_token): - message = bot.send_invoice( + @pytest.mark.asyncio + async def test_send_required_args_only(self, bot, chat_id, provider_token): + message = await bot.send_invoice( chat_id=chat_id, title=self.title, description=self.description, @@ -98,8 +100,9 @@ def test_send_required_args_only(self, bot, chat_id, provider_token): assert message.invoice.total_amount == self.total_amount @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, provider_token, monkeypatch): - message = bot.send_invoice( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, provider_token, monkeypatch): + message = await bot.send_invoice( chat_id, self.title, self.description, @@ -137,7 +140,7 @@ def test_send_all_args(self, bot, chat_id, provider_token, monkeypatch): # We do this next one as safety guard to make sure that we pass all of the optional # parameters correctly because #2526 went unnoticed for 3 years … - def make_assertion(*args, **_): + async def make_assertion(*args, **_): kwargs = args[1] return ( kwargs['chat_id'] == 'chat_id' @@ -146,7 +149,7 @@ def make_assertion(*args, **_): and kwargs['payload'] == 'payload' and kwargs['provider_token'] == 'provider_token' and kwargs['currency'] == 'currency' - and kwargs['prices'] == [p.to_dict() for p in self.prices] + and kwargs['prices'] == self.prices and kwargs['max_tip_amount'] == 'max_tip_amount' and kwargs['suggested_tip_amounts'] == 'suggested_tip_amounts' and kwargs['start_parameter'] == 'start_parameter' @@ -164,8 +167,8 @@ def make_assertion(*args, **_): and kwargs['is_flexible'] == 'is_flexible' ) - monkeypatch.setattr(bot, '_message', make_assertion) - assert bot.send_invoice( + monkeypatch.setattr(bot, '_send_message', make_assertion) + assert await bot.send_invoice( chat_id='chat_id', title='title', description='description', @@ -192,14 +195,18 @@ def make_assertion(*args, **_): protect_content=True, ) - def test_send_object_as_provider_data(self, monkeypatch, bot, chat_id, provider_token): - def test(url, data, **kwargs): + @pytest.mark.asyncio + async def test_send_object_as_provider_data(self, monkeypatch, bot, chat_id, provider_token): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): # depends on whether we're using ujson - return data['provider_data'] in ['{"test_data": 123456789}', '{"test_data":123456789}'] + return request_data.json_parameters['provider_data'] in [ + '{"test_data": 123456789}', + '{"test_data":123456789}', + ] - monkeypatch.setattr(bot.request, 'post', test) + monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_invoice( + assert await bot.send_invoice( chat_id, self.title, self.description, @@ -221,13 +228,14 @@ def test(url, data, **kwargs): ], indirect=['default_bot'], ) - def test_send_invoice_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_invoice_default_allow_sending_without_reply( self, default_bot, chat_id, custom, provider_token ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_invoice( + message = await default_bot.send_invoice( chat_id, self.title, self.description, @@ -240,7 +248,7 @@ def test_send_invoice_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_invoice( + message = await default_bot.send_invoice( chat_id, self.title, self.description, @@ -253,7 +261,7 @@ def test_send_invoice_default_allow_sending_without_reply( assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_invoice( + await default_bot.send_invoice( chat_id, self.title, self.description, @@ -265,9 +273,12 @@ def test_send_invoice_default_allow_sending_without_reply( ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_invoice_default_protect_content(self, chat_id, default_bot, provider_token): - protected = default_bot.send_invoice( + async def test_send_invoice_default_protect_content( + self, chat_id, default_bot, provider_token + ): + protected = await default_bot.send_invoice( chat_id, self.title, self.description, @@ -277,7 +288,7 @@ def test_send_invoice_default_protect_content(self, chat_id, default_bot, provid self.prices, ) assert protected.has_protected_content - unprotected = default_bot.send_invoice( + unprotected = await default_bot.send_invoice( chat_id, self.title, self.description, diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py deleted file mode 100644 index 984e1cf51f7..00000000000 --- a/tests/test_jobqueue.py +++ /dev/null @@ -1,528 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import calendar -import datetime as dtm -import logging -import os -import platform -import time -from queue import Queue -from time import sleep - -import pytest -import pytz -from apscheduler.schedulers import SchedulerNotRunningError -from flaky import flaky -from telegram.ext import ( - JobQueue, - Job, - CallbackContext, - ContextTypes, - DispatcherBuilder, -) - - -class CustomContext(CallbackContext): - pass - - -@pytest.fixture(scope='function') -def job_queue(bot, _dp): - jq = JobQueue() - jq.set_dispatcher(_dp) - jq.start() - yield jq - jq.stop() - - -@pytest.mark.skipif( - os.getenv('GITHUB_ACTIONS', False) and platform.system() in ['Windows', 'Darwin'], - reason="On Windows & MacOS precise timings are not accurate.", -) -@flaky(10, 1) # Timings aren't quite perfect -class TestJobQueue: - result = 0 - job_time = 0 - received_error = None - - @pytest.fixture(autouse=True) - def reset(self): - self.result = 0 - self.job_time = 0 - self.received_error = None - - def job_run_once(self, context): - self.result += 1 - - def job_with_exception(self, context): - raise Exception('Test Error') - - def job_remove_self(self, context): - self.result += 1 - context.job.schedule_removal() - - def job_run_once_with_context(self, context): - self.result += context.job.context - - def job_datetime_tests(self, context): - self.job_time = time.time() - - def job_context_based_callback(self, context): - if ( - isinstance(context, CallbackContext) - and isinstance(context.job, Job) - and isinstance(context.update_queue, Queue) - and context.job.context == 2 - and context.chat_data is None - and context.user_data is None - and isinstance(context.bot_data, dict) - ): - self.result += 1 - - def error_handler_context(self, update, context): - self.received_error = (str(context.error), context.job) - - def error_handler_raise_error(self, *args): - raise Exception('Failing bigly') - - def test_slot_behaviour(self, job_queue, mro_slots, _dp): - for attr in job_queue.__slots__: - assert getattr(job_queue, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(job_queue)) == len(set(mro_slots(job_queue))), "duplicate slot" - - def test_dispatcher_weakref(self, bot): - jq = JobQueue() - dispatcher = DispatcherBuilder().bot(bot).job_queue(None).build() - with pytest.raises(RuntimeError, match='No dispatcher was set'): - jq.dispatcher - jq.set_dispatcher(dispatcher) - assert jq.dispatcher is dispatcher - del dispatcher - with pytest.raises(RuntimeError, match='no longer alive'): - jq.dispatcher - - def test_run_once(self, job_queue): - job_queue.run_once(self.job_run_once, 0.01) - sleep(0.02) - assert self.result == 1 - - def test_run_once_timezone(self, job_queue, timezone): - """Test the correct handling of aware datetimes""" - # we're parametrizing this with two different UTC offsets to exclude the possibility - # of an xpass when the test is run in a timezone with the same UTC offset - when = dtm.datetime.now(timezone) - job_queue.run_once(self.job_run_once, when) - sleep(0.001) - assert self.result == 1 - - def test_job_with_context(self, job_queue): - job_queue.run_once(self.job_run_once_with_context, 0.01, context=5) - sleep(0.02) - assert self.result == 5 - - def test_run_repeating(self, job_queue): - job_queue.run_repeating(self.job_run_once, 0.02) - sleep(0.05) - assert self.result == 2 - - def test_run_repeating_first(self, job_queue): - job_queue.run_repeating(self.job_run_once, 0.05, first=0.2) - sleep(0.15) - assert self.result == 0 - sleep(0.07) - assert self.result == 1 - - def test_run_repeating_first_timezone(self, job_queue, timezone): - """Test correct scheduling of job when passing a timezone-aware datetime as ``first``""" - job_queue.run_repeating( - self.job_run_once, 0.1, first=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.05) - ) - sleep(0.1) - assert self.result == 1 - - def test_run_repeating_last(self, job_queue): - job_queue.run_repeating(self.job_run_once, 0.05, last=0.06) - sleep(0.1) - assert self.result == 1 - sleep(0.1) - assert self.result == 1 - - def test_run_repeating_last_timezone(self, job_queue, timezone): - """Test correct scheduling of job when passing a timezone-aware datetime as ``first``""" - job_queue.run_repeating( - self.job_run_once, 0.05, last=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.06) - ) - sleep(0.1) - assert self.result == 1 - sleep(0.1) - assert self.result == 1 - - def test_run_repeating_last_before_first(self, job_queue): - with pytest.raises(ValueError, match="'last' must not be before 'first'!"): - job_queue.run_repeating(self.job_run_once, 0.05, first=1, last=0.5) - - def test_run_repeating_timedelta(self, job_queue): - job_queue.run_repeating(self.job_run_once, dtm.timedelta(minutes=3.3333e-4)) - sleep(0.05) - assert self.result == 2 - - def test_run_custom(self, job_queue): - job_queue.run_custom(self.job_run_once, {'trigger': 'interval', 'seconds': 0.02}) - sleep(0.05) - assert self.result == 2 - - def test_multiple(self, job_queue): - job_queue.run_once(self.job_run_once, 0.01) - job_queue.run_once(self.job_run_once, 0.02) - job_queue.run_repeating(self.job_run_once, 0.02) - sleep(0.055) - assert self.result == 4 - - def test_disabled(self, job_queue): - j1 = job_queue.run_once(self.job_run_once, 0.1) - j2 = job_queue.run_repeating(self.job_run_once, 0.05) - - j1.enabled = False - j2.enabled = False - - sleep(0.06) - - assert self.result == 0 - - j1.enabled = True - - sleep(0.2) - - assert self.result == 1 - - def test_schedule_removal(self, job_queue): - j1 = job_queue.run_once(self.job_run_once, 0.03) - j2 = job_queue.run_repeating(self.job_run_once, 0.02) - - sleep(0.025) - - j1.schedule_removal() - j2.schedule_removal() - - sleep(0.04) - - assert self.result == 1 - - def test_schedule_removal_from_within(self, job_queue): - job_queue.run_repeating(self.job_remove_self, 0.01) - - sleep(0.05) - - assert self.result == 1 - - def test_longer_first(self, job_queue): - job_queue.run_once(self.job_run_once, 0.02) - job_queue.run_once(self.job_run_once, 0.01) - - sleep(0.015) - - assert self.result == 1 - - def test_error(self, job_queue): - job_queue.run_repeating(self.job_with_exception, 0.01) - job_queue.run_repeating(self.job_run_once, 0.02) - sleep(0.03) - assert self.result == 1 - - def test_in_dispatcher(self, bot): - dispatcher = DispatcherBuilder().bot(bot).build() - dispatcher.job_queue.start() - try: - dispatcher.job_queue.run_repeating(self.job_run_once, 0.02) - sleep(0.03) - assert self.result == 1 - dispatcher.stop() - sleep(1) - assert self.result == 1 - finally: - try: - dispatcher.stop() - except SchedulerNotRunningError: - pass - - def test_time_unit_int(self, job_queue): - # Testing seconds in int - delta = 0.05 - expected_time = time.time() + delta - - job_queue.run_once(self.job_datetime_tests, delta) - sleep(0.06) - assert pytest.approx(self.job_time) == expected_time - - def test_time_unit_dt_timedelta(self, job_queue): - # Testing seconds, minutes and hours as datetime.timedelta object - # This is sufficient to test that it actually works. - interval = dtm.timedelta(seconds=0.05) - expected_time = time.time() + interval.total_seconds() - - job_queue.run_once(self.job_datetime_tests, interval) - sleep(0.06) - assert pytest.approx(self.job_time) == expected_time - - def test_time_unit_dt_datetime(self, job_queue): - # Testing running at a specific datetime - delta, now = dtm.timedelta(seconds=0.05), dtm.datetime.now(pytz.utc) - when = now + delta - expected_time = (now + delta).timestamp() - - job_queue.run_once(self.job_datetime_tests, when) - sleep(0.06) - assert self.job_time == pytest.approx(expected_time) - - def test_time_unit_dt_time_today(self, job_queue): - # Testing running at a specific time today - delta, now = 0.05, dtm.datetime.now(pytz.utc) - expected_time = now + dtm.timedelta(seconds=delta) - when = expected_time.time() - expected_time = expected_time.timestamp() - - job_queue.run_once(self.job_datetime_tests, when) - sleep(0.06) - assert self.job_time == pytest.approx(expected_time) - - def test_time_unit_dt_time_tomorrow(self, job_queue): - # Testing running at a specific time that has passed today. Since we can't wait a day, we - # test if the job's next scheduled execution time has been calculated correctly - delta, now = -2, dtm.datetime.now(pytz.utc) - when = (now + dtm.timedelta(seconds=delta)).time() - expected_time = (now + dtm.timedelta(seconds=delta, days=1)).timestamp() - - job_queue.run_once(self.job_datetime_tests, when) - scheduled_time = job_queue.jobs()[0].next_t.timestamp() - assert scheduled_time == pytest.approx(expected_time) - - def test_run_daily(self, job_queue): - delta, now = 1, dtm.datetime.now(pytz.utc) - time_of_day = (now + dtm.timedelta(seconds=delta)).time() - expected_reschedule_time = (now + dtm.timedelta(seconds=delta, days=1)).timestamp() - - job_queue.run_daily(self.job_run_once, time_of_day) - sleep(delta + 0.1) - assert self.result == 1 - scheduled_time = job_queue.jobs()[0].next_t.timestamp() - assert scheduled_time == pytest.approx(expected_reschedule_time) - - def test_run_monthly(self, job_queue, timezone): - delta, now = 1, dtm.datetime.now(timezone) - expected_reschedule_time = now + dtm.timedelta(seconds=delta) - time_of_day = expected_reschedule_time.time().replace(tzinfo=timezone) - - day = now.day - this_months_days = calendar.monthrange(now.year, now.month)[1] - if now.month == 12: - next_months_days = calendar.monthrange(now.year + 1, 1)[1] - else: - next_months_days = calendar.monthrange(now.year, now.month + 1)[1] - - expected_reschedule_time += dtm.timedelta(this_months_days) - if day > next_months_days: - expected_reschedule_time += dtm.timedelta(next_months_days) - - expected_reschedule_time = timezone.normalize(expected_reschedule_time) - # Adjust the hour for the special case that between now and next month a DST switch happens - expected_reschedule_time += dtm.timedelta( - hours=time_of_day.hour - expected_reschedule_time.hour - ) - expected_reschedule_time = expected_reschedule_time.timestamp() - - job_queue.run_monthly(self.job_run_once, time_of_day, day) - sleep(delta + 0.1) - assert self.result == 1 - scheduled_time = job_queue.jobs()[0].next_t.timestamp() - assert scheduled_time == pytest.approx(expected_reschedule_time, rel=1e-3) - - def test_run_monthly_non_strict_day(self, job_queue, timezone): - delta, now = 1, dtm.datetime.now(timezone) - expected_reschedule_time = now + dtm.timedelta(seconds=delta) - time_of_day = expected_reschedule_time.time().replace(tzinfo=timezone) - - expected_reschedule_time += dtm.timedelta( - calendar.monthrange(now.year, now.month)[1] - ) - dtm.timedelta(days=now.day) - # Adjust the hour for the special case that between now & end of month a DST switch happens - expected_reschedule_time = timezone.normalize(expected_reschedule_time) - expected_reschedule_time += dtm.timedelta( - hours=time_of_day.hour - expected_reschedule_time.hour - ) - expected_reschedule_time = expected_reschedule_time.timestamp() - - job_queue.run_monthly(self.job_run_once, time_of_day, -1) - scheduled_time = job_queue.jobs()[0].next_t.timestamp() - assert scheduled_time == pytest.approx(expected_reschedule_time) - - def test_default_tzinfo(self, _dp, tz_bot): - # we're parametrizing this with two different UTC offsets to exclude the possibility - # of an xpass when the test is run in a timezone with the same UTC offset - jq = JobQueue() - original_bot = _dp.bot - _dp.bot = tz_bot - jq.set_dispatcher(_dp) - try: - jq.start() - - when = dtm.datetime.now(tz_bot.defaults.tzinfo) + dtm.timedelta(seconds=0.0005) - jq.run_once(self.job_run_once, when.time()) - sleep(0.001) - assert self.result == 1 - - jq.stop() - finally: - _dp.bot = original_bot - - def test_get_jobs(self, job_queue): - callback = self.job_context_based_callback - - job1 = job_queue.run_once(callback, 10, name='name1') - job2 = job_queue.run_once(callback, 10, name='name1') - job3 = job_queue.run_once(callback, 10, name='name2') - - assert job_queue.jobs() == (job1, job2, job3) - assert job_queue.get_jobs_by_name('name1') == (job1, job2) - assert job_queue.get_jobs_by_name('name2') == (job3,) - - def test_job_run(self, _dp): - job_queue = JobQueue() - job_queue.set_dispatcher(_dp) - job = job_queue.run_repeating(self.job_context_based_callback, 0.02, context=2) - assert self.result == 0 - job.run(_dp) - assert self.result == 1 - - def test_enable_disable_job(self, job_queue): - job = job_queue.run_repeating(self.job_run_once, 0.02) - sleep(0.05) - assert self.result == 2 - job.enabled = False - assert not job.enabled - sleep(0.05) - assert self.result == 2 - job.enabled = True - assert job.enabled - sleep(0.05) - assert self.result == 4 - - def test_remove_job(self, job_queue): - job = job_queue.run_repeating(self.job_run_once, 0.02) - sleep(0.05) - assert self.result == 2 - assert not job.removed - job.schedule_removal() - assert job.removed - sleep(0.05) - assert self.result == 2 - - def test_job_lt_eq(self, job_queue): - job = job_queue.run_repeating(self.job_run_once, 0.02) - assert not job == job_queue - assert not job < job - - def test_dispatch_error_context(self, job_queue, dp): - dp.add_error_handler(self.error_handler_context) - - job = job_queue.run_once(self.job_with_exception, 0.05) - sleep(0.1) - assert self.received_error[0] == 'Test Error' - assert self.received_error[1] is job - self.received_error = None - job.run(dp) - assert self.received_error[0] == 'Test Error' - assert self.received_error[1] is job - - # Remove handler - dp.remove_error_handler(self.error_handler_context) - self.received_error = None - - job = job_queue.run_once(self.job_with_exception, 0.05) - sleep(0.1) - assert self.received_error is None - job.run(dp) - assert self.received_error is None - - def test_dispatch_error_that_raises_errors(self, job_queue, dp, caplog): - dp.add_error_handler(self.error_handler_raise_error) - - with caplog.at_level(logging.ERROR): - job = job_queue.run_once(self.job_with_exception, 0.05) - sleep(0.1) - assert len(caplog.records) == 1 - rec = caplog.records[-1] - assert 'An error was raised and an uncaught' in rec.getMessage() - caplog.clear() - - with caplog.at_level(logging.ERROR): - job.run(dp) - assert len(caplog.records) == 1 - rec = caplog.records[-1] - assert 'uncaught error was raised while handling' in rec.getMessage() - caplog.clear() - - # Remove handler - dp.remove_error_handler(self.error_handler_raise_error) - self.received_error = None - - with caplog.at_level(logging.ERROR): - job = job_queue.run_once(self.job_with_exception, 0.05) - sleep(0.1) - assert len(caplog.records) == 1 - rec = caplog.records[-1] - assert 'No error handlers are registered' in rec.getMessage() - caplog.clear() - - with caplog.at_level(logging.ERROR): - job.run(dp) - assert len(caplog.records) == 1 - rec = caplog.records[-1] - assert 'No error handlers are registered' in rec.getMessage() - - def test_custom_context(self, bot, job_queue): - dispatcher = ( - DispatcherBuilder() - .bot(bot) - .context_types( - ContextTypes( - context=CustomContext, bot_data=int, user_data=float, chat_data=complex - ) - ) - .build() - ) - job_queue.set_dispatcher(dispatcher) - - def callback(context): - self.result = ( - type(context), - context.user_data, - context.chat_data, - type(context.bot_data), - ) - - job_queue.run_once(callback, 0.1) - sleep(0.15) - assert self.result == (CustomContext, None, None, int) - - def test_attribute_error(self): - job = Job(self.job_run_once) - with pytest.raises( - AttributeError, match="nor 'apscheduler.job.Job' has attribute 'error'" - ): - job.error diff --git a/tests/test_location.py b/tests/test_location.py index 6514ae7e9a5..a8b3d83db3b 100644 --- a/tests/test_location.py +++ b/tests/test_location.py @@ -21,6 +21,7 @@ from telegram import Location from telegram.error import BadRequest +from telegram.request import RequestData @pytest.fixture(scope='class') @@ -68,8 +69,9 @@ def test_de_json(self, bot): @flaky(3, 1) @pytest.mark.xfail - def test_send_live_location(self, bot, chat_id): - message = bot.send_location( + @pytest.mark.asyncio + async def test_send_live_location(self, bot, chat_id): + message = await bot.send_location( chat_id=chat_id, latitude=52.223880, longitude=5.166146, @@ -88,7 +90,7 @@ def test_send_live_location(self, bot, chat_id): assert message.location.proximity_alert_radius == 1000 assert message.has_protected_content - message2 = bot.edit_message_live_location( + message2 = await bot.edit_message_live_location( message.chat_id, message.message_id, latitude=52.223098, @@ -104,25 +106,27 @@ def test_send_live_location(self, bot, chat_id): assert message2.location.heading == 10 assert message2.location.proximity_alert_radius == 500 - bot.stop_message_live_location(message.chat_id, message.message_id) + await bot.stop_message_live_location(message.chat_id, message.message_id) with pytest.raises(BadRequest, match="Message can't be edited"): - bot.edit_message_live_location( + await bot.edit_message_live_location( message.chat_id, message.message_id, latitude=52.223880, longitude=5.164306 ) # TODO: Needs improvement with in inline sent live location. - def test_edit_live_inline_message(self, monkeypatch, bot, location): - def make_assertion(url, data, **kwargs): - lat = data['latitude'] == location.latitude - lon = data['longitude'] == location.longitude - id_ = data['inline_message_id'] == 1234 - ha = data['horizontal_accuracy'] == 50 - heading = data['heading'] == 90 - prox_alert = data['proximity_alert_radius'] == 1000 + @pytest.mark.asyncio + async def test_edit_live_inline_message(self, monkeypatch, bot, location): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters + lat = data['latitude'] == str(location.latitude) + lon = data['longitude'] == str(location.longitude) + id_ = data['inline_message_id'] == '1234' + ha = data['horizontal_accuracy'] == '50' + heading = data['heading'] == '90' + prox_alert = data['proximity_alert_radius'] == '1000' return lat and lon and id_ and ha and heading and prox_alert monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.edit_message_live_location( + assert await bot.edit_message_live_location( inline_message_id=1234, location=location, horizontal_accuracy=50, @@ -131,22 +135,24 @@ def make_assertion(url, data, **kwargs): ) # TODO: Needs improvement with in inline sent live location. - def test_stop_live_inline_message(self, monkeypatch, bot): - def test(url, data, **kwargs): - id_ = data['inline_message_id'] == 1234 + @pytest.mark.asyncio + async def test_stop_live_inline_message(self, monkeypatch, bot): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + id_ = request_data.json_parameters['inline_message_id'] == '1234' return id_ - monkeypatch.setattr(bot.request, 'post', test) - assert bot.stop_message_live_location(inline_message_id=1234) + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.stop_message_live_location(inline_message_id=1234) - def test_send_with_location(self, monkeypatch, bot, chat_id, location): - def test(url, data, **kwargs): - lat = data['latitude'] == location.latitude - lon = data['longitude'] == location.longitude + @pytest.mark.asyncio + async def test_send_with_location(self, monkeypatch, bot, chat_id, location): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + lat = request_data.json_parameters['latitude'] == str(location.latitude) + lon = request_data.json_parameters['longitude'] == str(location.longitude) return lat and lon - monkeypatch.setattr(bot.request, 'post', test) - assert bot.send_location(location=location, chat_id=chat_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.send_location(location=location, chat_id=chat_id) @flaky(3, 1) @pytest.mark.parametrize( @@ -158,13 +164,14 @@ def test(url, data, **kwargs): ], indirect=['default_bot'], ) - def test_send_location_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_location_default_allow_sending_without_reply( self, default_bot, chat_id, location, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_location( + message = await default_bot.send_location( chat_id, location=location, allow_sending_without_reply=custom, @@ -172,48 +179,56 @@ def test_send_location_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_location( + message = await default_bot.send_location( chat_id, location=location, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_location( + await default_bot.send_location( chat_id, location=location, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_location_default_protect_content(self, chat_id, default_bot, location): - protected = default_bot.send_location(chat_id, location=location) + async def test_send_location_default_protect_content(self, chat_id, default_bot, location): + protected = await default_bot.send_location(chat_id, location=location) assert protected.has_protected_content - unprotected = default_bot.send_location(chat_id, location=location, protect_content=False) + unprotected = await default_bot.send_location( + chat_id, location=location, protect_content=False + ) assert not unprotected.has_protected_content - def test_edit_live_location_with_location(self, monkeypatch, bot, location): - def test(url, data, **kwargs): - lat = data['latitude'] == location.latitude - lon = data['longitude'] == location.longitude + @pytest.mark.asyncio + async def test_edit_live_location_with_location(self, monkeypatch, bot, location): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + lat = request_data.json_parameters['latitude'] == str(location.latitude) + lon = request_data.json_parameters['longitude'] == str(location.longitude) return lat and lon - monkeypatch.setattr(bot.request, 'post', test) - assert bot.edit_message_live_location(None, None, location=location) + monkeypatch.setattr(bot.request, 'post', make_assertion) + assert await bot.edit_message_live_location(None, None, location=location) - def test_send_location_without_required(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_location_without_required(self, bot, chat_id): with pytest.raises(ValueError, match='Either location or latitude and longitude'): - bot.send_location(chat_id=chat_id) + await bot.send_location(chat_id=chat_id) - def test_edit_location_without_required(self, bot): + @pytest.mark.asyncio + async def test_edit_location_without_required(self, bot): with pytest.raises(ValueError, match='Either location or latitude and longitude'): - bot.edit_message_live_location(chat_id=2, message_id=3) + await bot.edit_message_live_location(chat_id=2, message_id=3) - def test_send_location_with_all_args(self, bot, location): + @pytest.mark.asyncio + async def test_send_location_with_all_args(self, bot, location): with pytest.raises(ValueError, match='Not both'): - bot.send_location(chat_id=1, latitude=2.5, longitude=4.6, location=location) + await bot.send_location(chat_id=1, latitude=2.5, longitude=4.6, location=location) - def test_edit_location_with_all_args(self, bot, location): + @pytest.mark.asyncio + async def test_edit_location_with_all_args(self, bot, location): with pytest.raises(ValueError, match='Not both'): - bot.edit_message_live_location( + await bot.edit_message_live_location( chat_id=1, message_id=7, latitude=2.5, longitude=4.6, location=location ) diff --git a/tests/test_message.py b/tests/test_message.py index 9583bbdbdc2..daa7d377817 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -163,6 +163,7 @@ def message(bot): ] }, }, + {'quote': True}, {'dice': Dice(4, '🎲')}, {'via_bot': User(9, 'A_Bot', True)}, { @@ -222,6 +223,7 @@ def message(bot): 'passport_data', 'poll', 'reply_markup', + 'default_quote', 'dice', 'via_bot', 'proximity_alert_triggered', @@ -311,17 +313,13 @@ class TestMessage: caption_entities=[MessageEntity(**e) for e in test_entities_v2], ) - def test_slot_behaviour(self, message, mro_slots): - for attr in message.__slots__: - assert getattr(message, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(message)) == len(set(mro_slots(message))), "duplicate slot" - def test_all_possibilities_de_json_and_to_dict(self, bot, message_params): new = Message.de_json(message_params.to_dict(), bot) assert new.to_dict() == message_params.to_dict() - def test_parse_entity(self): + @pytest.mark.asyncio + async def test_parse_entity(self): text = ( b'\\U0001f469\\u200d\\U0001f469\\u200d\\U0001f467' b'\\u200d\\U0001f467\\U0001f431http://google.com' @@ -330,7 +328,8 @@ def test_parse_entity(self): message = Message(1, self.from_user, self.date, self.chat, text=text, entities=[entity]) assert message.parse_entity(entity) == 'http://google.com' - def test_parse_caption_entity(self): + @pytest.mark.asyncio + async def test_parse_caption_entity(self): caption = ( b'\\U0001f469\\u200d\\U0001f469\\u200d\\U0001f467' b'\\u200d\\U0001f467\\U0001f431http://google.com' @@ -341,7 +340,8 @@ def test_parse_caption_entity(self): ) assert message.parse_caption_entity(entity) == 'http://google.com' - def test_parse_entities(self): + @pytest.mark.asyncio + async def test_parse_entities(self): text = ( b'\\U0001f469\\u200d\\U0001f469\\u200d\\U0001f467' b'\\u200d\\U0001f467\\U0001f431http://google.com' @@ -354,7 +354,8 @@ def test_parse_entities(self): assert message.parse_entities(MessageEntity.URL) == {entity: 'http://google.com'} assert message.parse_entities() == {entity: 'http://google.com', entity_2: 'h'} - def test_parse_caption_entities(self): + @pytest.mark.asyncio + async def test_parse_caption_entities(self): text = ( b'\\U0001f469\\u200d\\U0001f469\\u200d\\U0001f467' b'\\u200d\\U0001f467\\U0001f431http://google.com' @@ -370,7 +371,10 @@ def test_parse_caption_entities(self): caption_entities=[entity_2, entity], ) assert message.parse_caption_entities(MessageEntity.URL) == {entity: 'http://google.com'} - assert message.parse_caption_entities() == {entity: 'http://google.com', entity_2: 'h'} + assert message.parse_caption_entities() == { + entity: 'http://google.com', + entity_2: 'h', + } def test_text_html_simple(self): test_html_string = ( @@ -603,7 +607,8 @@ def test_caption_markdown_emoji(self): ) assert expected == message.caption_markdown - def test_parse_entities_url_emoji(self): + @pytest.mark.asyncio + async def test_parse_entities_url_emoji(self): url = b'http://github.com/?unicode=\\u2713\\U0001f469'.decode('unicode-escape') text = 'some url' link_entity = MessageEntity(type=MessageEntity.URL, offset=0, length=8, url=url) @@ -616,26 +621,26 @@ def test_parse_entities_url_emoji(self): def test_chat_id(self, message): assert message.chat_id == message.chat.id - @pytest.mark.parametrize('_type', argvalues=[Chat.SUPERGROUP, Chat.CHANNEL]) - def test_link_with_username(self, message, _type): + @pytest.mark.parametrize('type', argvalues=[Chat.SUPERGROUP, Chat.CHANNEL]) + def test_link_with_username(self, message, type): message.chat.username = 'username' - message.chat.type = _type + message.chat.type = type assert message.link == f'https://t.me/{message.chat.username}/{message.message_id}' @pytest.mark.parametrize( - '_type, _id', argvalues=[(Chat.CHANNEL, -1003), (Chat.SUPERGROUP, -1003)] + 'type, id', argvalues=[(Chat.CHANNEL, -1003), (Chat.SUPERGROUP, -1003)] ) - def test_link_with_id(self, message, _type, _id): + def test_link_with_id(self, message, type, id): message.chat.username = None - message.chat.id = _id - message.chat.type = _type + message.chat.id = id + message.chat.type = type # The leading - for group ids/ -100 for supergroup ids isn't supposed to be in the link assert message.link == f'https://t.me/c/{3}/{message.message_id}' - @pytest.mark.parametrize('_id, username', argvalues=[(None, 'username'), (-3, None)]) - def test_link_private_chats(self, message, _id, username): + @pytest.mark.parametrize('id, username', argvalues=[(None, 'username'), (-3, None)]) + def test_link_private_chats(self, message, id, username): message.chat.type = Chat.PRIVATE - message.chat.id = _id + message.chat.id = id message.chat.username = username assert message.link is None message.chat.type = Chat.GROUP @@ -677,8 +682,9 @@ def test_effective_attachment(self, message_params): ) assert not condition, 'effective_attachment was None even though it should not be' - def test_reply_text(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_text(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id text = kwargs['text'] == 'test' if kwargs.get('reply_to_message_id') is not None: @@ -690,15 +696,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_text, Bot.send_message, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') - assert check_defaults_handling(message.reply_text, message.get_bot()) + assert await check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') + assert await check_defaults_handling(message.reply_text, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_message', make_assertion) - assert message.reply_text('test') - assert message.reply_text('test', quote=True) - assert message.reply_text('test', reply_to_message_id=message.message_id, quote=True) + assert await message.reply_text('test') + assert await message.reply_text('test', quote=True) + assert await message.reply_text('test', reply_to_message_id=message.message_id, quote=True) - def test_reply_markdown(self, monkeypatch, message): + @pytest.mark.asyncio + async def test_reply_markdown(self, monkeypatch, message): test_md_string = ( r'Test for <*bold*, _ita_\__lic_, `code`, ' '[links](http://github.com/ab_), ' @@ -706,7 +713,7 @@ def test_reply_markdown(self, monkeypatch, message): r'http://google.com/ab\_' ) - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): cid = kwargs['chat_id'] == message.chat_id markdown_text = kwargs['text'] == test_md_string markdown_enabled = kwargs['parse_mode'] == ParseMode.MARKDOWN @@ -719,20 +726,21 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_markdown, Bot.send_message, ['chat_id', 'parse_mode'], ['quote'] ) - assert check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') - assert check_defaults_handling(message.reply_text, message.get_bot()) + assert await check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') + assert await check_defaults_handling(message.reply_text, message.get_bot()) text_markdown = self.test_message.text_markdown assert text_markdown == test_md_string monkeypatch.setattr(message.get_bot(), 'send_message', make_assertion) - assert message.reply_markdown(self.test_message.text_markdown) - assert message.reply_markdown(self.test_message.text_markdown, quote=True) - assert message.reply_markdown( + assert await message.reply_markdown(self.test_message.text_markdown) + assert await message.reply_markdown(self.test_message.text_markdown, quote=True) + assert await message.reply_markdown( self.test_message.text_markdown, reply_to_message_id=message.message_id, quote=True ) - def test_reply_markdown_v2(self, monkeypatch, message): + @pytest.mark.asyncio + async def test_reply_markdown_v2(self, monkeypatch, message): test_md_string = ( r'__Test__ for <*bold*, _ita\_lic_, `\\\`code`, ' '[links](http://github.com/abc\\\\\\)def), ' @@ -741,7 +749,7 @@ def test_reply_markdown_v2(self, monkeypatch, message): '```python\nPython pre```\\. ||Spoiled||\\.' ) - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): cid = kwargs['chat_id'] == message.chat_id markdown_text = kwargs['text'] == test_md_string markdown_enabled = kwargs['parse_mode'] == ParseMode.MARKDOWN_V2 @@ -754,22 +762,23 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_markdown_v2, Bot.send_message, ['chat_id', 'parse_mode'], ['quote'] ) - assert check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') - assert check_defaults_handling(message.reply_text, message.get_bot()) + assert await check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') + assert await check_defaults_handling(message.reply_text, message.get_bot()) text_markdown = self.test_message_v2.text_markdown_v2 assert text_markdown == test_md_string monkeypatch.setattr(message.get_bot(), 'send_message', make_assertion) - assert message.reply_markdown_v2(self.test_message_v2.text_markdown_v2) - assert message.reply_markdown_v2(self.test_message_v2.text_markdown_v2, quote=True) - assert message.reply_markdown_v2( + assert await message.reply_markdown_v2(self.test_message_v2.text_markdown_v2) + assert await message.reply_markdown_v2(self.test_message_v2.text_markdown_v2, quote=True) + assert await message.reply_markdown_v2( self.test_message_v2.text_markdown_v2, reply_to_message_id=message.message_id, quote=True, ) - def test_reply_html(self, monkeypatch, message): + @pytest.mark.asyncio + async def test_reply_html(self, monkeypatch, message): test_html_string = ( 'Test for <bold, ita_lic, ' r'\`code, ' @@ -781,7 +790,7 @@ def test_reply_html(self, monkeypatch, message): 'Spoiled.' ) - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): cid = kwargs['chat_id'] == message.chat_id html_text = kwargs['text'] == test_html_string html_enabled = kwargs['parse_mode'] == ParseMode.HTML @@ -794,21 +803,22 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_html, Bot.send_message, ['chat_id', 'parse_mode'], ['quote'] ) - assert check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') - assert check_defaults_handling(message.reply_text, message.get_bot()) + assert await check_shortcut_call(message.reply_text, message.get_bot(), 'send_message') + assert await check_defaults_handling(message.reply_text, message.get_bot()) text_html = self.test_message_v2.text_html assert text_html == test_html_string monkeypatch.setattr(message.get_bot(), 'send_message', make_assertion) - assert message.reply_html(self.test_message_v2.text_html) - assert message.reply_html(self.test_message_v2.text_html, quote=True) - assert message.reply_html( + assert await message.reply_html(self.test_message_v2.text_html) + assert await message.reply_html(self.test_message_v2.text_html, quote=True) + assert await message.reply_html( self.test_message_v2.text_html, reply_to_message_id=message.message_id, quote=True ) - def test_reply_media_group(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_media_group(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id media = kwargs['media'] == 'reply_media_group' if kwargs.get('reply_to_message_id') is not None: @@ -820,17 +830,18 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_media_group, Bot.send_media_group, ['chat_id'], ['quote'] ) - assert check_shortcut_call( + assert await check_shortcut_call( message.reply_media_group, message.get_bot(), 'send_media_group' ) - assert check_defaults_handling(message.reply_media_group, message.get_bot()) + assert await check_defaults_handling(message.reply_media_group, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_media_group', make_assertion) - assert message.reply_media_group(media='reply_media_group') - assert message.reply_media_group(media='reply_media_group', quote=True) + assert await message.reply_media_group(media='reply_media_group') + assert await message.reply_media_group(media='reply_media_group', quote=True) - def test_reply_photo(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_photo(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id photo = kwargs['photo'] == 'test_photo' if kwargs.get('reply_to_message_id') is not None: @@ -842,15 +853,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_photo, Bot.send_photo, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_photo, message.get_bot(), 'send_photo') - assert check_defaults_handling(message.reply_photo, message.get_bot()) + assert await check_shortcut_call(message.reply_photo, message.get_bot(), 'send_photo') + assert await check_defaults_handling(message.reply_photo, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_photo', make_assertion) - assert message.reply_photo(photo='test_photo') - assert message.reply_photo(photo='test_photo', quote=True) + assert await message.reply_photo(photo='test_photo') + assert await message.reply_photo(photo='test_photo', quote=True) - def test_reply_audio(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_audio(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id audio = kwargs['audio'] == 'test_audio' if kwargs.get('reply_to_message_id') is not None: @@ -862,15 +874,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_audio, Bot.send_audio, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_audio, message.get_bot(), 'send_audio') - assert check_defaults_handling(message.reply_audio, message.get_bot()) + assert await check_shortcut_call(message.reply_audio, message.get_bot(), 'send_audio') + assert await check_defaults_handling(message.reply_audio, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_audio', make_assertion) - assert message.reply_audio(audio='test_audio') - assert message.reply_audio(audio='test_audio', quote=True) + assert await message.reply_audio(audio='test_audio') + assert await message.reply_audio(audio='test_audio', quote=True) - def test_reply_document(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_document(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id document = kwargs['document'] == 'test_document' if kwargs.get('reply_to_message_id') is not None: @@ -882,15 +895,18 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_document, Bot.send_document, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_document, message.get_bot(), 'send_document') - assert check_defaults_handling(message.reply_document, message.get_bot()) + assert await check_shortcut_call( + message.reply_document, message.get_bot(), 'send_document' + ) + assert await check_defaults_handling(message.reply_document, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_document', make_assertion) - assert message.reply_document(document='test_document') - assert message.reply_document(document='test_document', quote=True) + assert await message.reply_document(document='test_document') + assert await message.reply_document(document='test_document', quote=True) - def test_reply_animation(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_animation(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id animation = kwargs['animation'] == 'test_animation' if kwargs.get('reply_to_message_id') is not None: @@ -902,15 +918,18 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_animation, Bot.send_animation, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_animation, message.get_bot(), 'send_animation') - assert check_defaults_handling(message.reply_animation, message.get_bot()) + assert await check_shortcut_call( + message.reply_animation, message.get_bot(), 'send_animation' + ) + assert await check_defaults_handling(message.reply_animation, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_animation', make_assertion) - assert message.reply_animation(animation='test_animation') - assert message.reply_animation(animation='test_animation', quote=True) + assert await message.reply_animation(animation='test_animation') + assert await message.reply_animation(animation='test_animation', quote=True) - def test_reply_sticker(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_sticker(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id sticker = kwargs['sticker'] == 'test_sticker' if kwargs.get('reply_to_message_id') is not None: @@ -922,15 +941,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_sticker, Bot.send_sticker, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_sticker, message.get_bot(), 'send_sticker') - assert check_defaults_handling(message.reply_sticker, message.get_bot()) + assert await check_shortcut_call(message.reply_sticker, message.get_bot(), 'send_sticker') + assert await check_defaults_handling(message.reply_sticker, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_sticker', make_assertion) - assert message.reply_sticker(sticker='test_sticker') - assert message.reply_sticker(sticker='test_sticker', quote=True) + assert await message.reply_sticker(sticker='test_sticker') + assert await message.reply_sticker(sticker='test_sticker', quote=True) - def test_reply_video(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_video(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id video = kwargs['video'] == 'test_video' if kwargs.get('reply_to_message_id') is not None: @@ -942,15 +962,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_video, Bot.send_video, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_video, message.get_bot(), 'send_video') - assert check_defaults_handling(message.reply_video, message.get_bot()) + assert await check_shortcut_call(message.reply_video, message.get_bot(), 'send_video') + assert await check_defaults_handling(message.reply_video, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_video', make_assertion) - assert message.reply_video(video='test_video') - assert message.reply_video(video='test_video', quote=True) + assert await message.reply_video(video='test_video') + assert await message.reply_video(video='test_video', quote=True) - def test_reply_video_note(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_video_note(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id video_note = kwargs['video_note'] == 'test_video_note' if kwargs.get('reply_to_message_id') is not None: @@ -962,15 +983,18 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_video_note, Bot.send_video_note, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_video_note, message.get_bot(), 'send_video_note') - assert check_defaults_handling(message.reply_video_note, message.get_bot()) + assert await check_shortcut_call( + message.reply_video_note, message.get_bot(), 'send_video_note' + ) + assert await check_defaults_handling(message.reply_video_note, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_video_note', make_assertion) - assert message.reply_video_note(video_note='test_video_note') - assert message.reply_video_note(video_note='test_video_note', quote=True) + assert await message.reply_video_note(video_note='test_video_note') + assert await message.reply_video_note(video_note='test_video_note', quote=True) - def test_reply_voice(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_voice(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id voice = kwargs['voice'] == 'test_voice' if kwargs.get('reply_to_message_id') is not None: @@ -982,15 +1006,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_voice, Bot.send_voice, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_voice, message.get_bot(), 'send_voice') - assert check_defaults_handling(message.reply_voice, message.get_bot()) + assert await check_shortcut_call(message.reply_voice, message.get_bot(), 'send_voice') + assert await check_defaults_handling(message.reply_voice, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_voice', make_assertion) - assert message.reply_voice(voice='test_voice') - assert message.reply_voice(voice='test_voice', quote=True) + assert await message.reply_voice(voice='test_voice') + assert await message.reply_voice(voice='test_voice', quote=True) - def test_reply_location(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_location(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id location = kwargs['location'] == 'test_location' if kwargs.get('reply_to_message_id') is not None: @@ -1002,15 +1027,18 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_location, Bot.send_location, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_location, message.get_bot(), 'send_location') - assert check_defaults_handling(message.reply_location, message.get_bot()) + assert await check_shortcut_call( + message.reply_location, message.get_bot(), 'send_location' + ) + assert await check_defaults_handling(message.reply_location, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_location', make_assertion) - assert message.reply_location(location='test_location') - assert message.reply_location(location='test_location', quote=True) + assert await message.reply_location(location='test_location') + assert await message.reply_location(location='test_location', quote=True) - def test_reply_venue(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_venue(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id venue = kwargs['venue'] == 'test_venue' if kwargs.get('reply_to_message_id') is not None: @@ -1022,15 +1050,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_venue, Bot.send_venue, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_venue, message.get_bot(), 'send_venue') - assert check_defaults_handling(message.reply_venue, message.get_bot()) + assert await check_shortcut_call(message.reply_venue, message.get_bot(), 'send_venue') + assert await check_defaults_handling(message.reply_venue, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_venue', make_assertion) - assert message.reply_venue(venue='test_venue') - assert message.reply_venue(venue='test_venue', quote=True) + assert await message.reply_venue(venue='test_venue') + assert await message.reply_venue(venue='test_venue', quote=True) - def test_reply_contact(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_contact(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id contact = kwargs['contact'] == 'test_contact' if kwargs.get('reply_to_message_id') is not None: @@ -1042,15 +1071,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_contact, Bot.send_contact, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_contact, message.get_bot(), 'send_contact') - assert check_defaults_handling(message.reply_contact, message.get_bot()) + assert await check_shortcut_call(message.reply_contact, message.get_bot(), 'send_contact') + assert await check_defaults_handling(message.reply_contact, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_contact', make_assertion) - assert message.reply_contact(contact='test_contact') - assert message.reply_contact(contact='test_contact', quote=True) + assert await message.reply_contact(contact='test_contact') + assert await message.reply_contact(contact='test_contact', quote=True) - def test_reply_poll(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_poll(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id question = kwargs['question'] == 'test_poll' options = kwargs['options'] == ['1', '2', '3'] @@ -1061,15 +1091,16 @@ def make_assertion(*_, **kwargs): return id_ and question and options and reply assert check_shortcut_signature(Message.reply_poll, Bot.send_poll, ['chat_id'], ['quote']) - assert check_shortcut_call(message.reply_poll, message.get_bot(), 'send_poll') - assert check_defaults_handling(message.reply_poll, message.get_bot()) + assert await check_shortcut_call(message.reply_poll, message.get_bot(), 'send_poll') + assert await check_defaults_handling(message.reply_poll, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_poll', make_assertion) - assert message.reply_poll(question='test_poll', options=['1', '2', '3']) - assert message.reply_poll(question='test_poll', quote=True, options=['1', '2', '3']) + assert await message.reply_poll(question='test_poll', options=['1', '2', '3']) + assert await message.reply_poll(question='test_poll', quote=True, options=['1', '2', '3']) - def test_reply_dice(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_dice(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id contact = kwargs['disable_notification'] is True if kwargs.get('reply_to_message_id') is not None: @@ -1079,15 +1110,16 @@ def make_assertion(*_, **kwargs): return id_ and contact and reply assert check_shortcut_signature(Message.reply_dice, Bot.send_dice, ['chat_id'], ['quote']) - assert check_shortcut_call(message.reply_dice, message.get_bot(), 'send_dice') - assert check_defaults_handling(message.reply_dice, message.get_bot()) + assert await check_shortcut_call(message.reply_dice, message.get_bot(), 'send_dice') + assert await check_defaults_handling(message.reply_dice, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_dice', make_assertion) - assert message.reply_dice(disable_notification=True) - assert message.reply_dice(disable_notification=True, quote=True) + assert await message.reply_dice(disable_notification=True) + assert await message.reply_dice(disable_notification=True, quote=True) - def test_reply_action(self, monkeypatch, message: Message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_action(self, monkeypatch, message: Message): + async def make_assertion(*_, **kwargs): id_ = kwargs['chat_id'] == message.chat_id action = kwargs['action'] == ChatAction.TYPING return id_ and action @@ -1095,30 +1127,32 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_chat_action, Bot.send_chat_action, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( message.reply_chat_action, message.get_bot(), 'send_chat_action' ) - assert check_defaults_handling(message.reply_chat_action, message.get_bot()) + assert await check_defaults_handling(message.reply_chat_action, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_chat_action', make_assertion) - assert message.reply_chat_action(action=ChatAction.TYPING) + assert await message.reply_chat_action(action=ChatAction.TYPING) - def test_reply_game(self, monkeypatch, message: Message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_game(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): return ( kwargs['chat_id'] == message.chat_id and kwargs['game_short_name'] == 'test_game' ) assert check_shortcut_signature(Message.reply_game, Bot.send_game, ['chat_id'], ['quote']) - assert check_shortcut_call(message.reply_game, message.get_bot(), 'send_game') - assert check_defaults_handling(message.reply_game, message.get_bot()) + assert await check_shortcut_call(message.reply_game, message.get_bot(), 'send_game') + assert await check_defaults_handling(message.reply_game, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_game', make_assertion) - assert message.reply_game(game_short_name='test_game') - assert message.reply_game(game_short_name='test_game', quote=True) + assert await message.reply_game(game_short_name='test_game') + assert await message.reply_game(game_short_name='test_game', quote=True) - def test_reply_invoice(self, monkeypatch, message: Message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_reply_invoice(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): title = kwargs['title'] == 'title' description = kwargs['description'] == 'description' payload = kwargs['payload'] == 'payload' @@ -1131,11 +1165,11 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_invoice, Bot.send_invoice, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.reply_invoice, message.get_bot(), 'send_invoice') - assert check_defaults_handling(message.reply_invoice, message.get_bot()) + assert await check_shortcut_call(message.reply_invoice, message.get_bot(), 'send_invoice') + assert await check_defaults_handling(message.reply_invoice, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'send_invoice', make_assertion) - assert message.reply_invoice( + assert await message.reply_invoice( 'title', 'description', 'payload', @@ -1143,7 +1177,7 @@ def make_assertion(*_, **kwargs): 'currency', 'prices', ) - assert message.reply_invoice( + assert await message.reply_invoice( 'title', 'description', 'payload', @@ -1154,8 +1188,9 @@ def make_assertion(*_, **kwargs): ) @pytest.mark.parametrize('disable_notification,protected', [(False, True), (True, False)]) - def test_forward(self, monkeypatch, message, disable_notification, protected): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_forward(self, monkeypatch, message, disable_notification, protected): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 123456 from_chat = kwargs['from_chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id @@ -1166,20 +1201,21 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.forward, Bot.forward_message, ['from_chat_id', 'message_id'], [] ) - assert check_shortcut_call(message.forward, message.get_bot(), 'forward_message') - assert check_defaults_handling(message.forward, message.get_bot()) + assert await check_shortcut_call(message.forward, message.get_bot(), 'forward_message') + assert await check_defaults_handling(message.forward, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'forward_message', make_assertion) - assert message.forward( + assert await message.forward( 123456, disable_notification=disable_notification, protect_content=protected ) - assert not message.forward(635241) + assert not await message.forward(635241) @pytest.mark.parametrize('disable_notification,protected', [(True, False), (False, True)]) - def test_copy(self, monkeypatch, message, disable_notification, protected): + @pytest.mark.asyncio + async def test_copy(self, monkeypatch, message, disable_notification, protected): keyboard = [[1, 2]] - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 123456 from_chat = kwargs['from_chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id @@ -1201,27 +1237,27 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.copy, Bot.copy_message, ['from_chat_id', 'message_id'], [] ) - - assert check_shortcut_call(message.copy, message.get_bot(), 'copy_message') - assert check_defaults_handling(message.copy, message.get_bot()) + assert await check_shortcut_call(message.copy, message.get_bot(), 'copy_message') + assert await check_defaults_handling(message.copy, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'copy_message', make_assertion) - assert message.copy( + assert await message.copy( 123456, disable_notification=disable_notification, protect_content=protected ) - assert message.copy( + assert await message.copy( 123456, reply_markup=keyboard, disable_notification=disable_notification, protect_content=protected, ) - assert not message.copy(635241) + assert not await message.copy(635241) @pytest.mark.parametrize('disable_notification,protected', [(True, False), (False, True)]) - def test_reply_copy(self, monkeypatch, message, disable_notification, protected): + @pytest.mark.asyncio + async def test_reply_copy(self, monkeypatch, message, disable_notification, protected): keyboard = [[1, 2]] - def make_assertion(*_, **kwargs): + async def make_assertion(*_, **kwargs): chat_id = kwargs['from_chat_id'] == 123456 from_chat = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == 456789 @@ -1248,28 +1284,28 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.reply_copy, Bot.copy_message, ['chat_id'], ['quote'] ) - assert check_shortcut_call(message.copy, message.get_bot(), 'copy_message') - assert check_defaults_handling(message.copy, message.get_bot()) + assert await check_shortcut_call(message.copy, message.get_bot(), 'copy_message') + assert await check_defaults_handling(message.copy, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'copy_message', make_assertion) - assert message.reply_copy( + assert await message.reply_copy( 123456, 456789, disable_notification=disable_notification, protect_content=protected ) - assert message.reply_copy( + assert await message.reply_copy( 123456, 456789, reply_markup=keyboard, disable_notification=disable_notification, protect_content=protected, ) - assert message.reply_copy( + assert await message.reply_copy( 123456, 456789, quote=True, disable_notification=disable_notification, protect_content=protected, ) - assert message.reply_copy( + assert await message.reply_copy( 123456, 456789, quote=True, @@ -1278,8 +1314,9 @@ def make_assertion(*_, **kwargs): protect_content=protected, ) - def test_edit_text(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_text(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id text = kwargs['text'] == 'test' @@ -1291,20 +1328,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.edit_text, message.get_bot(), 'edit_message_text', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.edit_text, message.get_bot()) + assert await check_defaults_handling(message.edit_text, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'edit_message_text', make_assertion) - assert message.edit_text(text='test') + assert await message.edit_text(text='test') - def test_edit_caption(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_caption(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id caption = kwargs['caption'] == 'new caption' @@ -1316,20 +1354,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.edit_caption, message.get_bot(), 'edit_message_caption', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.edit_caption, message.get_bot()) + assert await check_defaults_handling(message.edit_caption, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'edit_message_caption', make_assertion) - assert message.edit_caption(caption='new caption') + assert await message.edit_caption(caption='new caption') - def test_edit_media(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_media(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id media = kwargs['media'] == 'my_media' @@ -1341,20 +1380,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.edit_media, message.get_bot(), 'edit_message_media', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.edit_media, message.get_bot()) + assert await check_defaults_handling(message.edit_media, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'edit_message_media', make_assertion) - assert message.edit_media('my_media') + assert await message.edit_media('my_media') - def test_edit_reply_markup(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_reply_markup(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id reply_markup = kwargs['reply_markup'] == [['1', '2']] @@ -1366,20 +1406,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.edit_reply_markup, message.get_bot(), 'edit_message_reply_markup', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.edit_reply_markup, message.get_bot()) + assert await check_defaults_handling(message.edit_reply_markup, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'edit_message_reply_markup', make_assertion) - assert message.edit_reply_markup(reply_markup=[['1', '2']]) + assert await message.edit_reply_markup(reply_markup=[['1', '2']]) - def test_edit_live_location(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_edit_live_location(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id latitude = kwargs['latitude'] == 1 @@ -1392,20 +1433,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.edit_live_location, message.get_bot(), 'edit_message_live_location', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.edit_live_location, message.get_bot()) + assert await check_defaults_handling(message.edit_live_location, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'edit_message_live_location', make_assertion) - assert message.edit_live_location(latitude=1, longitude=2) + assert await message.edit_live_location(latitude=1, longitude=2) - def test_stop_live_location(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_stop_live_location(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id return chat_id and message_id @@ -1416,20 +1458,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.stop_live_location, message.get_bot(), 'stop_message_live_location', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.stop_live_location, message.get_bot()) + assert await check_defaults_handling(message.stop_live_location, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'stop_message_live_location', make_assertion) - assert message.stop_live_location() + assert await message.stop_live_location() - def test_set_game_score(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_set_game_score(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id user_id = kwargs['user_id'] == 1 @@ -1442,20 +1485,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.set_game_score, message.get_bot(), 'set_game_score', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.set_game_score, message.get_bot()) + assert await check_defaults_handling(message.set_game_score, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'set_game_score', make_assertion) - assert message.set_game_score(user_id=1, score=2) + assert await message.set_game_score(user_id=1, score=2) - def test_get_game_high_scores(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_game_high_scores(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id user_id = kwargs['user_id'] == 1 @@ -1467,20 +1511,21 @@ def make_assertion(*_, **kwargs): ['chat_id', 'message_id', 'inline_message_id'], [], ) - assert check_shortcut_call( + assert await check_shortcut_call( message.get_game_high_scores, message.get_bot(), 'get_game_high_scores', skip_params=['inline_message_id'], shortcut_kwargs=['message_id', 'chat_id'], ) - assert check_defaults_handling(message.get_game_high_scores, message.get_bot()) + assert await check_defaults_handling(message.get_game_high_scores, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'get_game_high_scores', make_assertion) - assert message.get_game_high_scores(user_id=1) + assert await message.get_game_high_scores(user_id=1) - def test_delete(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_delete(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id return chat_id and message_id @@ -1488,14 +1533,15 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.delete, Bot.delete_message, ['chat_id', 'message_id'], [] ) - assert check_shortcut_call(message.delete, message.get_bot(), 'delete_message') - assert check_defaults_handling(message.delete, message.get_bot()) + assert await check_shortcut_call(message.delete, message.get_bot(), 'delete_message') + assert await check_defaults_handling(message.delete, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'delete_message', make_assertion) - assert message.delete() + assert await message.delete() - def test_stop_poll(self, monkeypatch, message): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_stop_poll(self, monkeypatch, message): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id return chat_id and message_id @@ -1503,14 +1549,15 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( Message.stop_poll, Bot.stop_poll, ['chat_id', 'message_id'], [] ) - assert check_shortcut_call(message.stop_poll, message.get_bot(), 'stop_poll') - assert check_defaults_handling(message.stop_poll, message.get_bot()) + assert await check_shortcut_call(message.stop_poll, message.get_bot(), 'stop_poll') + assert await check_defaults_handling(message.stop_poll, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'stop_poll', make_assertion) - assert message.stop_poll() + assert await message.stop_poll() - def test_pin(self, monkeypatch, message): - def make_assertion(*args, **kwargs): + @pytest.mark.asyncio + async def test_pin(self, monkeypatch, message): + async def make_assertion(*args, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id return chat_id and message_id @@ -1518,14 +1565,15 @@ def make_assertion(*args, **kwargs): assert check_shortcut_signature( Message.pin, Bot.pin_chat_message, ['chat_id', 'message_id'], [] ) - assert check_shortcut_call(message.pin, message.get_bot(), 'pin_chat_message') - assert check_defaults_handling(message.pin, message.get_bot()) + assert await check_shortcut_call(message.pin, message.get_bot(), 'pin_chat_message') + assert await check_defaults_handling(message.pin, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'pin_chat_message', make_assertion) - assert message.pin() + assert await message.pin() - def test_unpin(self, monkeypatch, message): - def make_assertion(*args, **kwargs): + @pytest.mark.asyncio + async def test_unpin(self, monkeypatch, message): + async def make_assertion(*args, **kwargs): chat_id = kwargs['chat_id'] == message.chat_id message_id = kwargs['message_id'] == message.message_id return chat_id and message_id @@ -1533,16 +1581,16 @@ def make_assertion(*args, **kwargs): assert check_shortcut_signature( Message.unpin, Bot.unpin_chat_message, ['chat_id', 'message_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( message.unpin, message.get_bot(), 'unpin_chat_message', shortcut_kwargs=['chat_id', 'message_id'], ) - assert check_defaults_handling(message.unpin, message.get_bot()) + assert await check_defaults_handling(message.unpin, message.get_bot()) monkeypatch.setattr(message.get_bot(), 'unpin_chat_message', make_assertion) - assert message.unpin() + assert await message.unpin() def test_default_quote(self, message): message.get_bot()._defaults = Defaults() diff --git a/tests/test_messagehandler.py b/tests/test_messagehandler.py deleted file mode 100644 index 1a1891072b7..00000000000 --- a/tests/test_messagehandler.py +++ /dev/null @@ -1,210 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import re -from queue import Queue - -import pytest - -from telegram import ( - Message, - Update, - Chat, - Bot, - User, - CallbackQuery, - InlineQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import filters, MessageHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'callback_query', - 'inline_query', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=1, **request.param) - - -@pytest.fixture(scope='class') -def message(bot): - return Message(1, None, Chat(1, ''), from_user=User(1, '', False), bot=bot) - - -class TestMessageHandler: - test_flag = False - SRE_TYPE = type(re.match("", "")) - - def test_slot_behaviour(self, mro_slots): - handler = MessageHandler(filters.ALL, self.callback_context) - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.chat_data, dict) - and isinstance(context.bot_data, dict) - and ( - ( - isinstance(context.user_data, dict) - and ( - isinstance(update.message, Message) - or isinstance(update.edited_message, Message) - ) - ) - or ( - context.user_data is None - and ( - isinstance(update.channel_post, Message) - or isinstance(update.edited_channel_post, Message) - ) - ) - ) - ) - - def callback_context_regex1(self, update, context): - if context.matches: - types = all(type(res) is self.SRE_TYPE for res in context.matches) - num = len(context.matches) == 1 - self.test_flag = types and num - - def callback_context_regex2(self, update, context): - if context.matches: - types = all(type(res) is self.SRE_TYPE for res in context.matches) - num = len(context.matches) == 2 - self.test_flag = types and num - - def test_with_filter(self, message): - handler = MessageHandler(filters.ChatType.GROUP, self.callback_context) - - message.chat.type = 'group' - assert handler.check_update(Update(0, message)) - - message.chat.type = 'private' - assert not handler.check_update(Update(0, message)) - - def test_callback_query_with_filter(self, message): - class TestFilter(filters.UpdateFilter): - flag = False - - def filter(self, u): - self.flag = True - - test_filter = TestFilter() - handler = MessageHandler(test_filter, self.callback_context) - - update = Update(1, callback_query=CallbackQuery(1, None, None, message=message)) - - assert update.effective_message - assert not handler.check_update(update) - assert not test_filter.flag - - def test_specific_filters(self, message): - f = ( - ~filters.UpdateType.MESSAGES - & ~filters.UpdateType.CHANNEL_POST - & filters.UpdateType.EDITED_CHANNEL_POST - ) - handler = MessageHandler(f, self.callback_context) - - assert not handler.check_update(Update(0, edited_message=message)) - assert not handler.check_update(Update(0, message=message)) - assert not handler.check_update(Update(0, channel_post=message)) - assert handler.check_update(Update(0, edited_channel_post=message)) - - def test_other_update_types(self, false_update): - handler = MessageHandler(None, self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp, message): - handler = MessageHandler( - None, - self.callback_context, - ) - dp.add_handler(handler) - - dp.process_update(Update(0, message=message)) - assert self.test_flag - - self.test_flag = False - dp.process_update(Update(0, edited_message=message)) - assert self.test_flag - - self.test_flag = False - dp.process_update(Update(0, channel_post=message)) - assert self.test_flag - - self.test_flag = False - dp.process_update(Update(0, edited_channel_post=message)) - assert self.test_flag - - def test_context_regex(self, dp, message): - handler = MessageHandler(filters.Regex('one two'), self.callback_context_regex1) - dp.add_handler(handler) - - message.text = 'not it' - dp.process_update(Update(0, message)) - assert not self.test_flag - - message.text += ' one two now it is' - dp.process_update(Update(0, message)) - assert self.test_flag - - def test_context_multiple_regex(self, dp, message): - handler = MessageHandler( - filters.Regex('one') & filters.Regex('two'), self.callback_context_regex2 - ) - dp.add_handler(handler) - - message.text = 'not it' - dp.process_update(Update(0, message)) - assert not self.test_flag - - message.text += ' one two now it is' - dp.process_update(Update(0, message)) - assert self.test_flag diff --git a/tests/test_official.py b/tests/test_official.py index bfe0f852bcf..4039eb9fd55 100644 --- a/tests/test_official.py +++ b/tests/test_official.py @@ -20,10 +20,9 @@ import inspect from typing import List -import certifi import pytest +import httpx from bs4 import BeautifulSoup -from telegram.vendor.ptb_urllib3 import urllib3 import telegram from tests.conftest import env_var_2_bool @@ -33,9 +32,10 @@ 'self', 'args', '_kwargs', - 'read_latency', - 'network_delay', - 'timeout', + 'read_timeout', + 'write_timeout', + 'connect_timeout', + 'pool_timeout', 'bot', 'api_kwargs', 'kwargs', @@ -194,9 +194,8 @@ def check_required_param( argvalues = [] names = [] -http = urllib3.PoolManager(cert_reqs='CERT_REQUIRED', ca_certs=certifi.where()) -request = http.request('GET', 'https://core.telegram.org/bots/api') -soup = BeautifulSoup(request.data.decode('utf-8'), 'html.parser') +request = httpx.get('https://core.telegram.org/bots/api') +soup = BeautifulSoup(request.text, 'html.parser') for thing in soup.select('h4 > a.anchor'): # Methods and types don't have spaces in them, luckily all other sections of the docs do diff --git a/tests/test_passport.py b/tests/test_passport.py index f546bd0e255..909add81d62 100644 --- a/tests/test_passport.py +++ b/tests/test_passport.py @@ -35,7 +35,7 @@ # Note: All classes in telegram.credentials (except EncryptedCredentials) aren't directly tested # here, although they are implicitly tested. Testing for those classes was too much work and not # worth it. - +from telegram.request import RequestData RAW_PASSPORT_DATA = { 'credentials': { @@ -427,9 +427,10 @@ def test_bot_init_invalid_key(self, bot): with pytest.raises(ValueError): Bot(bot.token, private_key=b'Invalid key!') - def test_passport_data_okay_with_non_crypto_bot(self, bot): - b = Bot(bot.token) - assert PassportData.de_json(RAW_PASSPORT_DATA, bot=b) + @pytest.mark.asyncio + async def test_passport_data_okay_with_non_crypto_bot(self, bot): + async with Bot(bot.token) as b: + assert PassportData.de_json(RAW_PASSPORT_DATA, bot=b) def test_wrong_hash(self, bot): data = deepcopy(RAW_PASSPORT_DATA) @@ -438,20 +439,22 @@ def test_wrong_hash(self, bot): with pytest.raises(PassportDecryptionError): assert passport_data.decrypted_data - def test_wrong_key(self, bot): + @pytest.mark.asyncio + async def test_wrong_key(self, bot): short_key = b"-----BEGIN RSA PRIVATE KEY-----\r\nMIIBOQIBAAJBAKU+OZ2jJm7sCA/ec4gngNZhXYPu+DZ/TAwSMl0W7vAPXAsLplBk\r\nO8l6IBHx8N0ZC4Bc65mO3b2G8YAzqndyqH8CAwEAAQJAWOx3jQFzeVXDsOaBPdAk\r\nYTncXVeIc6tlfUl9mOLyinSbRNCy1XicOiOZFgH1rRKOGIC1235QmqxFvdecySoY\r\nwQIhAOFeGgeX9CrEPuSsd9+kqUcA2avCwqdQgSdy2qggRFyJAiEAu7QHT8JQSkHU\r\nDELfzrzc24AhjyG0z1DpGZArM8COascCIDK42SboXj3Z2UXiQ0CEcMzYNiVgOisq\r\nBUd5pBi+2mPxAiAM5Z7G/Sv1HjbKrOGh29o0/sXPhtpckEuj5QMC6E0gywIgFY6S\r\nNjwrAA+cMmsgY0O2fAzEKkDc5YiFsiXaGaSS4eA=\r\n-----END RSA PRIVATE KEY-----" - b = Bot(bot.token, private_key=short_key) - passport_data = PassportData.de_json(RAW_PASSPORT_DATA, bot=b) - with pytest.raises(PassportDecryptionError): - assert passport_data.decrypted_data + async with Bot(bot.token, private_key=short_key) as b: + passport_data = PassportData.de_json(RAW_PASSPORT_DATA, bot=b) + with pytest.raises(PassportDecryptionError): + assert passport_data.decrypted_data wrong_key = b"-----BEGIN RSA PRIVATE KEY-----\r\nMIIEogIBAAKCAQB4qCFltuvHakZze86TUweU7E/SB3VLGEHAe7GJlBmrou9SSWsL\r\nH7E++157X6UqWFl54LOE9MeHZnoW7rZ+DxLKhk6NwAHTxXPnvw4CZlvUPC3OFxg3\r\nhEmNen6ojSM4sl4kYUIa7F+Q5uMEYaboxoBen9mbj4zzMGsG4aY/xBOb2ewrXQyL\r\nRh//tk1Px4ago+lUPisAvQVecz7/6KU4Xj4Lpv2z20f3cHlZX6bb7HlE1vixCMOf\r\nxvfC5SkWEGZMR/ZoWQUsoDkrDSITF/S3GtLfg083TgtCKaOF3mCT27sJ1og77npP\r\n0cH/qdlbdoFtdrRj3PvBpaj/TtXRhmdGcJBxAgMBAAECggEAYSq1Sp6XHo8dkV8B\r\nK2/QSURNu8y5zvIH8aUrgqo8Shb7OH9bryekrB3vJtgNwR5JYHdu2wHttcL3S4SO\r\nftJQxbyHgmxAjHUVNGqOM6yPA0o7cR70J7FnMoKVgdO3q68pVY7ll50IET9/T0X9\r\nDrTdKFb+/eILFsXFS1NpeSzExdsKq3zM0sP/vlJHHYVTmZDGaGEvny/eLAS+KAfG\r\nrKP96DeO4C/peXEJzALZ/mG1ReBB05Qp9Dx1xEC20yreRk5MnnBA5oiHVG5ZLOl9\r\nEEHINidqN+TMNSkxv67xMfQ6utNu5IpbklKv/4wqQOJOO50HZ+qBtSurTN573dky\r\nzslbCQKBgQDHDUBYyKN/v69VLmvNVcxTgrOcrdbqAfefJXb9C3dVXhS8/oRkCRU/\r\ndzxYWNT7hmQyWUKor/izh68rZ/M+bsTnlaa7IdAgyChzTfcZL/2pxG9pq05GF1Q4\r\nBSJ896ZEe3jEhbpJXRlWYvz7455svlxR0H8FooCTddTmkU3nsQSx0wKBgQCbLSa4\r\nyZs2QVstQQerNjxAtLi0IvV8cJkuvFoNC2Q21oqQc7BYU7NJL7uwriprZr5nwkCQ\r\nOFQXi4N3uqimNxuSng31ETfjFZPp+pjb8jf7Sce7cqU66xxR+anUzVZqBG1CJShx\r\nVxN7cWN33UZvIH34gA2Ax6AXNnJG42B5Gn1GKwKBgQCZ/oh/p4nGNXfiAK3qB6yy\r\nFvX6CwuvsqHt/8AUeKBz7PtCU+38roI/vXF0MBVmGky+HwxREQLpcdl1TVCERpIT\r\nUFXThI9OLUwOGI1IcTZf9tby+1LtKvM++8n4wGdjp9qAv6ylQV9u09pAzZItMwCd\r\nUx5SL6wlaQ2y60tIKk0lfQKBgBJS+56YmA6JGzY11qz+I5FUhfcnpauDNGOTdGLT\r\n9IqRPR2fu7RCdgpva4+KkZHLOTLReoRNUojRPb4WubGfEk93AJju5pWXR7c6k3Bt\r\novS2mrJk8GQLvXVksQxjDxBH44sLDkKMEM3j7uYJqDaZNKbyoCWT7TCwikAau5qx\r\naRevAoGAAKZV705dvrpJuyoHFZ66luANlrAwG/vNf6Q4mBEXB7guqMkokCsSkjqR\r\nhsD79E6q06zA0QzkLCavbCn5kMmDS/AbA80+B7El92iIN6d3jRdiNZiewkhlWhEG\r\nm4N0gQRfIu+rUjsS/4xk8UuQUT/Ossjn/hExi7ejpKdCc7N++bc=\r\n-----END RSA PRIVATE KEY-----" - b = Bot(bot.token, private_key=wrong_key) - passport_data = PassportData.de_json(RAW_PASSPORT_DATA, bot=b) - with pytest.raises(PassportDecryptionError): - assert passport_data.decrypted_data + async with Bot(bot.token, private_key=wrong_key) as b: + passport_data = PassportData.de_json(RAW_PASSPORT_DATA, bot=b) + with pytest.raises(PassportDecryptionError): + assert passport_data.decrypted_data - def test_mocked_download_passport_file(self, passport_data, monkeypatch): + @pytest.mark.asyncio + async def test_mocked_download_passport_file(self, passport_data, monkeypatch): # The files are not coming from our test bot, therefore the file id is invalid/wrong # when coming from this bot, so we monkeypatch the call, to make sure that Bot.get_file # at least gets called @@ -459,30 +462,32 @@ def test_mocked_download_passport_file(self, passport_data, monkeypatch): selfie = passport_data.decrypted_data[1].selfie # NOTE: file_unique_id is not used in the get_file method, so it is passed directly - def get_file(*_, **kwargs): + async def get_file(*_, **kwargs): return File(kwargs['file_id'], selfie.file_unique_id) monkeypatch.setattr(passport_data.get_bot(), 'get_file', get_file) - file = selfie.get_file() + file = await selfie.get_file() assert file.file_id == selfie.file_id assert file.file_unique_id == selfie.file_unique_id assert file._credentials.file_hash == self.driver_license_selfie_credentials_file_hash assert file._credentials.secret == self.driver_license_selfie_credentials_secret - def test_mocked_set_passport_data_errors(self, monkeypatch, bot, chat_id, passport_data): - def test(url, data, **kwargs): + @pytest.mark.asyncio + async def test_mocked_set_passport_data_errors(self, monkeypatch, bot, chat_id, passport_data): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.parameters return ( - data['user_id'] == chat_id + data['user_id'] == str(chat_id) and data['errors'][0]['file_hash'] == ( passport_data.decrypted_credentials.secure_data.driver_license.selfie.file_hash ) and data['errors'][1]['data_hash'] - == (passport_data.decrypted_credentials.secure_data.driver_license.data.data_hash) + == passport_data.decrypted_credentials.secure_data.driver_license.data.data_hash ) - monkeypatch.setattr(bot.request, 'post', test) - message = bot.set_passport_data_errors( + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.set_passport_data_errors( chat_id, [ PassportElementErrorSelfie( diff --git a/tests/test_passportfile.py b/tests/test_passportfile.py index dd8b2386ca6..ffcc817af18 100644 --- a/tests/test_passportfile.py +++ b/tests/test_passportfile.py @@ -60,18 +60,21 @@ def test_to_dict(self, passport_file): assert passport_file_dict['file_size'] == passport_file.file_size assert passport_file_dict['file_date'] == passport_file.file_date - def test_get_file_instance_method(self, monkeypatch, passport_file): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, passport_file): + async def make_assertion(*_, **kwargs): result = kwargs['file_id'] == passport_file.file_id # we need to be a bit hacky here, b/c PF.get_file needs Bot.get_file to return a File return File(file_id=result, file_unique_id=result) assert check_shortcut_signature(PassportFile.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(passport_file.get_file, passport_file.get_bot(), 'get_file') - assert check_defaults_handling(passport_file.get_file, passport_file.get_bot()) + assert await check_shortcut_call( + passport_file.get_file, passport_file.get_bot(), 'get_file' + ) + assert await check_defaults_handling(passport_file.get_file, passport_file.get_bot()) monkeypatch.setattr(passport_file.get_bot(), 'get_file', make_assertion) - assert passport_file.get_file().file_id == 'True' + assert (await passport_file.get_file()).file_id == 'True' def test_equality(self): a = PassportFile(self.file_id, self.file_unique_id, self.file_size, self.file_date) diff --git a/tests/test_persistence.py b/tests/test_persistence.py deleted file mode 100644 index fc7f32fcf97..00000000000 --- a/tests/test_persistence.py +++ /dev/null @@ -1,2371 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import logging -import os -import pickle -import gzip -import signal -import uuid -from collections.abc import Container -from collections import defaultdict -from pathlib import Path -from time import sleep -from threading import Lock - -import pytest - -try: - import ujson as json -except ImportError: - import json - -from telegram import Update, Message, User, Chat, MessageEntity, Bot -from telegram.ext import ( - BasePersistence, - ConversationHandler, - MessageHandler, - filters, - PicklePersistence, - CommandHandler, - DictPersistence, - TypeHandler, - JobQueue, - ContextTypes, - PersistenceInput, - UpdaterBuilder, - CallbackDataCache, -) -from telegram.ext._callbackdatacache import _KeyboardData - - -@pytest.fixture(autouse=True) -def change_directory(tmp_path: Path): - orig_dir = Path.cwd() - # Switch to a temporary directory, so we don't have to worry about cleaning up files - os.chdir(tmp_path) - yield - # Go back to original directory - os.chdir(orig_dir) - - -@pytest.fixture(autouse=True) -def reset_callback_data_cache(bot): - yield - bot.callback_data_cache.clear_callback_data() - bot.callback_data_cache.clear_callback_queries() - bot.arbitrary_callback_data = False - - -class OwnPersistence(BasePersistence): - def get_bot_data(self): - raise NotImplementedError - - def get_chat_data(self): - raise NotImplementedError - - def get_user_data(self): - raise NotImplementedError - - def get_conversations(self, name): - raise NotImplementedError - - def update_bot_data(self, data): - raise NotImplementedError - - def update_chat_data(self, chat_id, data): - raise NotImplementedError - - def update_conversation(self, name, key, new_state): - raise NotImplementedError - - def update_user_data(self, user_id, data): - raise NotImplementedError - - def get_callback_data(self): - raise NotImplementedError - - def drop_user_data(self, user_id): - raise NotImplementedError - - def drop_chat_data(self, chat_id): - raise NotImplementedError - - def refresh_user_data(self, user_id, user_data): - raise NotImplementedError - - def refresh_chat_data(self, chat_id, chat_data): - raise NotImplementedError - - def refresh_bot_data(self, bot_data): - raise NotImplementedError - - def update_callback_data(self, data): - raise NotImplementedError - - def flush(self): - raise NotImplementedError - - -@pytest.fixture(scope="function") -def base_persistence(): - return OwnPersistence() - - -@pytest.fixture(scope="function") -def bot_persistence(): - class BotPersistence(BasePersistence): - __slots__ = () - - def __init__(self): - super().__init__() - self.bot_data = None - self.chat_data = {} - self.user_data = {} - self.callback_data = None - - def get_bot_data(self): - return self.bot_data - - def get_chat_data(self): - return self.chat_data - - def get_user_data(self): - return self.user_data - - def get_callback_data(self): - return self.callback_data - - def get_conversations(self, name): - raise NotImplementedError - - def update_bot_data(self, data): - self.bot_data = data - - def update_chat_data(self, chat_id, data): - self.chat_data[chat_id] = data - - def update_user_data(self, user_id, data): - self.user_data[user_id] = data - - def update_callback_data(self, data): - self.callback_data = data - - def drop_user_data(self, user_id): - self.user_data.pop(user_id, None) - - def drop_chat_data(self, chat_id): - self.chat_data.pop(chat_id, None) - - def update_conversation(self, name, key, new_state): - raise NotImplementedError - - def refresh_user_data(self, user_id, user_data): - pass - - def refresh_chat_data(self, chat_id, chat_data): - pass - - def refresh_bot_data(self, bot_data): - pass - - def flush(self): - pass - - return BotPersistence() - - -@pytest.fixture(scope="function") -def bot_data(): - return {'test1': 'test2', 'test3': {'test4': 'test5'}} - - -@pytest.fixture(scope="function") -def chat_data(): - return {-12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, -67890: {3: 'test4'}} - - -@pytest.fixture(scope="function") -def user_data(): - return {12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, 67890: {3: 'test4'}} - - -@pytest.fixture(scope="function") -def callback_data(): - return [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})], {'test1': 'test2'} - - -@pytest.fixture(scope='function') -def conversations(): - return { - 'name1': {(123, 123): 3, (456, 654): 4}, - 'name2': {(123, 321): 1, (890, 890): 2}, - 'name3': {(123, 321): 1, (890, 890): 2}, - } - - -@pytest.fixture(scope="function") -def updater(bot, base_persistence): - base_persistence.store_data = PersistenceInput(False, False, False, False) - u = UpdaterBuilder().bot(bot).persistence(base_persistence).build() - base_persistence.store_data = PersistenceInput() - return u - - -@pytest.fixture(scope='function') -def job_queue(bot): - jq = JobQueue() - yield jq - jq.stop() - - -def assert_data_in_cache(callback_data_cache: CallbackDataCache, data): - for val in callback_data_cache._keyboard_data.values(): - if data in val.button_data.values(): - return data - return False - - -class TestBasePersistence: - test_flag = False - - @pytest.fixture(scope='function', autouse=True) - def reset(self): - self.test_flag = False - - def test_slot_behaviour(self, bot_persistence, mro_slots): - inst = bot_persistence - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - # assert not inst.__dict__, f"got missing slot(s): {inst.__dict__}" - # The below test fails if the child class doesn't define __slots__ (not a cause of concern) - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - def test_creation(self, base_persistence): - assert base_persistence.store_data.chat_data - assert base_persistence.store_data.user_data - assert base_persistence.store_data.bot_data - assert base_persistence.store_data.callback_data - - def test_abstract_methods(self, base_persistence): - with pytest.raises( - TypeError, - match=( - 'drop_chat_data, drop_user_data, flush, get_bot_data, get_callback_data, ' - 'get_chat_data, get_conversations, ' - 'get_user_data, refresh_bot_data, refresh_chat_data, ' - 'refresh_user_data, update_bot_data, update_callback_data, ' - 'update_chat_data, update_conversation, update_user_data' - ), - ): - BasePersistence() - with pytest.raises(NotImplementedError): - base_persistence.get_callback_data() - with pytest.raises(NotImplementedError): - base_persistence.update_callback_data((None, {'foo': 'bar'})) - - def test_implementation(self, updater, base_persistence): - dp = updater.dispatcher - assert dp.persistence == base_persistence - - def test_conversationhandler_addition(self, dp, base_persistence): - with pytest.raises(ValueError, match="when handler is unnamed"): - ConversationHandler([], [], [], persistent=True) - with pytest.raises(ValueError, match="if dispatcher has no persistence"): - dp.add_handler(ConversationHandler([], {}, [], persistent=True, name="My Handler")) - dp.persistence = base_persistence - - def test_dispatcher_integration_init( - self, bot, base_persistence, chat_data, user_data, bot_data, callback_data - ): - # Bad data testing- - def bad_get_bot_data(): - return "test" - - def bad_get_callback_data(): - return "test" - - # Good data testing- - def good_get_user_data(): - return user_data - - def good_get_chat_data(): - return chat_data - - def good_get_bot_data(): - return bot_data - - def good_get_callback_data(): - return callback_data - - base_persistence.get_user_data = good_get_user_data # No errors to be tested so - base_persistence.get_chat_data = good_get_chat_data - base_persistence.get_bot_data = bad_get_bot_data - base_persistence.get_callback_data = bad_get_callback_data - - with pytest.raises(ValueError, match="bot_data must be of type dict"): - UpdaterBuilder().bot(bot).persistence(base_persistence).build() - - base_persistence.get_bot_data = good_get_bot_data - with pytest.raises(ValueError, match="callback_data must be a tuple of length 2"): - UpdaterBuilder().bot(bot).persistence(base_persistence).build() - - base_persistence.bot = None - base_persistence.get_callback_data = good_get_callback_data - u = UpdaterBuilder().bot(bot).persistence(base_persistence).build() - assert u.dispatcher.bot is base_persistence.bot - assert u.dispatcher.bot_data == bot_data - assert u.dispatcher.chat_data == chat_data - assert u.dispatcher.user_data == user_data - assert u.dispatcher.bot.callback_data_cache.persistence_data == callback_data - u.dispatcher.chat_data[442233]['test5'] = 'test6' - assert u.dispatcher.chat_data[442233]['test5'] == 'test6' - - @pytest.mark.parametrize('run_async', [True, False], ids=['run_async', 'synchronous']) - def test_dispatcher_integration_handlers( - self, - dp, - caplog, - bot, - base_persistence, - chat_data, - user_data, - bot_data, - callback_data, - run_async, - ): - def get_user_data(): - return user_data - - def get_chat_data(): - return chat_data - - def get_bot_data(): - return bot_data - - def get_callback_data(): - return callback_data - - base_persistence.get_user_data = get_user_data - base_persistence.get_chat_data = get_chat_data - base_persistence.get_bot_data = get_bot_data - base_persistence.get_callback_data = get_callback_data - base_persistence.refresh_bot_data = lambda x: x - base_persistence.refresh_chat_data = lambda x, y: x - base_persistence.refresh_user_data = lambda x, y: x - updater = UpdaterBuilder().bot(bot).persistence(base_persistence).build() - dp = updater.dispatcher - - def callback_known_user(update, context): - if not context.user_data['test1'] == 'test2': - pytest.fail('user_data corrupt') - if not context.bot_data == bot_data: - pytest.fail('bot_data corrupt') - - def callback_known_chat(update, context): - if not context.chat_data[3] == 'test4': - pytest.fail('chat_data corrupt') - if not context.bot_data == bot_data: - pytest.fail('bot_data corrupt') - - def callback_unknown_user_or_chat(update, context): - if not context.user_data == {}: - pytest.fail('user_data corrupt') - if not context.chat_data == {}: - pytest.fail('chat_data corrupt') - if not context.bot_data == bot_data: - pytest.fail('bot_data corrupt') - context.user_data[1] = 'test7' - context.chat_data[2] = 'test8' - context.bot_data['test0'] = 'test0' - # Let's now delete user1 and chat1 - context.dispatcher.drop_chat_data(-67890) - context.dispatcher.drop_user_data(12345) - # Test setting new keyboard callback data- - context.bot.callback_data_cache._keyboard_data['id'] = _KeyboardData( - 'id', button_data={'button3': 'test3'} - ) - - known_user = MessageHandler(filters.User(user_id=12345), callback_known_user) # user1 - known_chat = MessageHandler(filters.Chat(chat_id=-67890), callback_known_chat) # chat1 - unknown = MessageHandler(filters.ALL, callback_unknown_user_or_chat) # user2 and chat2 - dp.add_handler(known_user) - dp.add_handler(known_chat) - dp.add_handler(unknown) - user1 = User(id=12345, first_name='test user', is_bot=False) - user2 = User(id=54321, first_name='test user', is_bot=False) - chat1 = Chat(id=-67890, type='group') - chat2 = Chat(id=-987654, type='group') - m = Message(1, None, chat2, from_user=user1) - u_known_user = Update(0, m) - dp.process_update(u_known_user) - # 4 errors which arise since update_*_data are raising NotImplementedError here. - assert len(caplog.records) == 4 - m.from_user = user2 - m.chat = chat1 - u_known_chat = Update(1, m) - dp.process_update(u_known_chat) - m.chat = chat2 - u_unknown_user_or_chat = Update(2, m) - - def save_bot_data(data): - if 'test0' not in data: - pytest.fail() - - def save_chat_data(_id, data): - if 2 not in data: # data should be: {2: 'test8'} - pytest.fail() - - def save_user_data(_id, data): - if 1 not in data: # data should be: {1: 'test7'} - pytest.fail() - - def save_callback_data(data): - if not assert_data_in_cache(dp.bot.callback_data_cache, 'test3'): - pytest.fail() - - # Functions to check deletion- - def delete_user_data(user_id): - if 12345 != user_id: - pytest.fail("The id being deleted is not of user1's") - user_data.pop(user_id, None) - - def delete_chat_data(chat_id): - if -67890 != chat_id: - pytest.fail("The chat id being deleted is not of chat1's") - chat_data.pop(chat_id, None) - - base_persistence.update_chat_data = save_chat_data - base_persistence.update_user_data = save_user_data - base_persistence.update_bot_data = save_bot_data - base_persistence.update_callback_data = save_callback_data - base_persistence.drop_chat_data = delete_chat_data - base_persistence.drop_user_data = delete_user_data - dp.process_update(u_unknown_user_or_chat) - - # Test callback_unknown_user_or_chat worked correctly- - assert dp.user_data[54321][1] == 'test7' - assert dp.chat_data[-987654][2] == 'test8' - assert dp.bot_data['test0'] == 'test0' - assert assert_data_in_cache(dp.bot.callback_data_cache, 'test3') - assert 12345 not in dp.user_data # Tests if dp.drop_user_data worked or not - assert -67890 not in dp.chat_data - assert len(caplog.records) == 8 # Errors double since new update is processed. - for r in caplog.records: - assert issubclass(r.exc_info[0], NotImplementedError) - assert r.getMessage() == 'No error handlers are registered, logging exception.' - assert r.levelname == 'ERROR' - - def test_dispatcher_integration_migrate_chat_data(self, dp, bot_persistence): - dp.persistence = bot_persistence - dp.chat_data[1]['key'] = 'value' - dp.update_persistence() - assert bot_persistence.chat_data == {1: {'key': 'value'}} - - dp.migrate_chat_data(old_chat_id=1, new_chat_id=2) - assert bot_persistence.chat_data == {2: {'key': 'value'}} - - @pytest.mark.parametrize( - 'store_user_data', [True, False], ids=['store_user_data-True', 'store_user_data-False'] - ) - @pytest.mark.parametrize( - 'store_chat_data', [True, False], ids=['store_chat_data-True', 'store_chat_data-False'] - ) - @pytest.mark.parametrize( - 'store_bot_data', [True, False], ids=['store_bot_data-True', 'store_bot_data-False'] - ) - @pytest.mark.parametrize('run_async', [True, False], ids=['run_async', 'synchronous']) - def test_persistence_dispatcher_integration_refresh_data( - self, - dp, - base_persistence, - chat_data, - bot_data, - user_data, - store_bot_data, - store_chat_data, - store_user_data, - run_async, - ): - base_persistence.refresh_bot_data = lambda x: x.setdefault( - 'refreshed', x.get('refreshed', 0) + 1 - ) - # x is the user/chat_id - base_persistence.refresh_chat_data = lambda x, y: y.setdefault('refreshed', x) - base_persistence.refresh_user_data = lambda x, y: y.setdefault('refreshed', x) - base_persistence.store_data = PersistenceInput( - bot_data=store_bot_data, chat_data=store_chat_data, user_data=store_user_data - ) - dp.persistence = base_persistence - - self.test_flag = True - - def callback_with_user_and_chat(update, context): - if store_user_data: - if context.user_data.get('refreshed') != update.effective_user.id: - self.test_flag = 'user_data was not refreshed' - else: - if 'refreshed' in context.user_data: - self.test_flag = 'user_data was wrongly refreshed' - if store_chat_data: - if context.chat_data.get('refreshed') != update.effective_chat.id: - self.test_flag = 'chat_data was not refreshed' - else: - if 'refreshed' in context.chat_data: - self.test_flag = 'chat_data was wrongly refreshed' - if store_bot_data: - if context.bot_data.get('refreshed') != 1: - self.test_flag = 'bot_data was not refreshed' - else: - if 'refreshed' in context.bot_data: - self.test_flag = 'bot_data was wrongly refreshed' - - def callback_without_user_and_chat(_, context): - if store_bot_data: - if context.bot_data.get('refreshed') != 1: - self.test_flag = 'bot_data was not refreshed' - else: - if 'refreshed' in context.bot_data: - self.test_flag = 'bot_data was wrongly refreshed' - - with_user_and_chat = MessageHandler( - filters.User(user_id=12345), - callback_with_user_and_chat, - run_async=run_async, - ) - without_user_and_chat = MessageHandler( - filters.ALL, - callback_without_user_and_chat, - run_async=run_async, - ) - dp.add_handler(with_user_and_chat) - dp.add_handler(without_user_and_chat) - user = User(id=12345, first_name='test user', is_bot=False) - chat = Chat(id=-987654, type='group') - m = Message(1, None, chat, from_user=user) - - # has user and chat - u = Update(0, m) - dp.process_update(u) - - assert self.test_flag is True - - # has neither user nor hat - m.from_user = None - m.chat = None - u = Update(1, m) - dp.process_update(u) - - assert self.test_flag is True - - sleep(0.1) - - def test_persistence_dispatcher_arbitrary_update_types(self, dp, base_persistence, caplog): - # Updates used with TypeHandler doesn't necessarily have the proper attributes for - # persistence, makes sure it works anyways - - dp.persistence = base_persistence - - class MyUpdate: - pass - - dp.add_handler(TypeHandler(MyUpdate, lambda *_: None)) - - with caplog.at_level(logging.ERROR): - dp.process_update(MyUpdate()) - assert 'An uncaught error was raised while processing the update' not in caplog.text - - def test_bot_replace_insert_bot(self, bot, bot_persistence): - class CustomSlottedClass: - __slots__ = ('bot', '__dict__') - - def __init__(self): - self.bot = bot - self.not_in_slots = bot - - def __eq__(self, other): - if isinstance(other, CustomSlottedClass): - return self.bot is other.bot and self.not_in_slots is other.not_in_slots - return False - - class DictNotInSlots(Container): - """This classes parent has slots, but __dict__ is not in those slots.""" - - def __init__(self): - self.bot = bot - - def __contains__(self, item): - return True - - def __eq__(self, other): - if isinstance(other, DictNotInSlots): - return self.bot is other.bot - return False - - class CustomClass: - def __init__(self): - self.bot = bot - self.slotted_object = CustomSlottedClass() - self.dict_not_in_slots_object = DictNotInSlots() - self.list_ = [1, 2, bot] - self.tuple_ = tuple(self.list_) - self.set_ = set(self.list_) - self.frozenset_ = frozenset(self.list_) - self.dict_ = {item: item for item in self.list_} - self.defaultdict_ = defaultdict(dict, self.dict_) - - @staticmethod - def replace_bot(): - cc = CustomClass() - cc.bot = BasePersistence.REPLACED_BOT - cc.slotted_object.bot = BasePersistence.REPLACED_BOT - cc.slotted_object.not_in_slots = BasePersistence.REPLACED_BOT - cc.dict_not_in_slots_object.bot = BasePersistence.REPLACED_BOT - cc.list_ = [1, 2, BasePersistence.REPLACED_BOT] - cc.tuple_ = tuple(cc.list_) - cc.set_ = set(cc.list_) - cc.frozenset_ = frozenset(cc.list_) - cc.dict_ = {item: item for item in cc.list_} - cc.defaultdict_ = defaultdict(dict, cc.dict_) - return cc - - def __eq__(self, other): - if isinstance(other, CustomClass): - return ( - self.bot is other.bot - and self.slotted_object == other.slotted_object - and self.dict_not_in_slots_object == other.dict_not_in_slots_object - and self.list_ == other.list_ - and self.tuple_ == other.tuple_ - and self.set_ == other.set_ - and self.frozenset_ == other.frozenset_ - and self.dict_ == other.dict_ - and self.defaultdict_ == other.defaultdict_ - ) - return False - - persistence = bot_persistence - persistence.set_bot(bot) - cc = CustomClass() - - persistence.update_bot_data({1: cc}) - assert persistence.bot_data[1].bot == BasePersistence.REPLACED_BOT - assert persistence.bot_data[1] == cc.replace_bot() - - persistence.update_chat_data(123, {1: cc}) - assert persistence.chat_data[123][1].bot == BasePersistence.REPLACED_BOT - assert persistence.chat_data[123][1] == cc.replace_bot() - - persistence.update_user_data(123, {1: cc}) - assert persistence.user_data[123][1].bot == BasePersistence.REPLACED_BOT - assert persistence.user_data[123][1] == cc.replace_bot() - - persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0].bot == BasePersistence.REPLACED_BOT - assert persistence.callback_data[0][0][2][0] == cc.replace_bot() - - assert persistence.get_bot_data()[1] == cc - assert persistence.get_bot_data()[1].bot is bot - assert persistence.get_chat_data()[123][1] == cc - assert persistence.get_chat_data()[123][1].bot is bot - assert persistence.get_user_data()[123][1] == cc - assert persistence.get_user_data()[123][1].bot is bot - assert persistence.get_callback_data()[0][0][2][0].bot is bot - assert persistence.get_callback_data()[0][0][2][0] == cc - - def test_bot_replace_insert_bot_unpickable_objects(self, bot, bot_persistence, recwarn): - """Here check that unpickable objects are just returned verbatim.""" - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomClass: - def __copy__(self): - raise TypeError('UnhandledException') - - lock = Lock() - - persistence.update_bot_data({1: lock}) - assert persistence.bot_data[1] is lock - persistence.update_chat_data(123, {1: lock}) - assert persistence.chat_data[123][1] is lock - persistence.update_user_data(123, {1: lock}) - assert persistence.user_data[123][1] is lock - persistence.update_callback_data(([('1', 2, {0: lock})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0] is lock - - assert persistence.get_bot_data()[1] is lock - assert persistence.get_chat_data()[123][1] is lock - assert persistence.get_user_data()[123][1] is lock - assert persistence.get_callback_data()[0][0][2][0] is lock - - cc = CustomClass() - - persistence.update_bot_data({1: cc}) - assert persistence.bot_data[1] is cc - persistence.update_chat_data(123, {1: cc}) - assert persistence.chat_data[123][1] is cc - persistence.update_user_data(123, {1: cc}) - assert persistence.user_data[123][1] is cc - persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0] is cc - - assert persistence.get_bot_data()[1] is cc - assert persistence.get_chat_data()[123][1] is cc - assert persistence.get_user_data()[123][1] is cc - assert persistence.get_callback_data()[0][0][2][0] is cc - - assert len(recwarn) == 2 - assert str(recwarn[0].message).startswith( - "BasePersistence.replace_bot does not handle objects that can not be copied." - ) - assert str(recwarn[1].message).startswith( - "BasePersistence.insert_bot does not handle objects that can not be copied." - ) - - def test_bot_replace_insert_bot_unparsable_objects(self, bot, bot_persistence, recwarn): - """Here check that objects in __dict__ or __slots__ that can't - be parsed are just returned verbatim.""" - persistence = bot_persistence - persistence.set_bot(bot) - - uuid_obj = uuid.uuid4() - - persistence.update_bot_data({1: uuid_obj}) - assert persistence.bot_data[1] is uuid_obj - persistence.update_chat_data(123, {1: uuid_obj}) - assert persistence.chat_data[123][1] is uuid_obj - persistence.update_user_data(123, {1: uuid_obj}) - assert persistence.user_data[123][1] is uuid_obj - persistence.update_callback_data(([('1', 2, {0: uuid_obj})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0] is uuid_obj - - assert persistence.get_bot_data()[1] is uuid_obj - assert persistence.get_chat_data()[123][1] is uuid_obj - assert persistence.get_user_data()[123][1] is uuid_obj - assert persistence.get_callback_data()[0][0][2][0] is uuid_obj - - assert len(recwarn) == 2 - assert str(recwarn[0].message).startswith( - "Parsing of an object failed with the following exception: " - ) - assert str(recwarn[1].message).startswith( - "Parsing of an object failed with the following exception: " - ) - - def test_bot_replace_insert_bot_classes(self, bot, bot_persistence, recwarn): - """Here check that classes are just returned verbatim.""" - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomClass: - pass - - persistence.update_bot_data({1: CustomClass}) - assert persistence.bot_data[1] is CustomClass - persistence.update_chat_data(123, {1: CustomClass}) - assert persistence.chat_data[123][1] is CustomClass - persistence.update_user_data(123, {1: CustomClass}) - assert persistence.user_data[123][1] is CustomClass - - assert persistence.get_bot_data()[1] is CustomClass - assert persistence.get_chat_data()[123][1] is CustomClass - assert persistence.get_user_data()[123][1] is CustomClass - - assert len(recwarn) == 2 - assert str(recwarn[0].message).startswith( - "BasePersistence.replace_bot does not handle classes such as 'CustomClass'" - ) - assert str(recwarn[1].message).startswith( - "BasePersistence.insert_bot does not handle classes such as 'CustomClass'" - ) - - def test_bot_replace_insert_bot_objects_with_faulty_equality(self, bot, bot_persistence): - """Here check that trying to compare obj == self.REPLACED_BOT doesn't lead to problems.""" - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomClass: - def __init__(self, data): - self.data = data - - def __eq__(self, other): - raise RuntimeError("Can't be compared") - - cc = CustomClass({1: bot, 2: 'foo'}) - expected = {1: BasePersistence.REPLACED_BOT, 2: 'foo'} - - persistence.update_bot_data({1: cc}) - assert persistence.bot_data[1].data == expected - persistence.update_chat_data(123, {1: cc}) - assert persistence.chat_data[123][1].data == expected - persistence.update_user_data(123, {1: cc}) - assert persistence.user_data[123][1].data == expected - persistence.update_callback_data(([('1', 2, {0: cc})], {'1': '2'})) - assert persistence.callback_data[0][0][2][0].data == expected - - expected = {1: bot, 2: 'foo'} - - assert persistence.get_bot_data()[1].data == expected - assert persistence.get_chat_data()[123][1].data == expected - assert persistence.get_user_data()[123][1].data == expected - assert persistence.get_callback_data()[0][0][2][0].data == expected - - @pytest.mark.filterwarnings('ignore:BasePersistence') - def test_replace_insert_bot_item_identity(self, bot, bot_persistence): - persistence = bot_persistence - persistence.set_bot(bot) - - class CustomSlottedClass: - __slots__ = ('value',) - - def __init__(self): - self.value = 5 - - class CustomClass: - pass - - slot_object = CustomSlottedClass() - dict_object = CustomClass() - lock = Lock() - list_ = [slot_object, dict_object, lock] - tuple_ = (1, 2, 3) - dict_ = {1: slot_object, 2: dict_object} - - data = { - 'bot_1': bot, - 'bot_2': bot, - 'list_1': list_, - 'list_2': list_, - 'tuple_1': tuple_, - 'tuple_2': tuple_, - 'dict_1': dict_, - 'dict_2': dict_, - } - - def make_assertion(data_): - return ( - data_['bot_1'] is data_['bot_2'] - and data_['list_1'] is data_['list_2'] - and data_['list_1'][0] is data_['list_2'][0] - and data_['list_1'][1] is data_['list_2'][1] - and data_['list_1'][2] is data_['list_2'][2] - and data_['tuple_1'] is data_['tuple_2'] - and data_['dict_1'] is data_['dict_2'] - and data_['dict_1'][1] is data_['dict_2'][1] - and data_['dict_1'][1] is data_['list_1'][0] - and data_['dict_1'][2] is data_['list_1'][1] - and data_['dict_1'][2] is data_['dict_2'][2] - ) - - persistence.update_bot_data(data) - assert make_assertion(persistence.bot_data) - assert make_assertion(persistence.get_bot_data()) - - def test_set_bot_exception(self, bot): - non_ext_bot = Bot(bot.token) - persistence = OwnPersistence() - with pytest.raises(TypeError, match='callback_data can only be stored'): - persistence.set_bot(non_ext_bot) - - -@pytest.fixture(scope='function') -def pickle_persistence(): - return PicklePersistence( - filepath='pickletest', - single_file=False, - on_flush=False, - ) - - -@pytest.fixture(scope='function') -def pickle_persistence_only_bot(): - return PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(callback_data=False, user_data=False, chat_data=False), - single_file=False, - on_flush=False, - ) - - -@pytest.fixture(scope='function') -def pickle_persistence_only_chat(): - return PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(callback_data=False, user_data=False, bot_data=False), - single_file=False, - on_flush=False, - ) - - -@pytest.fixture(scope='function') -def pickle_persistence_only_user(): - return PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(callback_data=False, chat_data=False, bot_data=False), - single_file=False, - on_flush=False, - ) - - -@pytest.fixture(scope='function') -def pickle_persistence_only_callback(): - return PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(user_data=False, chat_data=False, bot_data=False), - single_file=False, - on_flush=False, - ) - - -@pytest.fixture(scope='function') -def bad_pickle_files(): - for name in [ - 'pickletest_user_data', - 'pickletest_chat_data', - 'pickletest_bot_data', - 'pickletest_callback_data', - 'pickletest_conversations', - 'pickletest', - ]: - Path(name).write_text('(())') - yield True - - -@pytest.fixture(scope='function') -def invalid_pickle_files(): - for name in [ - 'pickletest_user_data', - 'pickletest_chat_data', - 'pickletest_bot_data', - 'pickletest_callback_data', - 'pickletest_conversations', - 'pickletest', - ]: - # Just a random way to trigger pickle.UnpicklingError - # see https://stackoverflow.com/a/44422239/10606962 - with gzip.open(name, 'wb') as file: - pickle.dump([1, 2, 3], file) - yield True - - -@pytest.fixture(scope='function') -def good_pickle_files(user_data, chat_data, bot_data, callback_data, conversations): - data = { - 'user_data': user_data, - 'chat_data': chat_data, - 'bot_data': bot_data, - 'callback_data': callback_data, - 'conversations': conversations, - } - with Path('pickletest_user_data').open('wb') as f: - pickle.dump(user_data, f) - with Path('pickletest_chat_data').open('wb') as f: - pickle.dump(chat_data, f) - with Path('pickletest_bot_data').open('wb') as f: - pickle.dump(bot_data, f) - with Path('pickletest_callback_data').open('wb') as f: - pickle.dump(callback_data, f) - with Path('pickletest_conversations').open('wb') as f: - pickle.dump(conversations, f) - with Path('pickletest').open('wb') as f: - pickle.dump(data, f) - yield True - - -@pytest.fixture(scope='function') -def pickle_files_wo_bot_data(user_data, chat_data, callback_data, conversations): - data = { - 'user_data': user_data, - 'chat_data': chat_data, - 'conversations': conversations, - 'callback_data': callback_data, - } - with Path('pickletest_user_data').open('wb') as f: - pickle.dump(user_data, f) - with Path('pickletest_chat_data').open('wb') as f: - pickle.dump(chat_data, f) - with Path('pickletest_callback_data').open('wb') as f: - pickle.dump(callback_data, f) - with Path('pickletest_conversations').open('wb') as f: - pickle.dump(conversations, f) - with Path('pickletest').open('wb') as f: - pickle.dump(data, f) - yield True - - -@pytest.fixture(scope='function') -def pickle_files_wo_callback_data(user_data, chat_data, bot_data, conversations): - data = { - 'user_data': user_data, - 'chat_data': chat_data, - 'bot_data': bot_data, - 'conversations': conversations, - } - with Path('pickletest_user_data').open('wb') as f: - pickle.dump(user_data, f) - with Path('pickletest_chat_data').open('wb') as f: - pickle.dump(chat_data, f) - with Path('pickletest_bot_data').open('wb') as f: - pickle.dump(bot_data, f) - with Path('pickletest_conversations').open('wb') as f: - pickle.dump(conversations, f) - with Path('pickletest').open('wb') as f: - pickle.dump(data, f) - yield True - - -@pytest.fixture(scope='function') -def update(bot): - user = User(id=321, first_name='test_user', is_bot=False) - chat = Chat(id=123, type='group') - message = Message(1, None, chat, from_user=user, text="Hi there", bot=bot) - return Update(0, message=message) - - -class TestPicklePersistence: - def test_slot_behaviour(self, mro_slots, pickle_persistence): - inst = pickle_persistence - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - def test_pickle_behaviour_with_slots(self, pickle_persistence): - bot_data = pickle_persistence.get_bot_data() - bot_data['message'] = Message(3, None, Chat(2, type='supergroup')) - pickle_persistence.update_bot_data(bot_data) - retrieved = pickle_persistence.get_bot_data() - assert retrieved == bot_data - - def test_no_files_present_multi_file(self, pickle_persistence): - assert pickle_persistence.get_user_data() == {} - assert pickle_persistence.get_chat_data() == {} - assert pickle_persistence.get_bot_data() == {} - assert pickle_persistence.get_callback_data() is None - assert pickle_persistence.get_conversations('noname') == {} - - def test_no_files_present_single_file(self, pickle_persistence): - pickle_persistence.single_file = True - assert pickle_persistence.get_user_data() == {} - assert pickle_persistence.get_chat_data() == {} - assert pickle_persistence.get_bot_data() == {} - assert pickle_persistence.get_callback_data() is None - assert pickle_persistence.get_conversations('noname') == {} - - def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files): - with pytest.raises(TypeError, match='pickletest_user_data'): - pickle_persistence.get_user_data() - with pytest.raises(TypeError, match='pickletest_chat_data'): - pickle_persistence.get_chat_data() - with pytest.raises(TypeError, match='pickletest_bot_data'): - pickle_persistence.get_bot_data() - with pytest.raises(TypeError, match='pickletest_callback_data'): - pickle_persistence.get_callback_data() - with pytest.raises(TypeError, match='pickletest_conversations'): - pickle_persistence.get_conversations('name') - - def test_with_invalid_multi_file(self, pickle_persistence, invalid_pickle_files): - with pytest.raises(TypeError, match='pickletest_user_data does not contain'): - pickle_persistence.get_user_data() - with pytest.raises(TypeError, match='pickletest_chat_data does not contain'): - pickle_persistence.get_chat_data() - with pytest.raises(TypeError, match='pickletest_bot_data does not contain'): - pickle_persistence.get_bot_data() - with pytest.raises(TypeError, match='pickletest_callback_data does not contain'): - pickle_persistence.get_callback_data() - with pytest.raises(TypeError, match='pickletest_conversations does not contain'): - pickle_persistence.get_conversations('name') - - def test_with_bad_single_file(self, pickle_persistence, bad_pickle_files): - pickle_persistence.single_file = True - with pytest.raises(TypeError, match='pickletest'): - pickle_persistence.get_user_data() - with pytest.raises(TypeError, match='pickletest'): - pickle_persistence.get_chat_data() - with pytest.raises(TypeError, match='pickletest'): - pickle_persistence.get_bot_data() - with pytest.raises(TypeError, match='pickletest'): - pickle_persistence.get_callback_data() - with pytest.raises(TypeError, match='pickletest'): - pickle_persistence.get_conversations('name') - - def test_with_invalid_single_file(self, pickle_persistence, invalid_pickle_files): - pickle_persistence.single_file = True - with pytest.raises(TypeError, match='pickletest does not contain'): - pickle_persistence.get_user_data() - with pytest.raises(TypeError, match='pickletest does not contain'): - pickle_persistence.get_chat_data() - with pytest.raises(TypeError, match='pickletest does not contain'): - pickle_persistence.get_bot_data() - with pytest.raises(TypeError, match='pickletest does not contain'): - pickle_persistence.get_callback_data() - with pytest.raises(TypeError, match='pickletest does not contain'): - pickle_persistence.get_conversations('name') - - def test_with_good_multi_file(self, pickle_persistence, good_pickle_files): - user_data = pickle_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = pickle_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = pickle_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert bot_data['test1'] == 'test2' - assert bot_data['test3']['test4'] == 'test5' - assert 'test0' not in bot_data - - callback_data = pickle_persistence.get_callback_data() - assert isinstance(callback_data, tuple) - assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] - assert callback_data[1] == {'test1': 'test2'} - - conversation1 = pickle_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = pickle_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_with_good_single_file(self, pickle_persistence, good_pickle_files): - pickle_persistence.single_file = True - user_data = pickle_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = pickle_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = pickle_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert bot_data['test1'] == 'test2' - assert bot_data['test3']['test4'] == 'test5' - assert 'test0' not in bot_data - - callback_data = pickle_persistence.get_callback_data() - assert isinstance(callback_data, tuple) - assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] - assert callback_data[1] == {'test1': 'test2'} - - conversation1 = pickle_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = pickle_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_with_multi_file_wo_bot_data(self, pickle_persistence, pickle_files_wo_bot_data): - user_data = pickle_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = pickle_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = pickle_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert not bot_data.keys() - - callback_data = pickle_persistence.get_callback_data() - assert isinstance(callback_data, tuple) - assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] - assert callback_data[1] == {'test1': 'test2'} - - conversation1 = pickle_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = pickle_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_with_multi_file_wo_callback_data( - self, pickle_persistence, pickle_files_wo_callback_data - ): - user_data = pickle_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = pickle_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = pickle_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert bot_data['test1'] == 'test2' - assert bot_data['test3']['test4'] == 'test5' - assert 'test0' not in bot_data - - callback_data = pickle_persistence.get_callback_data() - assert callback_data is None - - conversation1 = pickle_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = pickle_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_with_single_file_wo_bot_data(self, pickle_persistence, pickle_files_wo_bot_data): - pickle_persistence.single_file = True - user_data = pickle_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = pickle_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = pickle_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert not bot_data.keys() - - callback_data = pickle_persistence.get_callback_data() - assert isinstance(callback_data, tuple) - assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] - assert callback_data[1] == {'test1': 'test2'} - - conversation1 = pickle_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = pickle_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_with_single_file_wo_callback_data( - self, pickle_persistence, pickle_files_wo_callback_data - ): - user_data = pickle_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = pickle_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = pickle_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert bot_data['test1'] == 'test2' - assert bot_data['test3']['test4'] == 'test5' - assert 'test0' not in bot_data - - callback_data = pickle_persistence.get_callback_data() - assert callback_data is None - - conversation1 = pickle_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = pickle_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_updating_multi_file(self, pickle_persistence, good_pickle_files): - user_data = pickle_persistence.get_user_data() - user_data[12345]['test3']['test4'] = 'test6' - assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(12345, user_data[12345]) - user_data[12345]['test3']['test4'] = 'test7' - assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(12345, user_data[12345]) - assert pickle_persistence.user_data == user_data - with Path('pickletest_user_data').open('rb') as f: - user_data_test = dict(pickle.load(f)) - assert user_data_test == user_data - pickle_persistence.drop_user_data(67890) - assert 67890 not in pickle_persistence.get_user_data() - - chat_data = pickle_persistence.get_chat_data() - chat_data[-12345]['test3']['test4'] = 'test6' - assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(-12345, chat_data[-12345]) - chat_data[-12345]['test3']['test4'] = 'test7' - assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(-12345, chat_data[-12345]) - assert pickle_persistence.chat_data == chat_data - with Path('pickletest_chat_data').open('rb') as f: - chat_data_test = dict(pickle.load(f)) - assert chat_data_test == chat_data - pickle_persistence.drop_chat_data(-67890) - assert -67890 not in pickle_persistence.get_chat_data() - - bot_data = pickle_persistence.get_bot_data() - bot_data['test3']['test4'] = 'test6' - assert not pickle_persistence.bot_data == bot_data - pickle_persistence.update_bot_data(bot_data) - bot_data['test3']['test4'] = 'test7' - assert not pickle_persistence.bot_data == bot_data - pickle_persistence.update_bot_data(bot_data) - assert pickle_persistence.bot_data == bot_data - with Path('pickletest_bot_data').open('rb') as f: - bot_data_test = pickle.load(f) - assert bot_data_test == bot_data - - callback_data = pickle_persistence.get_callback_data() - callback_data[1]['test3'] = 'test4' - assert not pickle_persistence.callback_data == callback_data - pickle_persistence.update_callback_data(callback_data) - callback_data[1]['test3'] = 'test5' - assert not pickle_persistence.callback_data == callback_data - pickle_persistence.update_callback_data(callback_data) - assert pickle_persistence.callback_data == callback_data - with Path('pickletest_callback_data').open('rb') as f: - callback_data_test = pickle.load(f) - assert callback_data_test == callback_data - - conversation1 = pickle_persistence.get_conversations('name1') - conversation1[(123, 123)] = 5 - assert not pickle_persistence.conversations['name1'] == conversation1 - pickle_persistence.update_conversation('name1', (123, 123), 5) - assert pickle_persistence.conversations['name1'] == conversation1 - assert pickle_persistence.get_conversations('name1') == conversation1 - with Path('pickletest_conversations').open('rb') as f: - conversations_test = dict(pickle.load(f)) - assert conversations_test['name1'] == conversation1 - - pickle_persistence.conversations = None - pickle_persistence.update_conversation('name1', (123, 123), 5) - assert pickle_persistence.conversations['name1'] == {(123, 123): 5} - assert pickle_persistence.get_conversations('name1') == {(123, 123): 5} - - def test_updating_single_file(self, pickle_persistence, good_pickle_files): - pickle_persistence.single_file = True - - user_data = pickle_persistence.get_user_data() - user_data[12345]['test3']['test4'] = 'test6' - assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(12345, user_data[12345]) - user_data[12345]['test3']['test4'] = 'test7' - assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(12345, user_data[12345]) - assert pickle_persistence.user_data == user_data - with Path('pickletest').open('rb') as f: - user_data_test = dict(pickle.load(f))['user_data'] - assert user_data_test == user_data - pickle_persistence.drop_user_data(67890) - assert 67890 not in pickle_persistence.get_user_data() - - chat_data = pickle_persistence.get_chat_data() - chat_data[-12345]['test3']['test4'] = 'test6' - assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(-12345, chat_data[-12345]) - chat_data[-12345]['test3']['test4'] = 'test7' - assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(-12345, chat_data[-12345]) - assert pickle_persistence.chat_data == chat_data - with Path('pickletest').open('rb') as f: - chat_data_test = dict(pickle.load(f))['chat_data'] - assert chat_data_test == chat_data - pickle_persistence.drop_chat_data(-67890) - assert -67890 not in pickle_persistence.get_chat_data() - - bot_data = pickle_persistence.get_bot_data() - bot_data['test3']['test4'] = 'test6' - assert not pickle_persistence.bot_data == bot_data - pickle_persistence.update_bot_data(bot_data) - bot_data['test3']['test4'] = 'test7' - assert not pickle_persistence.bot_data == bot_data - pickle_persistence.update_bot_data(bot_data) - assert pickle_persistence.bot_data == bot_data - with Path('pickletest').open('rb') as f: - bot_data_test = pickle.load(f)['bot_data'] - assert bot_data_test == bot_data - - callback_data = pickle_persistence.get_callback_data() - callback_data[1]['test3'] = 'test4' - assert not pickle_persistence.callback_data == callback_data - pickle_persistence.update_callback_data(callback_data) - callback_data[1]['test3'] = 'test5' - assert not pickle_persistence.callback_data == callback_data - pickle_persistence.update_callback_data(callback_data) - assert pickle_persistence.callback_data == callback_data - with Path('pickletest').open('rb') as f: - callback_data_test = pickle.load(f)['callback_data'] - assert callback_data_test == callback_data - - conversation1 = pickle_persistence.get_conversations('name1') - conversation1[(123, 123)] = 5 - assert not pickle_persistence.conversations['name1'] == conversation1 - pickle_persistence.update_conversation('name1', (123, 123), 5) - assert pickle_persistence.conversations['name1'] == conversation1 - assert pickle_persistence.get_conversations('name1') == conversation1 - with Path('pickletest').open('rb') as f: - conversations_test = dict(pickle.load(f))['conversations'] - assert conversations_test['name1'] == conversation1 - - pickle_persistence.conversations = None - pickle_persistence.update_conversation('name1', (123, 123), 5) - assert pickle_persistence.conversations['name1'] == {(123, 123): 5} - assert pickle_persistence.get_conversations('name1') == {(123, 123): 5} - - def test_updating_single_file_no_data(self, pickle_persistence): - pickle_persistence.single_file = True - assert not any( - [ - pickle_persistence.user_data, - pickle_persistence.chat_data, - pickle_persistence.bot_data, - pickle_persistence.callback_data, - pickle_persistence.conversations, - ] - ) - pickle_persistence.flush() - with pytest.raises(FileNotFoundError, match='pickletest'): - open('pickletest', 'rb') - - def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files): - # Should run without error - pickle_persistence.flush() - pickle_persistence.on_flush = True - - user_data = pickle_persistence.get_user_data() - user_data[54321] = {} - user_data[54321]['test9'] = 'test 10' - assert not pickle_persistence.user_data == user_data - - pickle_persistence.update_user_data(54321, user_data[54321]) - assert pickle_persistence.user_data == user_data - - pickle_persistence.drop_user_data(0) - assert pickle_persistence.user_data == user_data - - with Path('pickletest_user_data').open('rb') as f: - user_data_test = dict(pickle.load(f)) - assert not user_data_test == user_data - - chat_data = pickle_persistence.get_chat_data() - chat_data[54321] = {} - chat_data[54321]['test9'] = 'test 10' - assert not pickle_persistence.chat_data == chat_data - - pickle_persistence.update_chat_data(54321, chat_data[54321]) - assert pickle_persistence.chat_data == chat_data - - pickle_persistence.drop_chat_data(0) - assert pickle_persistence.user_data == user_data - - with Path('pickletest_chat_data').open('rb') as f: - chat_data_test = dict(pickle.load(f)) - assert not chat_data_test == chat_data - - bot_data = pickle_persistence.get_bot_data() - bot_data['test6'] = 'test 7' - assert not pickle_persistence.bot_data == bot_data - - pickle_persistence.update_bot_data(bot_data) - assert pickle_persistence.bot_data == bot_data - - with Path('pickletest_bot_data').open('rb') as f: - bot_data_test = pickle.load(f) - assert not bot_data_test == bot_data - - callback_data = pickle_persistence.get_callback_data() - callback_data[1]['test3'] = 'test4' - assert not pickle_persistence.callback_data == callback_data - - pickle_persistence.update_callback_data(callback_data) - assert pickle_persistence.callback_data == callback_data - - with Path('pickletest_callback_data').open('rb') as f: - callback_data_test = pickle.load(f) - assert not callback_data_test == callback_data - - conversation1 = pickle_persistence.get_conversations('name1') - conversation1[(123, 123)] = 5 - assert not pickle_persistence.conversations['name1'] == conversation1 - - pickle_persistence.update_conversation('name1', (123, 123), 5) - assert pickle_persistence.conversations['name1'] == conversation1 - - with Path('pickletest_conversations').open('rb') as f: - conversations_test = dict(pickle.load(f)) - assert not conversations_test['name1'] == conversation1 - - pickle_persistence.flush() - with Path('pickletest_user_data').open('rb') as f: - user_data_test = dict(pickle.load(f)) - assert user_data_test == user_data - - with Path('pickletest_chat_data').open('rb') as f: - chat_data_test = dict(pickle.load(f)) - assert chat_data_test == chat_data - - with Path('pickletest_bot_data').open('rb') as f: - bot_data_test = pickle.load(f) - assert bot_data_test == bot_data - - with Path('pickletest_conversations').open('rb') as f: - conversations_test = dict(pickle.load(f)) - assert conversations_test['name1'] == conversation1 - - def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files): - # Should run without error - pickle_persistence.flush() - - pickle_persistence.on_flush = True - pickle_persistence.single_file = True - - user_data = pickle_persistence.get_user_data() - user_data[54321] = {} - user_data[54321]['test9'] = 'test 10' - assert not pickle_persistence.user_data == user_data - pickle_persistence.update_user_data(54321, user_data[54321]) - assert pickle_persistence.user_data == user_data - with Path('pickletest').open('rb') as f: - user_data_test = dict(pickle.load(f))['user_data'] - assert not user_data_test == user_data - - chat_data = pickle_persistence.get_chat_data() - chat_data[54321] = {} - chat_data[54321]['test9'] = 'test 10' - assert not pickle_persistence.chat_data == chat_data - pickle_persistence.update_chat_data(54321, chat_data[54321]) - assert pickle_persistence.chat_data == chat_data - with Path('pickletest').open('rb') as f: - chat_data_test = dict(pickle.load(f))['chat_data'] - assert not chat_data_test == chat_data - - bot_data = pickle_persistence.get_bot_data() - bot_data['test6'] = 'test 7' - assert not pickle_persistence.bot_data == bot_data - pickle_persistence.update_bot_data(bot_data) - assert pickle_persistence.bot_data == bot_data - with Path('pickletest').open('rb') as f: - bot_data_test = pickle.load(f)['bot_data'] - assert not bot_data_test == bot_data - - callback_data = pickle_persistence.get_callback_data() - callback_data[1]['test3'] = 'test4' - assert not pickle_persistence.callback_data == callback_data - pickle_persistence.update_callback_data(callback_data) - assert pickle_persistence.callback_data == callback_data - with Path('pickletest').open('rb') as f: - callback_data_test = pickle.load(f)['callback_data'] - assert not callback_data_test == callback_data - - conversation1 = pickle_persistence.get_conversations('name1') - conversation1[(123, 123)] = 5 - assert not pickle_persistence.conversations['name1'] == conversation1 - pickle_persistence.update_conversation('name1', (123, 123), 5) - assert pickle_persistence.conversations['name1'] == conversation1 - with Path('pickletest').open('rb') as f: - conversations_test = dict(pickle.load(f))['conversations'] - assert not conversations_test['name1'] == conversation1 - - pickle_persistence.flush() - with Path('pickletest').open('rb') as f: - user_data_test = dict(pickle.load(f))['user_data'] - assert user_data_test == user_data - - with Path('pickletest').open('rb') as f: - chat_data_test = dict(pickle.load(f))['chat_data'] - assert chat_data_test == chat_data - - with Path('pickletest').open('rb') as f: - bot_data_test = pickle.load(f)['bot_data'] - assert bot_data_test == bot_data - - with Path('pickletest').open('rb') as f: - conversations_test = dict(pickle.load(f))['conversations'] - assert conversations_test['name1'] == conversation1 - - def test_with_handler(self, bot, update, bot_data, pickle_persistence, good_pickle_files): - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build() - dp = u.dispatcher - bot.callback_data_cache.clear_callback_data() - bot.callback_data_cache.clear_callback_queries() - - def first(update, context): - if not context.user_data == {}: - pytest.fail() - if not context.chat_data == {}: - pytest.fail() - if not context.bot_data == bot_data: - pytest.fail() - if not context.bot.callback_data_cache.persistence_data == ([], {}): - pytest.fail() - context.user_data['test1'] = 'test2' - context.chat_data['test3'] = 'test4' - context.bot_data['test1'] = 'test0' - context.bot.callback_data_cache._callback_queries['test1'] = 'test0' - - def second(update, context): - if not context.user_data['test1'] == 'test2': - pytest.fail() - if not context.chat_data['test3'] == 'test4': - pytest.fail() - if not context.bot_data['test1'] == 'test0': - pytest.fail() - if not context.bot.callback_data_cache.persistence_data == ([], {'test1': 'test0'}): - pytest.fail() - - h1 = MessageHandler(None, first) - h2 = MessageHandler(None, second) - dp.add_handler(h1) - dp.process_update(update) - pickle_persistence_2 = PicklePersistence( - filepath='pickletest', - single_file=False, - on_flush=False, - ) - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_2).build() - dp = u.dispatcher - dp.add_handler(h2) - dp.process_update(update) - - def test_flush_on_stop(self, bot, update, pickle_persistence): - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence).build() - dp = u.dispatcher - u.running = True - dp.user_data[4242424242]['my_test'] = 'Working!' - dp.chat_data[-4242424242]['my_test2'] = 'Working2!' - dp.bot_data['test'] = 'Working3!' - dp.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - u._signal_handler(signal.SIGINT, None) - pickle_persistence_2 = PicklePersistence( - filepath='pickletest', - single_file=False, - on_flush=False, - ) - assert pickle_persistence_2.get_user_data()[4242424242]['my_test'] == 'Working!' - assert pickle_persistence_2.get_chat_data()[-4242424242]['my_test2'] == 'Working2!' - assert pickle_persistence_2.get_bot_data()['test'] == 'Working3!' - data = pickle_persistence_2.get_callback_data()[1] - assert data['test'] == 'Working4!' - - def test_flush_on_stop_only_bot(self, bot, update, pickle_persistence_only_bot): - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_only_bot).build() - dp = u.dispatcher - u.running = True - dp.user_data[4242424242]['my_test'] = 'Working!' - dp.chat_data[-4242424242]['my_test2'] = 'Working2!' - dp.bot_data['my_test3'] = 'Working3!' - dp.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - u._signal_handler(signal.SIGINT, None) - pickle_persistence_2 = PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(callback_data=False, chat_data=False, user_data=False), - single_file=False, - on_flush=False, - ) - assert pickle_persistence_2.get_user_data() == {} - assert pickle_persistence_2.get_chat_data() == {} - assert pickle_persistence_2.get_bot_data()['my_test3'] == 'Working3!' - assert pickle_persistence_2.get_callback_data() is None - - def test_flush_on_stop_only_chat(self, bot, update, pickle_persistence_only_chat): - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_only_chat).build() - dp = u.dispatcher - u.running = True - dp.user_data[4242424242]['my_test'] = 'Working!' - dp.chat_data[-4242424242]['my_test2'] = 'Working2!' - dp.bot_data['my_test3'] = 'Working3!' - dp.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - u._signal_handler(signal.SIGINT, None) - pickle_persistence_2 = PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(callback_data=False, user_data=False, bot_data=False), - single_file=False, - on_flush=False, - ) - assert pickle_persistence_2.get_user_data() == {} - assert pickle_persistence_2.get_chat_data()[-4242424242]['my_test2'] == 'Working2!' - assert pickle_persistence_2.get_bot_data() == {} - assert pickle_persistence_2.get_callback_data() is None - - def test_flush_on_stop_only_user(self, bot, update, pickle_persistence_only_user): - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_only_user).build() - dp = u.dispatcher - u.running = True - dp.user_data[4242424242]['my_test'] = 'Working!' - dp.chat_data[-4242424242]['my_test2'] = 'Working2!' - dp.bot_data['my_test3'] = 'Working3!' - dp.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - u._signal_handler(signal.SIGINT, None) - pickle_persistence_2 = PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(callback_data=False, chat_data=False, bot_data=False), - single_file=False, - on_flush=False, - ) - assert pickle_persistence_2.get_user_data()[4242424242]['my_test'] == 'Working!' - assert pickle_persistence_2.get_chat_data() == {} - assert pickle_persistence_2.get_bot_data() == {} - assert pickle_persistence_2.get_callback_data() is None - - def test_flush_on_stop_only_callback(self, bot, update, pickle_persistence_only_callback): - u = UpdaterBuilder().bot(bot).persistence(pickle_persistence_only_callback).build() - dp = u.dispatcher - u.running = True - dp.user_data[4242424242]['my_test'] = 'Working!' - dp.chat_data[-4242424242]['my_test2'] = 'Working2!' - dp.bot_data['my_test3'] = 'Working3!' - dp.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - u._signal_handler(signal.SIGINT, None) - del dp - del u - del pickle_persistence_only_callback - pickle_persistence_2 = PicklePersistence( - filepath='pickletest', - store_data=PersistenceInput(user_data=False, chat_data=False, bot_data=False), - single_file=False, - on_flush=False, - ) - assert pickle_persistence_2.get_user_data() == {} - assert pickle_persistence_2.get_chat_data() == {} - assert pickle_persistence_2.get_bot_data() == {} - data = pickle_persistence_2.get_callback_data()[1] - assert data['test'] == 'Working4!' - - def test_with_conversation_handler(self, dp, update, good_pickle_files, pickle_persistence): - dp.persistence = pickle_persistence - NEXT, NEXT2 = range(2) - - def start(update, context): - return NEXT - - start = CommandHandler('start', start) - - def next_callback(update, context): - return NEXT2 - - next_handler = MessageHandler(None, next_callback) - - def next2(update, context): - return ConversationHandler.END - - next2 = MessageHandler(None, next2) - - ch = ConversationHandler( - [start], {NEXT: [next_handler], NEXT2: [next2]}, [], name='name2', persistent=True - ) - dp.add_handler(ch) - assert ch.conversations[ch._get_key(update)] == 1 - dp.process_update(update) - assert ch._get_key(update) not in ch.conversations - update.message.text = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - dp.process_update(update) - assert ch.conversations[ch._get_key(update)] == 0 - assert ch.conversations == pickle_persistence.conversations['name2'] - - def test_with_nested_conversationHandler( - self, dp, update, good_pickle_files, pickle_persistence - ): - dp.persistence = pickle_persistence - NEXT2, NEXT3 = range(1, 3) - - def start(update, context): - return NEXT2 - - start = CommandHandler('start', start) - - def next_callback(update, context): - return NEXT2 - - next_handler = MessageHandler(None, next_callback) - - def next2(update, context): - return ConversationHandler.END - - next2 = MessageHandler(None, next2) - - nested_ch = ConversationHandler( - [next_handler], - {NEXT2: [next2]}, - [], - name='name3', - persistent=True, - map_to_parent={ConversationHandler.END: ConversationHandler.END}, - ) - - ch = ConversationHandler( - [start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2', persistent=True - ) - dp.add_handler(ch) - assert ch.conversations[ch._get_key(update)] == 1 - assert nested_ch.conversations[nested_ch._get_key(update)] == 1 - dp.process_update(update) - assert ch._get_key(update) not in ch.conversations - assert nested_ch._get_key(update) not in nested_ch.conversations - update.message.text = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - dp.process_update(update) - assert ch.conversations[ch._get_key(update)] == 1 - assert ch.conversations == pickle_persistence.conversations['name2'] - assert nested_ch._get_key(update) not in nested_ch.conversations - dp.process_update(update) - assert ch.conversations[ch._get_key(update)] == 1 - assert ch.conversations == pickle_persistence.conversations['name2'] - assert nested_ch.conversations[nested_ch._get_key(update)] == 1 - assert nested_ch.conversations == pickle_persistence.conversations['name3'] - - @pytest.mark.parametrize( - 'filepath', - ['pickletest', Path('pickletest')], - ids=['str filepath', 'pathlib.Path filepath'], - ) - def test_filepath_argument_types(self, filepath): - pick_persist = PicklePersistence( - filepath=filepath, - on_flush=False, - ) - pick_persist.update_user_data(1, 1) - - assert pick_persist.get_user_data()[1] == 1 - assert Path(filepath).is_file() - - def test_with_job(self, job_queue, dp, pickle_persistence): - dp.bot.arbitrary_callback_data = True - - def job_callback(context): - context.bot_data['test1'] = '456' - context.dispatcher.chat_data[123]['test2'] = '789' - context.dispatcher.user_data[789]['test3'] = '123' - context.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - - dp.persistence = pickle_persistence - job_queue.set_dispatcher(dp) - job_queue.start() - job_queue.run_once(job_callback, 0.01) - sleep(0.5) - bot_data = pickle_persistence.get_bot_data() - assert bot_data == {'test1': '456'} - chat_data = pickle_persistence.get_chat_data() - assert chat_data[123] == {'test2': '789'} - user_data = pickle_persistence.get_user_data() - assert user_data[789] == {'test3': '123'} - data = pickle_persistence.get_callback_data()[1] - assert data['test'] == 'Working4!' - - @pytest.mark.parametrize('singlefile', [True, False]) - @pytest.mark.parametrize('ud', [int, float, complex]) - @pytest.mark.parametrize('cd', [int, float, complex]) - @pytest.mark.parametrize('bd', [int, float, complex]) - def test_with_context_types(self, ud, cd, bd, singlefile): - cc = ContextTypes(user_data=ud, chat_data=cd, bot_data=bd) - persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc) - - assert isinstance(persistence.get_bot_data(), bd) - assert persistence.get_bot_data() == 0 - - persistence.user_data = None - persistence.chat_data = None - persistence.drop_user_data(123) - persistence.drop_chat_data(123) - assert isinstance(persistence.get_user_data(), dict) - assert isinstance(persistence.get_chat_data(), dict) - persistence.user_data = None - persistence.chat_data = None - persistence.update_user_data(1, ud(1)) - persistence.update_chat_data(1, cd(1)) - persistence.update_bot_data(bd(1)) - assert persistence.get_user_data()[1] == 1 - assert persistence.get_chat_data()[1] == 1 - assert persistence.get_bot_data() == 1 - - persistence.flush() - persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc) - assert isinstance(persistence.get_user_data()[1], ud) - assert persistence.get_user_data()[1] == 1 - assert isinstance(persistence.get_chat_data()[1], cd) - assert persistence.get_chat_data()[1] == 1 - assert isinstance(persistence.get_bot_data(), bd) - assert persistence.get_bot_data() == 1 - - -@pytest.fixture(scope='function') -def user_data_json(user_data): - return json.dumps(user_data) - - -@pytest.fixture(scope='function') -def chat_data_json(chat_data): - return json.dumps(chat_data) - - -@pytest.fixture(scope='function') -def bot_data_json(bot_data): - return json.dumps(bot_data) - - -@pytest.fixture(scope='function') -def callback_data_json(callback_data): - return json.dumps(callback_data) - - -@pytest.fixture(scope='function') -def conversations_json(conversations): - return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2": - {"[123, 321]": 1, "[890, 890]": 2}, "name3": - {"[123, 321]": 1, "[890, 890]": 2}}""" - - -class TestDictPersistence: - def test_slot_behaviour(self, mro_slots, recwarn): - inst = DictPersistence() - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - def test_no_json_given(self): - dict_persistence = DictPersistence() - assert dict_persistence.get_user_data() == {} - assert dict_persistence.get_chat_data() == {} - assert dict_persistence.get_bot_data() == {} - assert dict_persistence.get_callback_data() is None - assert dict_persistence.get_conversations('noname') == {} - - def test_bad_json_string_given(self): - bad_user_data = 'thisisnojson99900()))(' - bad_chat_data = 'thisisnojson99900()))(' - bad_bot_data = 'thisisnojson99900()))(' - bad_callback_data = 'thisisnojson99900()))(' - bad_conversations = 'thisisnojson99900()))(' - with pytest.raises(TypeError, match='user_data'): - DictPersistence(user_data_json=bad_user_data) - with pytest.raises(TypeError, match='chat_data'): - DictPersistence(chat_data_json=bad_chat_data) - with pytest.raises(TypeError, match='bot_data'): - DictPersistence(bot_data_json=bad_bot_data) - with pytest.raises(TypeError, match='callback_data'): - DictPersistence(callback_data_json=bad_callback_data) - with pytest.raises(TypeError, match='conversations'): - DictPersistence(conversations_json=bad_conversations) - - def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files): - bad_user_data = '["this", "is", "json"]' - bad_chat_data = '["this", "is", "json"]' - bad_bot_data = '["this", "is", "json"]' - bad_conversations = '["this", "is", "json"]' - bad_callback_data_1 = '[[["str", 3.14, {"di": "ct"}]], "is"]' - bad_callback_data_2 = '[[["str", "non-float", {"di": "ct"}]], {"di": "ct"}]' - bad_callback_data_3 = '[[[{"not": "a str"}, 3.14, {"di": "ct"}]], {"di": "ct"}]' - bad_callback_data_4 = '[[["wrong", "length"]], {"di": "ct"}]' - bad_callback_data_5 = '["this", "is", "json"]' - with pytest.raises(TypeError, match='user_data'): - DictPersistence(user_data_json=bad_user_data) - with pytest.raises(TypeError, match='chat_data'): - DictPersistence(chat_data_json=bad_chat_data) - with pytest.raises(TypeError, match='bot_data'): - DictPersistence(bot_data_json=bad_bot_data) - for bad_callback_data in [ - bad_callback_data_1, - bad_callback_data_2, - bad_callback_data_3, - bad_callback_data_4, - bad_callback_data_5, - ]: - with pytest.raises(TypeError, match='callback_data'): - DictPersistence(callback_data_json=bad_callback_data) - with pytest.raises(TypeError, match='conversations'): - DictPersistence(conversations_json=bad_conversations) - - def test_good_json_input( - self, user_data_json, chat_data_json, bot_data_json, conversations_json, callback_data_json - ): - dict_persistence = DictPersistence( - user_data_json=user_data_json, - chat_data_json=chat_data_json, - bot_data_json=bot_data_json, - conversations_json=conversations_json, - callback_data_json=callback_data_json, - ) - user_data = dict_persistence.get_user_data() - assert isinstance(user_data, dict) - assert user_data[12345]['test1'] == 'test2' - assert user_data[67890][3] == 'test4' - - chat_data = dict_persistence.get_chat_data() - assert isinstance(chat_data, dict) - assert chat_data[-12345]['test1'] == 'test2' - assert chat_data[-67890][3] == 'test4' - - bot_data = dict_persistence.get_bot_data() - assert isinstance(bot_data, dict) - assert bot_data['test1'] == 'test2' - assert bot_data['test3']['test4'] == 'test5' - assert 'test6' not in bot_data - - callback_data = dict_persistence.get_callback_data() - - assert isinstance(callback_data, tuple) - assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] - assert callback_data[1] == {'test1': 'test2'} - - conversation1 = dict_persistence.get_conversations('name1') - assert isinstance(conversation1, dict) - assert conversation1[(123, 123)] == 3 - assert conversation1[(456, 654)] == 4 - with pytest.raises(KeyError): - conversation1[(890, 890)] - conversation2 = dict_persistence.get_conversations('name2') - assert isinstance(conversation1, dict) - assert conversation2[(123, 321)] == 1 - assert conversation2[(890, 890)] == 2 - with pytest.raises(KeyError): - conversation2[(123, 123)] - - def test_good_json_input_callback_data_none(self): - dict_persistence = DictPersistence(callback_data_json='null') - assert dict_persistence.callback_data is None - assert dict_persistence.callback_data_json == 'null' - - def test_dict_outputs( - self, - user_data, - user_data_json, - chat_data, - chat_data_json, - bot_data, - bot_data_json, - callback_data_json, - conversations, - conversations_json, - ): - dict_persistence = DictPersistence( - user_data_json=user_data_json, - chat_data_json=chat_data_json, - bot_data_json=bot_data_json, - callback_data_json=callback_data_json, - conversations_json=conversations_json, - ) - assert dict_persistence.user_data == user_data - assert dict_persistence.chat_data == chat_data - assert dict_persistence.bot_data == bot_data - assert dict_persistence.bot_data == bot_data - assert dict_persistence.conversations == conversations - - def test_json_outputs( - self, user_data_json, chat_data_json, bot_data_json, callback_data_json, conversations_json - ): - dict_persistence = DictPersistence( - user_data_json=user_data_json, - chat_data_json=chat_data_json, - bot_data_json=bot_data_json, - callback_data_json=callback_data_json, - conversations_json=conversations_json, - ) - assert dict_persistence.user_data_json == user_data_json - assert dict_persistence.chat_data_json == chat_data_json - assert dict_persistence.callback_data_json == callback_data_json - assert dict_persistence.conversations_json == conversations_json - - def test_updating( - self, - user_data_json, - chat_data_json, - bot_data_json, - callback_data, - callback_data_json, - conversations, - conversations_json, - ): - dict_persistence = DictPersistence( - user_data_json=user_data_json, - chat_data_json=chat_data_json, - bot_data_json=bot_data_json, - callback_data_json=callback_data_json, - conversations_json=conversations_json, - ) - - user_data = dict_persistence.get_user_data() - user_data[12345]['test3']['test4'] = 'test6' - assert not dict_persistence.user_data == user_data - assert not dict_persistence.user_data_json == json.dumps(user_data) - dict_persistence.update_user_data(12345, user_data[12345]) - user_data[12345]['test3']['test4'] = 'test7' - assert not dict_persistence.user_data == user_data - assert not dict_persistence.user_data_json == json.dumps(user_data) - dict_persistence.update_user_data(12345, user_data[12345]) - assert dict_persistence.user_data == user_data - assert dict_persistence.user_data_json == json.dumps(user_data) - dict_persistence.drop_user_data(67890) - assert 67890 not in dict_persistence.user_data - dict_persistence._user_data = None - dict_persistence.drop_user_data(123) - assert isinstance(dict_persistence.get_user_data(), dict) - - chat_data = dict_persistence.get_chat_data() - chat_data[-12345]['test3']['test4'] = 'test6' - assert not dict_persistence.chat_data == chat_data - assert not dict_persistence.chat_data_json == json.dumps(chat_data) - dict_persistence.update_chat_data(-12345, chat_data[-12345]) - chat_data[-12345]['test3']['test4'] = 'test7' - assert not dict_persistence.chat_data == chat_data - assert not dict_persistence.chat_data_json == json.dumps(chat_data) - dict_persistence.update_chat_data(-12345, chat_data[-12345]) - assert dict_persistence.chat_data == chat_data - assert dict_persistence.chat_data_json == json.dumps(chat_data) - dict_persistence.drop_chat_data(-67890) - assert -67890 not in dict_persistence.chat_data - dict_persistence._chat_data = None - dict_persistence.drop_chat_data(123) - assert isinstance(dict_persistence.get_chat_data(), dict) - - bot_data = dict_persistence.get_bot_data() - bot_data['test3']['test4'] = 'test6' - assert not dict_persistence.bot_data == bot_data - assert not dict_persistence.bot_data_json == json.dumps(bot_data) - dict_persistence.update_bot_data(bot_data) - bot_data['test3']['test4'] = 'test7' - assert not dict_persistence.bot_data == bot_data - assert not dict_persistence.bot_data_json == json.dumps(bot_data) - dict_persistence.update_bot_data(bot_data) - assert dict_persistence.bot_data == bot_data - assert dict_persistence.bot_data_json == json.dumps(bot_data) - - callback_data = dict_persistence.get_callback_data() - callback_data[1]['test3'] = 'test4' - callback_data[0][0][2]['button2'] = 'test41' - assert not dict_persistence.callback_data == callback_data - assert not dict_persistence.callback_data_json == json.dumps(callback_data) - dict_persistence.update_callback_data(callback_data) - callback_data[1]['test3'] = 'test5' - callback_data[0][0][2]['button2'] = 'test42' - assert not dict_persistence.callback_data == callback_data - assert not dict_persistence.callback_data_json == json.dumps(callback_data) - dict_persistence.update_callback_data(callback_data) - assert dict_persistence.callback_data == callback_data - assert dict_persistence.callback_data_json == json.dumps(callback_data) - - conversation1 = dict_persistence.get_conversations('name1') - conversation1[(123, 123)] = 5 - assert not dict_persistence.conversations['name1'] == conversation1 - dict_persistence.update_conversation('name1', (123, 123), 5) - assert dict_persistence.conversations['name1'] == conversation1 - conversations['name1'][(123, 123)] = 5 - assert ( - dict_persistence.conversations_json - == DictPersistence._encode_conversations_to_json(conversations) - ) - assert dict_persistence.get_conversations('name1') == conversation1 - - dict_persistence._conversations = None - dict_persistence.update_conversation('name1', (123, 123), 5) - assert dict_persistence.conversations['name1'] == {(123, 123): 5} - assert dict_persistence.get_conversations('name1') == {(123, 123): 5} - assert ( - dict_persistence.conversations_json - == DictPersistence._encode_conversations_to_json({"name1": {(123, 123): 5}}) - ) - - def test_with_handler(self, bot, update): - dict_persistence = DictPersistence() - u = UpdaterBuilder().bot(bot).persistence(dict_persistence).build() - dp = u.dispatcher - - def first(update, context): - if not context.user_data == {}: - pytest.fail() - if not context.chat_data == {}: - pytest.fail() - if not context.bot_data == {}: - pytest.fail() - if not context.bot.callback_data_cache.persistence_data == ([], {}): - pytest.fail() - context.user_data['test1'] = 'test2' - context.chat_data[3] = 'test4' - context.bot_data['test1'] = 'test0' - context.bot.callback_data_cache._callback_queries['test1'] = 'test0' - - def second(update, context): - if not context.user_data['test1'] == 'test2': - pytest.fail() - if not context.chat_data[3] == 'test4': - pytest.fail() - if not context.bot_data['test1'] == 'test0': - pytest.fail() - if not context.bot.callback_data_cache.persistence_data == ([], {'test1': 'test0'}): - pytest.fail() - - h1 = MessageHandler(filters.ALL, first) - h2 = MessageHandler(filters.ALL, second) - dp.add_handler(h1) - dp.process_update(update) - user_data = dict_persistence.user_data_json - chat_data = dict_persistence.chat_data_json - bot_data = dict_persistence.bot_data_json - callback_data = dict_persistence.callback_data_json - dict_persistence_2 = DictPersistence( - user_data_json=user_data, - chat_data_json=chat_data, - bot_data_json=bot_data, - callback_data_json=callback_data, - ) - - u = UpdaterBuilder().bot(bot).persistence(dict_persistence_2).build() - dp = u.dispatcher - dp.add_handler(h2) - dp.process_update(update) - - def test_with_conversationHandler(self, dp, update, conversations_json): - dict_persistence = DictPersistence(conversations_json=conversations_json) - dp.persistence = dict_persistence - NEXT, NEXT2 = range(2) - - def start(update, context): - return NEXT - - start = CommandHandler('start', start) - - def next_callback(update, context): - return NEXT2 - - next_handler = MessageHandler(None, next_callback) - - def next2(update, context): - return ConversationHandler.END - - next2 = MessageHandler(None, next2) - - ch = ConversationHandler( - [start], {NEXT: [next_handler], NEXT2: [next2]}, [], name='name2', persistent=True - ) - dp.add_handler(ch) - assert ch.conversations[ch._get_key(update)] == 1 - dp.process_update(update) - assert ch._get_key(update) not in ch.conversations - update.message.text = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - dp.process_update(update) - assert ch.conversations[ch._get_key(update)] == 0 - assert ch.conversations == dict_persistence.conversations['name2'] - - def test_with_nested_conversationHandler(self, dp, update, conversations_json): - dict_persistence = DictPersistence(conversations_json=conversations_json) - dp.persistence = dict_persistence - NEXT2, NEXT3 = range(1, 3) - - def start(update, context): - return NEXT2 - - start = CommandHandler('start', start) - - def next_callback(update, context): - return NEXT2 - - next_handler = MessageHandler(None, next_callback) - - def next2(update, context): - return ConversationHandler.END - - next2 = MessageHandler(None, next2) - - nested_ch = ConversationHandler( - [next_handler], - {NEXT2: [next2]}, - [], - name='name3', - persistent=True, - map_to_parent={ConversationHandler.END: ConversationHandler.END}, - ) - - ch = ConversationHandler( - [start], {NEXT2: [nested_ch], NEXT3: []}, [], name='name2', persistent=True - ) - dp.add_handler(ch) - assert ch.conversations[ch._get_key(update)] == 1 - assert nested_ch.conversations[nested_ch._get_key(update)] == 1 - dp.process_update(update) - assert ch._get_key(update) not in ch.conversations - assert nested_ch._get_key(update) not in nested_ch.conversations - update.message.text = '/start' - update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] - dp.process_update(update) - assert ch.conversations[ch._get_key(update)] == 1 - assert ch.conversations == dict_persistence.conversations['name2'] - assert nested_ch._get_key(update) not in nested_ch.conversations - dp.process_update(update) - assert ch.conversations[ch._get_key(update)] == 1 - assert ch.conversations == dict_persistence.conversations['name2'] - assert nested_ch.conversations[nested_ch._get_key(update)] == 1 - assert nested_ch.conversations == dict_persistence.conversations['name3'] - - def test_with_job(self, job_queue, dp): - dp.bot.arbitrary_callback_data = True - - def job_callback(context): - context.bot_data['test1'] = '456' - context.dispatcher.chat_data[123]['test2'] = '789' - context.dispatcher.user_data[789]['test3'] = '123' - context.bot.callback_data_cache._callback_queries['test'] = 'Working4!' - - dict_persistence = DictPersistence() - dp.persistence = dict_persistence - job_queue.set_dispatcher(dp) - job_queue.start() - job_queue.run_once(job_callback, 0.01) - sleep(0.8) - bot_data = dict_persistence.get_bot_data() - assert bot_data == {'test1': '456'} - chat_data = dict_persistence.get_chat_data() - assert chat_data[123] == {'test2': '789'} - user_data = dict_persistence.get_user_data() - assert user_data[789] == {'test3': '123'} - data = dict_persistence.get_callback_data()[1] - assert data['test'] == 'Working4!' diff --git a/tests/test_photo.py b/tests/test_photo.py index ac4e5fdc748..27eb9dcd0ba 100644 --- a/tests/test_photo.py +++ b/tests/test_photo.py @@ -25,6 +25,7 @@ from telegram import Sticker, PhotoSize, InputFile, MessageEntity, Bot from telegram.error import BadRequest, TelegramError from telegram.helpers import escape_markdown +from telegram.request import RequestData from tests.conftest import ( expect_bad_request, check_shortcut_call, @@ -42,12 +43,16 @@ def photo_file(): @pytest.fixture(scope='class') -def _photo(bot, chat_id): - def func(): +@pytest.mark.asyncio +async def _photo(bot, chat_id): + async def func(): with data_file('telegram.jpg').open('rb') as f: - return bot.send_photo(chat_id, photo=f, timeout=50).photo + photo = (await bot.send_photo(chat_id, photo=f, read_timeout=50)).photo + return photo - return expect_bad_request(func, 'Type of file mismatch', 'Telegram did not accept the file.') + return await expect_bad_request( + func, 'Type of file mismatch', 'Telegram did not accept the file.' + ) @pytest.fixture(scope='class') @@ -57,20 +62,18 @@ def thumb(_photo): @pytest.fixture(scope='class') def photo(_photo): - return _photo[1] + print([ps.to_json() for ps in _photo]) + return _photo[-1] class TestPhoto: - width = 320 - height = 320 + width = 800 + height = 800 caption = 'PhotoTest - *Caption*' - photo_file_url = 'https://python-telegram-bot.org/static/testfiles/telegram.jpg' - file_size = 29176 - - def test_slot_behaviour(self, photo, mro_slots): - for attr in photo.__slots__: - assert getattr(photo, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(photo)) == len(set(mro_slots(photo))), "duplicate slot" + photo_file_url = 'https://python-telegram-bot.org/static/testfiles/telegram_new.jpg' + # For some reason the file size is not the same after switching to httpx + # so we accept three different sizes here. Shouldn't be too much + file_size = [29176, 27662] def test_creation(self, thumb, photo): # Make sure file has been uploaded. @@ -87,16 +90,17 @@ def test_creation(self, thumb, photo): assert thumb.file_unique_id != '' def test_expected_values(self, photo, thumb): - # We used to test for file_size as well, but TG apparently at some point apparently changed - # the compression method and it's not really our job anyway ... assert photo.width == self.width assert photo.height == self.height + assert photo.file_size in self.file_size assert thumb.width == 90 assert thumb.height == 90 + assert thumb.file_size == 1477 @flaky(3, 1) - def test_send_photo_all_args(self, bot, chat_id, photo_file, thumb, photo): - message = bot.send_photo( + @pytest.mark.asyncio + async def test_send_photo_all_args(self, bot, chat_id, photo_file, thumb, photo): + message = await bot.send_photo( chat_id, photo_file, caption=self.caption, @@ -105,93 +109,83 @@ def test_send_photo_all_args(self, bot, chat_id, photo_file, thumb, photo): parse_mode='Markdown', ) - assert isinstance(message.photo[0], PhotoSize) - assert isinstance(message.photo[0].file_id, str) - assert isinstance(message.photo[0].file_unique_id, str) - assert message.photo[0].file_id != '' - assert message.photo[0].file_unique_id != '' - assert message.photo[0].width == thumb.width - assert message.photo[0].height == thumb.height - assert message.photo[0].file_size == thumb.file_size - - assert isinstance(message.photo[1], PhotoSize) - assert isinstance(message.photo[1].file_id, str) - assert isinstance(message.photo[1].file_unique_id, str) - assert message.photo[1].file_id != '' - assert message.photo[1].file_unique_id != '' - assert message.photo[1].width == photo.width - assert message.photo[1].height == photo.height - assert message.photo[1].file_size == photo.file_size + assert isinstance(message.photo[-2], PhotoSize) + assert isinstance(message.photo[-2].file_id, str) + assert isinstance(message.photo[-2].file_unique_id, str) + assert message.photo[-2].file_id != '' + assert message.photo[-2].file_unique_id != '' + + assert isinstance(message.photo[-1], PhotoSize) + assert isinstance(message.photo[-1].file_id, str) + assert isinstance(message.photo[-1].file_unique_id, str) + assert message.photo[-1].file_id != '' + assert message.photo[-1].file_unique_id != '' assert message.caption == TestPhoto.caption.replace('*', '') assert message.has_protected_content @flaky(3, 1) - def test_send_photo_custom_filename(self, bot, chat_id, photo_file, monkeypatch): - def make_assertion(url, data, **kwargs): - return data['photo'].filename == 'custom_filename' + @pytest.mark.asyncio + async def test_send_photo_custom_filename(self, bot, chat_id, photo_file, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return list(request_data.multipart_data.values())[0][0] == 'custom_filename' monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_photo(chat_id, photo_file, filename='custom_filename') + assert await bot.send_photo(chat_id, photo_file, filename='custom_filename') @flaky(3, 1) - def test_send_photo_parse_mode_markdown(self, bot, chat_id, photo_file, thumb, photo): - message = bot.send_photo(chat_id, photo_file, caption=self.caption, parse_mode='Markdown') - assert isinstance(message.photo[0], PhotoSize) - assert isinstance(message.photo[0].file_id, str) - assert isinstance(message.photo[0].file_unique_id, str) - assert message.photo[0].file_id != '' - assert message.photo[0].file_unique_id != '' - assert message.photo[0].width == thumb.width - assert message.photo[0].height == thumb.height - assert message.photo[0].file_size == thumb.file_size - - assert isinstance(message.photo[1], PhotoSize) - assert isinstance(message.photo[1].file_id, str) - assert isinstance(message.photo[1].file_unique_id, str) - assert message.photo[1].file_id != '' - assert message.photo[1].file_unique_id != '' - assert message.photo[1].width == photo.width - assert message.photo[1].height == photo.height - assert message.photo[1].file_size == photo.file_size + @pytest.mark.asyncio + async def test_send_photo_parse_mode_markdown(self, bot, chat_id, photo_file, thumb, photo): + message = await bot.send_photo( + chat_id, photo_file, caption=self.caption, parse_mode='Markdown' + ) + assert isinstance(message.photo[-2], PhotoSize) + assert isinstance(message.photo[-2].file_id, str) + assert isinstance(message.photo[-2].file_unique_id, str) + assert message.photo[-2].file_id != '' + assert message.photo[-2].file_unique_id != '' + + assert isinstance(message.photo[-1], PhotoSize) + assert isinstance(message.photo[-1].file_id, str) + assert isinstance(message.photo[-1].file_unique_id, str) + assert message.photo[-1].file_id != '' + assert message.photo[-1].file_unique_id != '' assert message.caption == TestPhoto.caption.replace('*', '') assert len(message.caption_entities) == 1 @flaky(3, 1) - def test_send_photo_parse_mode_html(self, bot, chat_id, photo_file, thumb, photo): - message = bot.send_photo(chat_id, photo_file, caption=self.caption, parse_mode='HTML') - assert isinstance(message.photo[0], PhotoSize) - assert isinstance(message.photo[0].file_id, str) - assert isinstance(message.photo[0].file_unique_id, str) - assert message.photo[0].file_id != '' - assert message.photo[0].file_unique_id != '' - assert message.photo[0].width == thumb.width - assert message.photo[0].height == thumb.height - assert message.photo[0].file_size == thumb.file_size - - assert isinstance(message.photo[1], PhotoSize) - assert isinstance(message.photo[1].file_id, str) - assert isinstance(message.photo[1].file_unique_id, str) - assert message.photo[1].file_id != '' - assert message.photo[1].file_unique_id != '' - assert message.photo[1].width == photo.width - assert message.photo[1].height == photo.height - assert message.photo[1].file_size == photo.file_size + @pytest.mark.asyncio + async def test_send_photo_parse_mode_html(self, bot, chat_id, photo_file, thumb, photo): + message = await bot.send_photo( + chat_id, photo_file, caption=self.caption, parse_mode='HTML' + ) + assert isinstance(message.photo[-2], PhotoSize) + assert isinstance(message.photo[-2].file_id, str) + assert isinstance(message.photo[-2].file_unique_id, str) + assert message.photo[-2].file_id != '' + assert message.photo[-2].file_unique_id != '' + + assert isinstance(message.photo[-1], PhotoSize) + assert isinstance(message.photo[-1].file_id, str) + assert isinstance(message.photo[-1].file_unique_id, str) + assert message.photo[-1].file_id != '' + assert message.photo[-1].file_unique_id != '' assert message.caption == TestPhoto.caption.replace('', '').replace('', '') assert len(message.caption_entities) == 1 @flaky(3, 1) - def test_send_photo_caption_entities(self, bot, chat_id, photo_file, thumb, photo): + @pytest.mark.asyncio + async def test_send_photo_caption_entities(self, bot, chat_id, photo_file, thumb, photo): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_photo( + message = await bot.send_photo( chat_id, photo_file, caption=test_string, caption_entities=entities ) @@ -200,20 +194,26 @@ def test_send_photo_caption_entities(self, bot, chat_id, photo_file, thumb, phot @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_photo_default_parse_mode_1(self, default_bot, chat_id, photo_file, thumb, photo): + @pytest.mark.asyncio + async def test_send_photo_default_parse_mode_1( + self, default_bot, chat_id, photo_file, thumb, photo + ): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_photo(chat_id, photo_file, caption=test_markdown_string) + message = await default_bot.send_photo(chat_id, photo_file, caption=test_markdown_string) assert message.caption_markdown == test_markdown_string assert message.caption == test_string @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_photo_default_parse_mode_2(self, default_bot, chat_id, photo_file, thumb, photo): + @pytest.mark.asyncio + async def test_send_photo_default_parse_mode_2( + self, default_bot, chat_id, photo_file, thumb, photo + ): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_photo( + message = await default_bot.send_photo( chat_id, photo_file, caption=test_markdown_string, parse_mode=None ) assert message.caption == test_markdown_string @@ -221,37 +221,41 @@ def test_send_photo_default_parse_mode_2(self, default_bot, chat_id, photo_file, @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_photo_default_parse_mode_3(self, default_bot, chat_id, photo_file, thumb, photo): + @pytest.mark.asyncio + async def test_send_photo_default_parse_mode_3( + self, default_bot, chat_id, photo_file, thumb, photo + ): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_photo( + message = await default_bot.send_photo( chat_id, photo_file, caption=test_markdown_string, parse_mode='HTML' ) assert message.caption == test_markdown_string assert message.caption_markdown == escape_markdown(test_markdown_string) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_photo_default_protect_content(self, chat_id, default_bot, photo): - protected = default_bot.send_photo(chat_id, photo) + async def test_send_photo_default_protect_content(self, chat_id, default_bot, photo): + protected = await default_bot.send_photo(chat_id, photo) assert protected.has_protected_content - unprotected = default_bot.send_photo(chat_id, photo, protect_content=False) + unprotected = await default_bot.send_photo(chat_id, photo, protect_content=False) assert not unprotected.has_protected_content - def test_send_photo_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_photo_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('photo') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_photo(chat_id, file) + await bot.send_photo(chat_id, file) assert test_flag - monkeypatch.delattr(bot, '_post') @flaky(3, 1) @pytest.mark.parametrize( @@ -263,13 +267,14 @@ def make_assertion(_, data, *args, **kwargs): ], indirect=['default_bot'], ) - def test_send_photo_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_photo_default_allow_sending_without_reply( self, default_bot, chat_id, photo_file, thumb, photo, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_photo( + message = await default_bot.send_photo( chat_id, photo_file, allow_sending_without_reply=custom, @@ -277,51 +282,54 @@ def test_send_photo_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_photo( + message = await default_bot.send_photo( chat_id, photo_file, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_photo( + await default_bot.send_photo( chat_id, photo_file, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) - def test_get_and_download(self, bot, photo): - new_file = bot.getFile(photo.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, photo): + path = Path('telegram.jpg') + if path.is_file(): + path.unlink() + + new_file = await bot.getFile(photo.file_id) assert new_file.file_size == photo.file_size assert new_file.file_unique_id == photo.file_unique_id assert new_file.file_path.startswith('https://') is True - new_file.download('telegram.jpg') + await new_file.download('telegram.jpg') - assert Path('telegram.jpg').is_file() + assert path.is_file() @flaky(3, 1) - def test_send_url_jpg_file(self, bot, chat_id, thumb, photo): - message = bot.send_photo(chat_id, photo=self.photo_file_url) - - assert isinstance(message.photo[0], PhotoSize) - assert isinstance(message.photo[0].file_id, str) - assert isinstance(message.photo[0].file_unique_id, str) - assert message.photo[0].file_id != '' - assert message.photo[0].file_unique_id != '' - # We used to test for width, height and file_size, but TG apparently started to treat - # sending by URL and sending by upload differently and it's not really our job anyway ... - - assert isinstance(message.photo[1], PhotoSize) - assert isinstance(message.photo[1].file_id, str) - assert isinstance(message.photo[1].file_unique_id, str) - assert message.photo[1].file_id != '' - assert message.photo[1].file_unique_id != '' - # We used to test for width, height and file_size, but TG apparently started to treat - # sending by URL and sending by upload differently and it's not really our job anyway ... + @pytest.mark.asyncio + async def test_send_url_jpg_file(self, bot, chat_id, thumb, photo): + message = await bot.send_photo(chat_id, photo=self.photo_file_url) + + assert isinstance(message.photo[-2], PhotoSize) + assert isinstance(message.photo[-2].file_id, str) + assert isinstance(message.photo[-2].file_unique_id, str) + assert message.photo[-2].file_id != '' + assert message.photo[-2].file_unique_id != '' + + assert isinstance(message.photo[-1], PhotoSize) + assert isinstance(message.photo[-1].file_id, str) + assert isinstance(message.photo[-1].file_unique_id, str) + assert message.photo[-1].file_id != '' + assert message.photo[-1].file_unique_id != '' @flaky(3, 1) - def test_send_url_png_file(self, bot, chat_id): - message = bot.send_photo( + @pytest.mark.asyncio + async def test_send_url_png_file(self, bot, chat_id): + message = await bot.send_photo( photo='http://dummyimage.com/600x400/000/fff.png&text=telegram', chat_id=chat_id ) @@ -334,8 +342,9 @@ def test_send_url_png_file(self, bot, chat_id): assert photo.file_unique_id != '' @flaky(3, 1) - def test_send_url_gif_file(self, bot, chat_id): - message = bot.send_photo( + @pytest.mark.asyncio + async def test_send_url_gif_file(self, bot, chat_id): + message = await bot.send_photo( photo='http://dummyimage.com/600x400/000/fff.png&text=telegram', chat_id=chat_id ) @@ -348,12 +357,13 @@ def test_send_url_gif_file(self, bot, chat_id): assert photo.file_unique_id != '' @flaky(3, 1) - def test_send_file_unicode_filename(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_file_unicode_filename(self, bot, chat_id): """ Regression test for https://github.com/python-telegram-bot/python-telegram-bot/issues/1202 """ with data_file('测试.png').open('rb') as f: - message = bot.send_photo(photo=f, chat_id=chat_id) + message = await bot.send_photo(photo=f, chat_id=chat_id) photo = message.photo[-1] @@ -364,7 +374,8 @@ def test_send_file_unicode_filename(self, bot, chat_id): assert photo.file_unique_id != '' @flaky(3, 1) - def test_send_bytesio_jpg_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_bytesio_jpg_file(self, bot, chat_id): filepath = data_file('telegram_no_standard_header.jpg') # raw image bytes @@ -380,7 +391,7 @@ def test_send_bytesio_jpg_file(self, bot, chat_id): # send raw photo raw_bytes = BytesIO(filepath.read_bytes()) - message = bot.send_photo(chat_id, photo=raw_bytes) + message = await bot.send_photo(chat_id, photo=raw_bytes) photo = message.photo[-1] assert isinstance(photo.file_id, str) assert isinstance(photo.file_unique_id, str) @@ -391,37 +402,31 @@ def test_send_bytesio_jpg_file(self, bot, chat_id): assert photo.height == 720 assert photo.file_size == 33372 - def test_send_with_photosize(self, monkeypatch, bot, chat_id, photo): - def test(url, data, **kwargs): - return data['photo'] == photo.file_id + @pytest.mark.asyncio + async def test_send_with_photosize(self, monkeypatch, bot, chat_id, photo): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['photo'] == photo.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_photo(photo=photo, chat_id=chat_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_photo(photo=photo, chat_id=chat_id) assert message @flaky(3, 1) - def test_resend(self, bot, chat_id, photo): - message = bot.send_photo(chat_id=chat_id, photo=photo.file_id) - - thumb, photo, _ = message.photo - - assert isinstance(message.photo[0], PhotoSize) - assert isinstance(message.photo[0].file_id, str) - assert isinstance(message.photo[0].file_unique_id, str) - assert message.photo[0].file_id != '' - assert message.photo[0].file_unique_id != '' - assert message.photo[0].width == thumb.width - assert message.photo[0].height == thumb.height - assert message.photo[0].file_size == thumb.file_size - - assert isinstance(message.photo[1], PhotoSize) - assert isinstance(message.photo[1].file_id, str) - assert isinstance(message.photo[1].file_unique_id, str) - assert message.photo[1].file_id != '' - assert message.photo[1].file_unique_id != '' - assert message.photo[1].width == photo.width - assert message.photo[1].height == photo.height - assert message.photo[1].file_size == photo.file_size + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, photo, thumb): + message = await bot.send_photo(chat_id=chat_id, photo=photo.file_id) + + assert isinstance(message.photo[-2], PhotoSize) + assert isinstance(message.photo[-2].file_id, str) + assert isinstance(message.photo[-2].file_unique_id, str) + assert message.photo[-2].file_id != '' + assert message.photo[-2].file_unique_id != '' + + assert isinstance(message.photo[-1], PhotoSize) + assert isinstance(message.photo[-1].file_id, str) + assert isinstance(message.photo[-1].file_unique_id, str) + assert message.photo[-1].file_id != '' + assert message.photo[-1].file_unique_id != '' def test_de_json(self, bot, photo): json_dict = { @@ -450,29 +455,33 @@ def test_to_dict(self, photo): assert photo_dict['file_size'] == photo.file_size @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_photo(chat_id=chat_id, photo=open(os.devnull, 'rb')) + await bot.send_photo(chat_id=chat_id, photo=open(os.devnull, 'rb')) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_photo(chat_id=chat_id, photo='') + await bot.send_photo(chat_id=chat_id, photo='') - def test_error_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_photo(chat_id=chat_id) + await bot.send_photo(chat_id=chat_id) - def test_get_file_instance_method(self, monkeypatch, photo): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, photo): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == photo.file_id assert check_shortcut_signature(PhotoSize.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(photo.get_file, photo.get_bot(), 'get_file') - assert check_defaults_handling(photo.get_file, photo.get_bot()) + assert await check_shortcut_call(photo.get_file, photo.get_bot(), 'get_file') + assert await check_defaults_handling(photo.get_file, photo.get_bot()) monkeypatch.setattr(photo.get_bot(), 'get_file', make_assertion) - assert photo.get_file() + assert await photo.get_file() def test_equality(self, photo): a = PhotoSize(photo.file_id, photo.file_unique_id, self.width, self.height) diff --git a/tests/test_pollanswerhandler.py b/tests/test_pollanswerhandler.py deleted file mode 100644 index a0704d700db..00000000000 --- a/tests/test_pollanswerhandler.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - CallbackQuery, - Bot, - Message, - User, - Chat, - PollAnswer, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import PollAnswerHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=2, **request.param) - - -@pytest.fixture(scope='function') -def poll_answer(bot): - return Update(0, poll_answer=PollAnswer(1, User(2, 'test user', False), [0, 1])) - - -class TestPollAnswerHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - handler = PollAnswerHandler(self.callback_context) - for attr in handler.__slots__: - assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.poll_answer, PollAnswer) - ) - - def test_other_update_types(self, false_update): - handler = PollAnswerHandler(self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp, poll_answer): - handler = PollAnswerHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(poll_answer) - assert self.test_flag diff --git a/tests/test_pollhandler.py b/tests/test_pollhandler.py deleted file mode 100644 index 83fafee9370..00000000000 --- a/tests/test_pollhandler.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - Poll, - PollOption, - Bot, - Message, - User, - Chat, - CallbackQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import PollHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=2, **request.param) - - -@pytest.fixture(scope='function') -def poll(bot): - return Update( - 0, - poll=Poll( - 1, - 'question', - [PollOption('1', 0), PollOption('2', 0)], - 0, - False, - False, - Poll.REGULAR, - True, - ), - ) - - -class TestPollHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = PollHandler(self.callback_context) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and context.user_data is None - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.poll, Poll) - ) - - def test_other_update_types(self, false_update): - handler = PollHandler(self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp, poll): - handler = PollHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(poll) - assert self.test_flag diff --git a/tests/test_precheckoutquery.py b/tests/test_precheckoutquery.py index 05721238a82..c782e066729 100644 --- a/tests/test_precheckoutquery.py +++ b/tests/test_precheckoutquery.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. + import pytest from telegram import Update, User, PreCheckoutQuery, OrderInfo, Bot @@ -45,12 +46,6 @@ class TestPreCheckoutQuery: from_user = User(0, '', False) order_info = OrderInfo() - def test_slot_behaviour(self, pre_checkout_query, mro_slots): - inst = pre_checkout_query - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - def test_de_json(self, bot): json_dict = { 'id': self.id_, @@ -84,24 +79,27 @@ def test_to_dict(self, pre_checkout_query): assert pre_checkout_query_dict['from'] == pre_checkout_query.from_user.to_dict() assert pre_checkout_query_dict['order_info'] == pre_checkout_query.order_info.to_dict() - def test_answer(self, monkeypatch, pre_checkout_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_answer(self, monkeypatch, pre_checkout_query): + async def make_assertion(*_, **kwargs): return kwargs['pre_checkout_query_id'] == pre_checkout_query.id assert check_shortcut_signature( PreCheckoutQuery.answer, Bot.answer_pre_checkout_query, ['pre_checkout_query_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( pre_checkout_query.answer, pre_checkout_query.get_bot(), 'answer_pre_checkout_query', ) - assert check_defaults_handling(pre_checkout_query.answer, pre_checkout_query.get_bot()) + assert await check_defaults_handling( + pre_checkout_query.answer, pre_checkout_query.get_bot() + ) monkeypatch.setattr( pre_checkout_query.get_bot(), 'answer_pre_checkout_query', make_assertion ) - assert pre_checkout_query.answer(ok=True) + assert await pre_checkout_query.answer(ok=True) def test_equality(self): a = PreCheckoutQuery( diff --git a/tests/test_precheckoutqueryhandler.py b/tests/test_precheckoutqueryhandler.py deleted file mode 100644 index 4635e2928cd..00000000000 --- a/tests/test_precheckoutqueryhandler.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - Chat, - Bot, - ChosenInlineResult, - User, - Message, - CallbackQuery, - InlineQuery, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import PreCheckoutQueryHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'inline_query', - 'chosen_inline_result', - 'shipping_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=1, **request.param) - - -@pytest.fixture(scope='class') -def pre_checkout_query(): - return Update( - 1, - pre_checkout_query=PreCheckoutQuery( - 'id', User(1, 'test user', False), 'EUR', 223, 'invoice_payload' - ), - ) - - -class TestPreCheckoutQueryHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = PreCheckoutQueryHandler(self.callback_context) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.pre_checkout_query, PreCheckoutQuery) - ) - - def test_other_update_types(self, false_update): - handler = PreCheckoutQueryHandler(self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp, pre_checkout_query): - handler = PreCheckoutQueryHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(pre_checkout_query) - assert self.test_flag diff --git a/tests/test_promise.py b/tests/test_promise.py deleted file mode 100644 index 5862357a34f..00000000000 --- a/tests/test_promise.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import logging -import pytest - -from telegram.error import TelegramError -from telegram.ext._utils.promise import Promise - - -class TestPromise: - """ - Here we just test the things that are not covered by the other tests anyway - """ - - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = Promise(self.test_call, [], {}) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def test_call(self): - def callback(): - self.test_flag = True - - promise = Promise(callback, [], {}) - promise() - - assert promise.done - assert self.test_flag - - def test_run_with_exception(self): - def callback(): - raise TelegramError('Error') - - promise = Promise(callback, [], {}) - promise.run() - - assert promise.done - assert not self.test_flag - assert isinstance(promise.exception, TelegramError) - - def test_wait_for_exception(self): - def callback(): - raise TelegramError('Error') - - promise = Promise(callback, [], {}) - promise.run() - - with pytest.raises(TelegramError, match='Error'): - promise.result() - - def test_done_cb_after_run(self): - def callback(): - return "done!" - - def done_callback(_): - self.test_flag = True - - promise = Promise(callback, [], {}) - promise.run() - promise.add_done_callback(done_callback) - assert promise.result() == "done!" - assert self.test_flag is True - - def test_done_cb_after_run_excp(self): - def callback(): - return "done!" - - def done_callback(_): - raise Exception("Error!") - - promise = Promise(callback, [], {}) - promise.run() - assert promise.result() == "done!" - with pytest.raises(Exception) as err: - promise.add_done_callback(done_callback) - assert str(err) == "Error!" - - def test_done_cb_before_run(self): - def callback(): - return "done!" - - def done_callback(_): - self.test_flag = True - - promise = Promise(callback, [], {}) - promise.add_done_callback(done_callback) - assert promise.result(0) != "done!" - assert self.test_flag is False - promise.run() - assert promise.result() == "done!" - assert self.test_flag is True - - def test_done_cb_before_run_excp(self, caplog): - def callback(): - return "done!" - - def done_callback(_): - raise Exception("Error!") - - promise = Promise(callback, [], {}) - promise.add_done_callback(done_callback) - assert promise.result(0) != "done!" - caplog.clear() - with caplog.at_level(logging.WARNING): - promise.run() - assert len(caplog.records) == 2 - assert caplog.records[0].message == ( - "`done_callback` of a Promise raised the following exception." - " The exception won't be handled by error handlers." - ) - assert caplog.records[1].message.startswith("Full traceback:") - assert promise.result() == "done!" - - def test_done_cb_not_run_on_excp(self): - def callback(): - raise TelegramError('Error') - - def done_callback(_): - self.test_flag = True - - promise = Promise(callback, [], {}) - promise.add_done_callback(done_callback) - promise.run() - assert isinstance(promise.exception, TelegramError) - assert promise.done - assert self.test_flag is False diff --git a/tests/test_replykeyboardmarkup.py b/tests/test_replykeyboardmarkup.py index d7627860c8a..405a85cd78a 100644 --- a/tests/test_replykeyboardmarkup.py +++ b/tests/test_replykeyboardmarkup.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. + import pytest from flaky import flaky @@ -29,7 +30,6 @@ def reply_keyboard_markup(): resize_keyboard=TestReplyKeyboardMarkup.resize_keyboard, one_time_keyboard=TestReplyKeyboardMarkup.one_time_keyboard, selective=TestReplyKeyboardMarkup.selective, - input_field_placeholder=TestReplyKeyboardMarkup.input_field_placeholder, ) @@ -38,23 +38,22 @@ class TestReplyKeyboardMarkup: resize_keyboard = True one_time_keyboard = True selective = True - input_field_placeholder = 'lol a keyboard' - - def test_slot_behaviour(self, reply_keyboard_markup, mro_slots): - inst = reply_keyboard_markup - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" @flaky(3, 1) - def test_send_message_with_reply_keyboard_markup(self, bot, chat_id, reply_keyboard_markup): - message = bot.send_message(chat_id, 'Text', reply_markup=reply_keyboard_markup) + @pytest.mark.asyncio + async def test_send_message_with_reply_keyboard_markup( + self, bot, chat_id, reply_keyboard_markup + ): + message = await bot.send_message(chat_id, 'Text', reply_markup=reply_keyboard_markup) assert message.text == 'Text' @flaky(3, 1) - def test_send_message_with_data_markup(self, bot, chat_id): - message = bot.send_message(chat_id, 'text 2', reply_markup={'keyboard': [['1', '2']]}) + @pytest.mark.asyncio + async def test_send_message_with_data_markup(self, bot, chat_id): + message = await bot.send_message( + chat_id, 'text 2', reply_markup={'keyboard': [['1', '2']]} + ) assert message.text == 'text 2' @@ -100,13 +99,6 @@ def test_expected_values(self, reply_keyboard_markup): assert reply_keyboard_markup.resize_keyboard == self.resize_keyboard assert reply_keyboard_markup.one_time_keyboard == self.one_time_keyboard assert reply_keyboard_markup.selective == self.selective - assert reply_keyboard_markup.input_field_placeholder == self.input_field_placeholder - - def test_wrong_keyboard_inputs(self): - with pytest.raises(ValueError): - ReplyKeyboardMarkup([[KeyboardButton('b1')], 'b2']) - with pytest.raises(ValueError): - ReplyKeyboardMarkup(KeyboardButton('b1')) def test_to_dict(self, reply_keyboard_markup): reply_keyboard_markup_dict = reply_keyboard_markup.to_dict() @@ -128,10 +120,6 @@ def test_to_dict(self, reply_keyboard_markup): == reply_keyboard_markup.one_time_keyboard ) assert reply_keyboard_markup_dict['selective'] == reply_keyboard_markup.selective - assert ( - reply_keyboard_markup_dict['input_field_placeholder'] - == reply_keyboard_markup.input_field_placeholder - ) def test_equality(self): a = ReplyKeyboardMarkup.from_column(['button1', 'button2', 'button3']) diff --git a/tests/test_replykeyboardremove.py b/tests/test_replykeyboardremove.py index 0a516d6de9d..89768f16e71 100644 --- a/tests/test_replykeyboardremove.py +++ b/tests/test_replykeyboardremove.py @@ -38,8 +38,11 @@ def test_slot_behaviour(self, reply_keyboard_remove, mro_slots): assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" @flaky(3, 1) - def test_send_message_with_reply_keyboard_remove(self, bot, chat_id, reply_keyboard_remove): - message = bot.send_message(chat_id, 'Text', reply_markup=reply_keyboard_remove) + @pytest.mark.asyncio + async def test_send_message_with_reply_keyboard_remove( + self, bot, chat_id, reply_keyboard_remove + ): + message = await bot.send_message(chat_id, 'Text', reply_markup=reply_keyboard_remove) assert message.text == 'Text' diff --git a/tests/test_request.py b/tests/test_request.py index cc1c3ba2bd7..dd813e4e016 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -16,51 +16,444 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -from pathlib import Path +"""Here we run tests directly with HTTPXRequest because that's easier than providing dummy +implementations for BaseRequest and we want to test HTTPXRequest anyway.""" +import json +from http import HTTPStatus +from typing import Tuple, Any, Coroutine, Callable +import httpx import pytest -from telegram.error import TelegramError -from telegram.request import Request -from tests.conftest import data_file +from telegram._utils.defaultvalue import DEFAULT_NONE +from telegram._utils.types import ODVInput +from telegram.error import ( + TelegramError, + ChatMigrated, + RetryAfter, + NetworkError, + Forbidden, + InvalidToken, + BadRequest, + Conflict, + TimedOut, +) +from telegram.request import BaseRequest, RequestData +from telegram.request._httpxrequest import HTTPXRequest +# We only need the first fixture, but it uses the others, so pytest needs us to import them as well +from .test_requestdata import ( # noqa: F401 + mixed_rqs, + mixed_params, + file_params, + simple_params, + inputfile, + input_media_video, + input_media_photo, +) -def test_slot_behaviour(mro_slots): - inst = Request() - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" +def mocker_factory( + response: bytes, return_code: int = HTTPStatus.OK +) -> Callable[[Tuple[Any]], Coroutine[Any, Any, Tuple[int, bytes]]]: + async def make_assertion(*args, **kwargs): + return return_code, response -def test_replaced_unprintable_char(): - """ - Clients can send arbitrary bytes in callback data. - Make sure the correct error is raised in this case. - """ - server_response = b'{"invalid utf-8": "\x80", "result": "KUKU"}' + return make_assertion - assert Request._parse(server_response) == 'KUKU' +@pytest.fixture(scope='function') +@pytest.mark.asyncio +async def httpx_request(): + async with HTTPXRequest() as rq: + yield rq -def test_parse_illegal_json(): - """ - Clients can send arbitrary bytes in callback data. - Make sure the correct error is raised in this case. - """ - server_response = b'{"invalid utf-8": "\x80", result: "KUKU"}' - with pytest.raises(TelegramError, match='Invalid server response'): - Request._parse(server_response) +# TODO: Test timeouts -@pytest.mark.parametrize( - "destination_path_type", - [str, Path], - ids=['str destination_path', 'pathlib.Path destination_path'], -) -def test_download(destination_path_type): - destination_filepath = data_file('downloaded_request.txt') - request = Request() - request.download("http://google.com", destination_path_type(destination_filepath)) - assert destination_filepath.is_file() - destination_filepath.unlink() +class TestRequest: + test_flag = None + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = None + + def test_slot_behaviour(self, mro_slots): + inst = HTTPXRequest() + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.mark.asyncio + async def test_context_manager(self, monkeypatch): + async def initialize(): + self.test_flag = ['initialize'] + + async def stop(): + self.test_flag.append('stop') + + httpx_request = HTTPXRequest() + + monkeypatch.setattr(httpx_request, 'initialize', initialize) + monkeypatch.setattr(httpx_request, 'stop', stop) + + async with httpx_request: + pass + + assert self.test_flag == ['initialize', 'stop'] + + @pytest.mark.asyncio + async def test_context_manager_exception_on_init(self, monkeypatch): + async def initialize(): + raise RuntimeError('initialize') + + async def stop(): + self.test_flag = 'stop' + + httpx_request = HTTPXRequest() + + monkeypatch.setattr(httpx_request, 'initialize', initialize) + monkeypatch.setattr(httpx_request, 'stop', stop) + + with pytest.raises(RuntimeError, match='initialize'): + async with httpx_request: + pass + + assert self.test_flag == 'stop' + + @pytest.mark.asyncio + async def test_replaced_unprintable_char(self, monkeypatch, httpx_request): + """Clients can send arbitrary bytes in callback data. Make sure that we just replace + those + """ + server_response = b'{"result": "test_string\x80"}' + + monkeypatch.setattr(httpx_request, 'do_request', mocker_factory(response=server_response)) + + assert await httpx_request.post(None, None, None) == 'test_string�' + + @pytest.mark.asyncio + async def test_illegal_json_response(self, monkeypatch, httpx_request: HTTPXRequest): + # for proper JSON it should be `"result":` instead of `result:` + server_response = b'{result: "test_string"}' + + monkeypatch.setattr(httpx_request, 'do_request', mocker_factory(response=server_response)) + + with pytest.raises(TelegramError, match='Invalid server response'): + await httpx_request.post(None, None, None) + + @pytest.mark.asyncio + async def test_chat_migrated(self, monkeypatch, httpx_request: HTTPXRequest): + server_response = b'{"ok": "False", "parameters": {"migrate_to_chat_id": "123"}}' + + monkeypatch.setattr( + httpx_request, + 'do_request', + mocker_factory(response=server_response, return_code=HTTPStatus.BAD_REQUEST), + ) + + with pytest.raises(ChatMigrated, match='New chat id: 123') as exc_info: + await httpx_request.post(None, None, None) + + assert exc_info.value.new_chat_id == 123 + + @pytest.mark.asyncio + async def test_retry_after(self, monkeypatch, httpx_request: HTTPXRequest): + server_response = b'{"ok": "False", "parameters": {"retry_after": "42"}}' + + monkeypatch.setattr( + httpx_request, + 'do_request', + mocker_factory(response=server_response, return_code=HTTPStatus.BAD_REQUEST), + ) + + with pytest.raises(RetryAfter, match='Retry in 42.0') as exc_info: + await httpx_request.post(None, None, None) + + assert exc_info.value.retry_after == 42.0 + + @pytest.mark.asyncio + @pytest.mark.parametrize('description', [True, False]) + async def test_error_description(self, monkeypatch, httpx_request: HTTPXRequest, description): + response_data = {"ok": "False"} + if description: + match = 'ErrorDescription' + response_data['description'] = match + else: + match = 'Unknown HTTPError' + + server_response = json.dumps(response_data).encode('utf-8') + + monkeypatch.setattr( + httpx_request, + 'do_request', + mocker_factory(response=server_response, return_code=-1), + ) + + with pytest.raises(NetworkError, match=match): + await httpx_request.post(None, None, None) + + # Special casing for bad gateway + if not description: + monkeypatch.setattr( + httpx_request, + 'do_request', + mocker_factory(response=server_response, return_code=HTTPStatus.BAD_GATEWAY), + ) + + with pytest.raises(NetworkError, match='Bad Gateway'): + await httpx_request.post(None, None, None) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'code, exception_class', + [ + (HTTPStatus.FORBIDDEN, Forbidden), + (HTTPStatus.NOT_FOUND, InvalidToken), + (HTTPStatus.UNAUTHORIZED, InvalidToken), + (HTTPStatus.BAD_REQUEST, BadRequest), + (HTTPStatus.CONFLICT, Conflict), + (HTTPStatus.BAD_GATEWAY, NetworkError), + (-1, NetworkError), + ], + ) + async def test_special_errors( + self, monkeypatch, httpx_request: HTTPXRequest, code, exception_class + ): + server_response = b'{"ok": "False", "description": "Test Message"}' + + monkeypatch.setattr( + httpx_request, + 'do_request', + mocker_factory(response=server_response, return_code=code), + ) + + with pytest.raises(exception_class, match='Test Message'): + await httpx_request.post(None, None, None) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ['exception', 'catch_class', 'match'], + [ + (TelegramError('TelegramError'), TelegramError, 'TelegramError'), + (RuntimeError('CustomError'), Exception, 'HTTP implementation: CustomError'), + ], + ) + async def test_exceptions_in_do_request( + self, monkeypatch, httpx_request: HTTPXRequest, exception, catch_class, match + ): + async def do_request(*args, **kwargs): + raise exception + + monkeypatch.setattr( + httpx_request, + 'do_request', + do_request, + ) + + with pytest.raises(catch_class, match=match): + await httpx_request.post(None, None, None) + + @pytest.mark.asyncio + async def test_retrieve(self, monkeypatch, httpx_request): + """Here we just test that retrieve gives us the raw bytes instead of trying to parse them + as json + """ + server_response = b'{"result": "test_string\x80"}' + + monkeypatch.setattr(httpx_request, 'do_request', mocker_factory(response=server_response)) + + assert await httpx_request.retrieve(None, None) == server_response + + def test_connection_pool_size(self): + class Request(BaseRequest): + async def do_request(self, *args, **kwargs): + pass + + async def initialize(self, *args, **kwargs): + pass + + async def shutdown(self, *args, **kwargs): + pass + + with pytest.raises(NotImplementedError): + Request().connection_pool_size + + @pytest.mark.asyncio + async def test_timeout_propagation(self, monkeypatch, httpx_request): + """Here we just test that retrieve gives us the raw bytes instead of trying to parse them + as json + """ + + async def make_assertion( + method: str, + url: str, + request_data: RequestData = None, + read_timeout: ODVInput[float] = DEFAULT_NONE, + *args, + **kwargs, + ): + self.test_flag = read_timeout + return HTTPStatus.OK, b'{"ok": "True", "result": {}}' + + monkeypatch.setattr(httpx_request, 'do_request', make_assertion) + + await httpx_request.post('url', None, read_timeout=42.314) + assert self.test_flag == 42.314 + + +class TestHTTPXRequest: + test_flag = None + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = None + + def test_init(self): + request = HTTPXRequest() + assert request.connection_pool_size == 1 + assert request._client.timeout == httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + + request = HTTPXRequest( + connection_pool_size=42, + connect_timeout=43, + read_timeout=44, + write_timeout=45, + pool_timeout=46, + ) + assert request.connection_pool_size == 42 + assert request._client.timeout == httpx.Timeout(connect=43, read=44, write=45, pool=46) + + @pytest.mark.asyncio + async def test_context_manager(self, monkeypatch): + async def initialize(): + self.test_flag = ['initialize'] + + async def aclose(*args): + self.test_flag.append('stop') + + httpx_request = HTTPXRequest() + + monkeypatch.setattr(httpx_request, 'initialize', initialize) + monkeypatch.setattr(httpx.AsyncClient, 'aclose', aclose) + + async with httpx_request: + pass + + assert self.test_flag == ['initialize', 'stop'] + + @pytest.mark.asyncio + async def test_context_manager_exception_on_init(self, monkeypatch): + async def initialize(): + raise RuntimeError('initialize') + + async def aclose(*args): + self.test_flag = 'stop' + + httpx_request = HTTPXRequest() + + monkeypatch.setattr(httpx_request, 'initialize', initialize) + monkeypatch.setattr(httpx.AsyncClient, 'aclose', aclose) + + with pytest.raises(RuntimeError, match='initialize'): + async with httpx_request: + pass + + assert self.test_flag == 'stop' + + @pytest.mark.asyncio + async def test_do_request_default_timeouts(self, monkeypatch, httpx_request): + default_timeouts = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + + async def make_assertion(self, method, url, headers, timeout, files, data): + self.test_flag = timeout == default_timeouts + return httpx.Response(HTTPStatus.OK) + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + await httpx_request.do_request('GET', 'URL') + assert httpx_request._client.timeout == default_timeouts + + @pytest.mark.asyncio + async def test_do_request_manual_timeouts(self, monkeypatch, httpx_request): + default_timeouts = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + + async def make_assertion(self, method, url, headers, timeout, files, data): + self.test_flag = timeout == httpx.Timeout(connect=5.0, read=5.5, write=5.6, pool=1.0) + return httpx.Response(HTTPStatus.OK) + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + await httpx_request.do_request('GET', 'URL', read_timeout=5.5, write_timeout=5.6) + assert httpx_request._client.timeout == default_timeouts + + @pytest.mark.asyncio + async def test_do_request_params_no_data(self, monkeypatch, httpx_request): + async def make_assertion(self, method, url, headers, timeout, files, data): + method_assertion = method == 'method' + url_assertion = url == 'url' + files_assertion = files is None + data_assertion = data is None + if method_assertion and url_assertion and files_assertion and data_assertion: + return httpx.Response(HTTPStatus.OK) + return httpx.Response(HTTPStatus.BAD_REQUEST) + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + code, _ = await httpx_request.do_request( + 'method', 'url', read_timeout=5.5, write_timeout=5.6 + ) + assert code == HTTPStatus.OK + + @pytest.mark.asyncio + async def test_do_request_params_with_data( + self, monkeypatch, httpx_request, mixed_rqs # noqa: 9811 + ): + async def make_assertion(self, method, url, headers, timeout, files, data): + method_assertion = method == 'method' + url_assertion = url == 'url' + files_assertion = files == mixed_rqs.multipart_data + data_assertion = data == mixed_rqs.json_parameters + if method_assertion and url_assertion and files_assertion and data_assertion: + return httpx.Response(HTTPStatus.OK) + return httpx.Response(HTTPStatus.BAD_REQUEST) + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + code, _ = await httpx_request.do_request( + 'method', + 'url', + read_timeout=5.5, + write_timeout=5.6, + request_data=mixed_rqs, + ) + assert code == HTTPStatus.OK + + @pytest.mark.asyncio + async def test_do_request_return_value(self, monkeypatch, httpx_request): + async def make_assertion(self, method, url, headers, timeout, files, data): + return httpx.Response(123, content=b'content') + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + code, content = await httpx_request.do_request( + 'method', + 'url', + ) + assert code == 123 + assert content == b'content' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ['raised_class', 'expected_class'], + [(httpx.TimeoutException, TimedOut), (httpx.HTTPError, NetworkError)], + ) + async def test_do_request_exceptions( + self, monkeypatch, httpx_request, raised_class, expected_class + ): + async def make_assertion(self, method, url, headers, timeout, files, data): + raise raised_class('message') + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + + with pytest.raises(expected_class): + await httpx_request.do_request( + 'method', + 'url', + ) diff --git a/tests/test_requestdata.py b/tests/test_requestdata.py new file mode 100644 index 00000000000..a8b0356c195 --- /dev/null +++ b/tests/test_requestdata.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +from urllib.parse import quote + +try: + import ujson as json +except ImportError: + import json +from typing import Any, Dict + +import pytest + +from telegram import InputFile, MessageEntity, InputMediaPhoto, InputMediaVideo +from telegram.request import RequestData +from telegram.request._requestparameter import RequestParameter +from tests.conftest import data_file + + +@pytest.fixture(scope='module') +def inputfile() -> InputFile: + return InputFile(data_file('telegram.jpg').read_bytes()) + + +@pytest.fixture(scope='module') +def input_media_video() -> InputMediaVideo: + return InputMediaVideo( + media=data_file('telegram.mp4').read_bytes(), + thumb=data_file('telegram.jpg').read_bytes(), + parse_mode=None, + ) + + +@pytest.fixture(scope='module') +def input_media_photo() -> InputMediaPhoto: + return InputMediaPhoto( + media=data_file('telegram.jpg').read_bytes(), + parse_mode=None, + ) + + +@pytest.fixture(scope='module') +def simple_params() -> Dict[str, Any]: + return { + 'string': 'string', + 'integer': 1, + 'tg_object': MessageEntity('type', 1, 1).to_dict(), + 'list': [1, 'string', MessageEntity('type', 1, 1).to_dict()], + } + + +@pytest.fixture(scope='module') +def simple_jsons() -> Dict[str, Any]: + return { + 'string': 'string', + 'integer': json.dumps(1), + 'tg_object': MessageEntity('type', 1, 1).to_json(), + 'list': json.dumps([1, 'string', MessageEntity('type', 1, 1).to_dict()]), + } + + +@pytest.fixture(scope='module') +def simple_rqs(simple_params) -> RequestData: + return RequestData( + [RequestParameter.from_input(key, value) for key, value in simple_params.items()] + ) + + +@pytest.fixture(scope='module') +def file_params(inputfile, input_media_video, input_media_photo) -> Dict[str, Any]: + return { + 'inputfile': inputfile, + 'inputmedia': input_media_video, + 'inputmedia_list': [input_media_video, input_media_photo], + } + + +@pytest.fixture(scope='module') +def file_jsons(inputfile, input_media_video, input_media_photo) -> Dict[str, Any]: + input_media_video_dict = input_media_video.to_dict() + input_media_video_dict['media'] = input_media_video.media.attach_uri + input_media_video_dict['thumb'] = input_media_video.thumb.attach_uri + input_media_photo_dict = input_media_photo.to_dict() + input_media_photo_dict['media'] = input_media_photo.media.attach_uri + return { + 'inputfile': inputfile.attach_uri, + 'inputmedia': json.dumps(input_media_video_dict), + 'inputmedia_list': json.dumps([input_media_video_dict, input_media_photo_dict]), + } + + +@pytest.fixture(scope='module') +def file_rqs(file_params) -> RequestData: + return RequestData( + [RequestParameter.from_input(key, value) for key, value in file_params.items()] + ) + + +@pytest.fixture() +def mixed_params(file_params, simple_params) -> Dict[str, Any]: + both = file_params.copy() + both.update(simple_params) + return both + + +@pytest.fixture() +def mixed_jsons(file_jsons, simple_jsons) -> Dict[str, Any]: + both = file_jsons.copy() + both.update(simple_jsons) + return both + + +@pytest.fixture() +def mixed_rqs(mixed_params) -> RequestData: + return RequestData( + [RequestParameter.from_input(key, value) for key, value in mixed_params.items()] + ) + + +class TestRequestData: + def test_contains_files(self, simple_rqs, file_rqs, mixed_rqs): + assert not simple_rqs.contains_files + assert file_rqs.contains_files + assert mixed_rqs.contains_files + + def test_parameters( + self, + simple_rqs, + simple_params, # file_rqs, mixed_rqs, file_params, mixed_params + ): + assert simple_rqs.parameters == simple_params + # We don't test these for now since that's a struggle + # And the conversation part is already being tested in test_requestparameter.py + # assert file_rqs.parameters == file_params + # assert mixed_rqs.parameters == mixed_params + + def test_json_parameters( + self, simple_rqs, file_rqs, mixed_rqs, simple_jsons, file_jsons, mixed_jsons + ): + assert simple_rqs.json_parameters == simple_jsons + assert file_rqs.json_parameters == file_jsons + assert mixed_rqs.json_parameters == mixed_jsons + + def test_json_payload( + self, simple_rqs, file_rqs, mixed_rqs, simple_jsons, file_jsons, mixed_jsons + ): + assert simple_rqs.json_payload == json.dumps(simple_jsons).encode() + assert file_rqs.json_payload == json.dumps(file_jsons).encode() + assert mixed_rqs.json_payload == json.dumps(mixed_jsons).encode() + + def test_multipart_data( + self, + simple_rqs, + file_rqs, + mixed_rqs, + inputfile, + input_media_video, + input_media_photo, + ): + expected = { + inputfile.attach_name: inputfile.field_tuple, + input_media_photo.media.attach_name: input_media_photo.media.field_tuple, + input_media_video.media.attach_name: input_media_video.media.field_tuple, + input_media_video.thumb.attach_name: input_media_video.thumb.field_tuple, + } + assert simple_rqs.multipart_data == {} + assert file_rqs.multipart_data == expected + assert mixed_rqs.multipart_data == expected + + def test_url_encoding(self, monkeypatch): + data = RequestData( + [ + RequestParameter.from_input('chat_id', 123), + RequestParameter.from_input('text', 'Hello there/!'), + ] + ) + expected_params = 'chat_id=123&text=Hello+there%2F%21' + expected_url = 'https://te.st/method?' + expected_params + assert data.url_encoded_parameters() == expected_params + assert data.build_parametrized_url('https://te.st/method') == expected_url + + expected_params = 'chat_id=123&text=Hello%20there/!' + expected_url = 'https://te.st/method?' + expected_params + assert ( + data.url_encoded_parameters(encode_kwargs={'quote_via': quote, 'safe': '/!'}) + == expected_params + ) + assert ( + data.build_parametrized_url( + 'https://te.st/method', encode_kwargs={'quote_via': quote, 'safe': '/!'} + ) + == expected_url + ) diff --git a/tests/test_requestparameter.py b/tests/test_requestparameter.py new file mode 100644 index 00000000000..aaf9ea75027 --- /dev/null +++ b/tests/test_requestparameter.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import datetime + +import pytest + +from telegram import InputFile, MessageEntity, InputMediaPhoto, InputMediaVideo +from telegram.constants import ChatType +from telegram.request._requestparameter import RequestParameter +from tests.conftest import data_file + + +class TestRequestParameter: + def test_init(self): + request_parameter = RequestParameter('name', 'value', [1, 2]) + assert request_parameter.name == 'name' + assert request_parameter.value == 'value' + assert request_parameter.input_files == [1, 2] + + request_parameter = RequestParameter('name', 'value', None) + assert request_parameter.name == 'name' + assert request_parameter.value == 'value' + assert request_parameter.input_files is None + + @pytest.mark.parametrize( + 'value, expected', + [ + (1, '1'), + ('one', 'one'), + (True, 'true'), + (None, 'null'), + ([1, '1'], '[1, "1"]'), + ({True: None}, '{"true": null}'), + ((1,), '[1]'), + ], + ) + def test_json_value(self, value, expected): + request_parameter = RequestParameter('name', value, None) + assert request_parameter.json_value == expected + + def test_multipart_data(self): + assert RequestParameter('name', 'value', []).multipart_data is None + + input_file_1 = InputFile(data_file('telegram.jpg').read_bytes()) + input_file_2 = InputFile(data_file('telegram.jpg').read_bytes(), filename='custom') + request_parameter = RequestParameter('value', 'name', [input_file_1, input_file_2]) + files = request_parameter.multipart_data + assert files[input_file_1.attach_name] == input_file_1.field_tuple + assert files[input_file_2.attach_name] == input_file_2.field_tuple + + @pytest.mark.parametrize( + ('value', 'expected_value'), + [ + (True, True), + ('str', 'str'), + ({1: 1.0}, {1: 1.0}), + (ChatType.PRIVATE, 'private'), + (MessageEntity('type', 1, 1), {'type': 'type', 'offset': 1, 'length': 1}), + (datetime.datetime(2019, 11, 11, 0, 26, 16, 10 ** 5), 1573431976), + ( + [ + True, + 'str', + MessageEntity('type', 1, 1), + ChatType.PRIVATE, + datetime.datetime(2019, 11, 11, 0, 26, 16, 10 ** 5), + ], + [True, 'str', {'type': 'type', 'offset': 1, 'length': 1}, 'private', 1573431976], + ), + ], + ) + def test_from_input_no_media(self, value, expected_value): + request_parameter = RequestParameter.from_input('key', value) + assert request_parameter.value == expected_value + assert request_parameter.input_files is None + + def test_from_input_inputfile(self): + inputfile_1 = InputFile(data_file('telegram.jpg').read_bytes(), 'inputfile_1') + inputfile_2 = InputFile(data_file('telegram.mp4').read_bytes(), 'inputfile_2') + + request_parameter = RequestParameter.from_input('key', inputfile_1) + assert request_parameter.value == inputfile_1.attach_uri + assert request_parameter.input_files == [inputfile_1] + + request_parameter = RequestParameter.from_input('key', [inputfile_1, inputfile_2]) + assert request_parameter.value == [inputfile_1.attach_uri, inputfile_2.attach_uri] + assert request_parameter.input_files == [inputfile_1, inputfile_2] + + def test_from_input_input_media(self): + input_media_no_thumb = InputMediaPhoto(media=data_file('telegram.jpg').read_bytes()) + input_media_thumb = InputMediaVideo( + media=data_file('telegram.mp4').read_bytes(), + thumb=data_file('telegram.jpg').read_bytes(), + ) + + request_parameter = RequestParameter.from_input('key', input_media_no_thumb) + expected_no_thumb = input_media_no_thumb.to_dict() + expected_no_thumb.update({'media': input_media_no_thumb.media.attach_uri}) + assert request_parameter.value == expected_no_thumb + assert request_parameter.input_files == [input_media_no_thumb.media] + + request_parameter = RequestParameter.from_input('key', input_media_thumb) + expected_thumb = input_media_thumb.to_dict() + expected_thumb.update({'media': input_media_thumb.media.attach_uri}) + expected_thumb.update({'thumb': input_media_thumb.thumb.attach_uri}) + assert request_parameter.value == expected_thumb + assert request_parameter.input_files == [input_media_thumb.media, input_media_thumb.thumb] + + request_parameter = RequestParameter.from_input( + 'key', [input_media_thumb, input_media_no_thumb] + ) + assert request_parameter.value == [expected_thumb, expected_no_thumb] + assert request_parameter.input_files == [ + input_media_thumb.media, + input_media_thumb.thumb, + input_media_no_thumb.media, + ] diff --git a/tests/test_shippingquery.py b/tests/test_shippingquery.py index 8a42fa7af92..d9415436a6d 100644 --- a/tests/test_shippingquery.py +++ b/tests/test_shippingquery.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. + import pytest from telegram import Update, User, ShippingAddress, ShippingQuery, Bot @@ -39,12 +40,6 @@ class TestShippingQuery: from_user = User(0, '', False) shipping_address = ShippingAddress('GB', '', 'London', '12 Grimmauld Place', '', 'WC1') - def test_slot_behaviour(self, shipping_query, mro_slots): - inst = shipping_query - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - def test_de_json(self, bot): json_dict = { 'id': TestShippingQuery.id_, @@ -69,20 +64,21 @@ def test_to_dict(self, shipping_query): assert shipping_query_dict['from'] == shipping_query.from_user.to_dict() assert shipping_query_dict['shipping_address'] == shipping_query.shipping_address.to_dict() - def test_answer(self, monkeypatch, shipping_query): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_answer(self, monkeypatch, shipping_query): + async def make_assertion(*_, **kwargs): return kwargs['shipping_query_id'] == shipping_query.id assert check_shortcut_signature( ShippingQuery.answer, Bot.answer_shipping_query, ['shipping_query_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( shipping_query.answer, shipping_query._bot, 'answer_shipping_query' ) - assert check_defaults_handling(shipping_query.answer, shipping_query._bot) + assert await check_defaults_handling(shipping_query.answer, shipping_query._bot) monkeypatch.setattr(shipping_query._bot, 'answer_shipping_query', make_assertion) - assert shipping_query.answer(ok=True) + assert await shipping_query.answer(ok=True) def test_equality(self): a = ShippingQuery(self.id_, self.from_user, self.invoice_payload, self.shipping_address) diff --git a/tests/test_shippingqueryhandler.py b/tests/test_shippingqueryhandler.py deleted file mode 100644 index f43e7f9ab66..00000000000 --- a/tests/test_shippingqueryhandler.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Update, - Chat, - Bot, - ChosenInlineResult, - User, - Message, - CallbackQuery, - InlineQuery, - ShippingQuery, - PreCheckoutQuery, - ShippingAddress, -) -from telegram.ext import ShippingQueryHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'inline_query', - 'chosen_inline_result', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=1, **request.param) - - -@pytest.fixture(scope='class') -def shiping_query(): - return Update( - 1, - shipping_query=ShippingQuery( - 42, - User(1, 'test user', False), - 'invoice_payload', - ShippingAddress('EN', 'my_state', 'my_city', 'steer_1', '', 'post_code'), - ), - ) - - -class TestShippingQueryHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = ShippingQueryHandler(self.callback_context) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, Update) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and isinstance(context.user_data, dict) - and context.chat_data is None - and isinstance(context.bot_data, dict) - and isinstance(update.shipping_query, ShippingQuery) - ) - - def test_other_update_types(self, false_update): - handler = ShippingQueryHandler(self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp, shiping_query): - handler = ShippingQueryHandler(self.callback_context) - dp.add_handler(handler) - - dp.process_update(shiping_query) - assert self.test_flag diff --git a/tests/test_slots.py b/tests/test_slots.py index f1168c34c24..512f64df2a3 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -19,10 +19,8 @@ import importlib import os from pathlib import Path - import inspect - included = { # These modules/classes intentionally have __dict__. 'CallbackContext', 'BasePersistence', @@ -30,7 +28,7 @@ def test_class_has_slots_and_no_dict(): - tg_paths = [p for p in Path('telegram').rglob("*.py") if 'vendor' not in str(p)] + tg_paths = Path('telegram').rglob("*.py") for path in tg_paths: if '__' in str(path): # Exclude __init__, __main__, etc diff --git a/tests/test_stack.py b/tests/test_stack.py index 4e484b3f287..b01bdd2670d 100644 --- a/tests/test_stack.py +++ b/tests/test_stack.py @@ -32,4 +32,4 @@ def test_called_by_current_file(self): assert was_called_by(frame, file) # Testing a call by a different file is somewhat hard but it's covered in - # TestUpdater/Dispatcher.test_manual_init_warning + # TestUpdater/Application.test_manual_init_warning diff --git a/tests/test_sticker.py b/tests/test_sticker.py index 5d4043dc424..022093ba162 100644 --- a/tests/test_sticker.py +++ b/tests/test_sticker.py @@ -25,6 +25,7 @@ from telegram import Sticker, PhotoSize, StickerSet, Audio, MaskPosition, Bot from telegram.error import BadRequest, TelegramError +from telegram.request import RequestData from tests.conftest import ( check_shortcut_call, check_shortcut_signature, @@ -40,9 +41,10 @@ def sticker_file(): @pytest.fixture(scope='class') -def sticker(bot, chat_id): +@pytest.mark.asyncio +async def sticker(bot, chat_id): with data_file('telegram.webp').open('rb') as f: - return bot.send_sticker(chat_id, sticker=f, timeout=50).sticker + return (await bot.send_sticker(chat_id, sticker=f, read_timeout=50)).sticker @pytest.fixture(scope='function') @@ -52,9 +54,10 @@ def animated_sticker_file(): @pytest.fixture(scope='class') -def animated_sticker(bot, chat_id): +@pytest.mark.asyncio +async def animated_sticker(bot, chat_id): with data_file('telegram_animated_sticker.tgs').open('rb') as f: - return bot.send_sticker(chat_id, sticker=f, timeout=50).sticker + return (await bot.send_sticker(chat_id, sticker=f, read_timeout=50)).sticker @pytest.fixture(scope='function') @@ -119,8 +122,9 @@ def test_expected_values(self, sticker): assert sticker.thumb.file_size == self.thumb_file_size @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, sticker_file, sticker): - message = bot.send_sticker( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, sticker_file, sticker): + message = await bot.send_sticker( chat_id, sticker=sticker_file, disable_notification=False, protect_content=True ) @@ -146,34 +150,42 @@ def test_send_all_args(self, bot, chat_id, sticker_file, sticker): assert message.has_protected_content @flaky(3, 1) - def test_get_and_download(self, bot, sticker): - new_file = bot.get_file(sticker.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, sticker): + path = Path('telegram.webp') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(sticker.file_id) assert new_file.file_size == sticker.file_size assert new_file.file_id == sticker.file_id assert new_file.file_unique_id == sticker.file_unique_id assert new_file.file_path.startswith('https://') - new_file.download('telegram.webp') + await new_file.download('telegram.webp') - assert Path('telegram.webp').is_file() + assert path.is_file() @flaky(3, 1) - def test_resend(self, bot, chat_id, sticker): - message = bot.send_sticker(chat_id=chat_id, sticker=sticker.file_id) + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, sticker): + message = await bot.send_sticker(chat_id=chat_id, sticker=sticker.file_id) assert message.sticker == sticker @flaky(3, 1) - def test_send_on_server_emoji(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_on_server_emoji(self, bot, chat_id): server_file_id = 'CAADAQADHAADyIsGAAFZfq1bphjqlgI' - message = bot.send_sticker(chat_id=chat_id, sticker=server_file_id) + message = await bot.send_sticker(chat_id=chat_id, sticker=server_file_id) sticker = message.sticker assert sticker.emoji == self.emoji @flaky(3, 1) - def test_send_from_url(self, bot, chat_id): - message = bot.send_sticker(chat_id=chat_id, sticker=self.sticker_file_url) + @pytest.mark.asyncio + async def test_send_from_url(self, bot, chat_id): + message = await bot.send_sticker(chat_id=chat_id, sticker=self.sticker_file_url) sticker = message.sticker assert isinstance(message.sticker, Sticker) @@ -220,26 +232,28 @@ def test_de_json(self, bot, sticker): assert json_sticker.file_size == self.file_size assert json_sticker.thumb == sticker.thumb - def test_send_with_sticker(self, monkeypatch, bot, chat_id, sticker): - def test(url, data, **kwargs): - return data['sticker'] == sticker.file_id + @pytest.mark.asyncio + async def test_send_with_sticker(self, monkeypatch, bot, chat_id, sticker): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['sticker'] == sticker.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_sticker(sticker=sticker, chat_id=chat_id) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_sticker(sticker=sticker, chat_id=chat_id) assert message - def test_send_sticker_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_sticker_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('sticker') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_sticker(chat_id, file) + await bot.send_sticker(chat_id, file) assert test_flag monkeypatch.delattr(bot, '_post') @@ -253,13 +267,14 @@ def make_assertion(_, data, *args, **kwargs): ], indirect=['default_bot'], ) - def test_send_sticker_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_sticker_default_allow_sending_without_reply( self, default_bot, chat_id, sticker, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_sticker( + message = await default_bot.send_sticker( chat_id, sticker, allow_sending_without_reply=custom, @@ -267,22 +282,23 @@ def test_send_sticker_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_sticker( + message = await default_bot.send_sticker( chat_id, sticker, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_sticker( + await default_bot.send_sticker( chat_id, sticker, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_sticker_default_protect_content(self, chat_id, sticker, default_bot): - protected = default_bot.send_sticker(chat_id, sticker) + async def test_send_sticker_default_protect_content(self, chat_id, sticker, default_bot): + protected = await default_bot.send_sticker(chat_id, sticker) assert protected.has_protected_content - unprotected = default_bot.send_sticker(chat_id, sticker, protect_content=False) + unprotected = await default_bot.send_sticker(chat_id, sticker, protect_content=False) assert not unprotected.has_protected_content def test_to_dict(self, sticker): @@ -299,18 +315,21 @@ def test_to_dict(self, sticker): assert sticker_dict['thumb'] == sticker.thumb.to_dict() @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_sticker(chat_id, open(os.devnull, 'rb')) + await bot.send_sticker(chat_id, open(os.devnull, 'rb')) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_sticker(chat_id, '') + await bot.send_sticker(chat_id, '') - def test_error_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_sticker(chat_id) + await bot.send_sticker(chat_id) def test_equality(self, sticker): a = Sticker( @@ -345,12 +364,12 @@ def test_equality(self, sticker): @pytest.fixture(scope='function') -def sticker_set(bot): - ss = bot.get_sticker_set(f'test_by_{bot.username}') +async def sticker_set(bot): + ss = await bot.get_sticker_set(f'test_by_{bot.username}') if len(ss.stickers) > 100: try: for i in range(1, 50): - bot.delete_sticker_from_set(ss.stickers[-i].file_id) + await bot.delete_sticker_from_set(ss.stickers[-i].file_id) except BadRequest as e: if e.message == 'Stickerset_not_modified': return ss @@ -359,8 +378,9 @@ def sticker_set(bot): @pytest.fixture(scope='function') -def animated_sticker_set(bot): - ss = bot.get_sticker_set(f'animated_test_by_{bot.username}') +@pytest.mark.asyncio +async def animated_sticker_set(bot): + ss = await bot.get_sticker_set(f'animated_test_by_{bot.username}') if len(ss.stickers) > 100: try: for i in range(1, 50): @@ -373,12 +393,13 @@ def animated_sticker_set(bot): @pytest.fixture(scope='function') -def video_sticker_set(bot): - ss = bot.get_sticker_set(f'video_test_by_{bot.username}') +@pytest.mark.asyncio +async def video_sticker_set(bot): + ss = await bot.get_sticker_set(f'video_test_by_{bot.username}') if len(ss.stickers) > 100: try: for i in range(1, 50): - bot.delete_sticker_from_set(ss.stickers[-i].file_id) + await bot.delete_sticker_from_set(ss.stickers[-i].file_id) except BadRequest as e: if e.message == 'Stickerset_not_modified': return ss @@ -421,7 +442,8 @@ def test_de_json(self, bot, sticker): assert sticker_set.stickers == self.stickers assert sticker_set.thumb == sticker.thumb - def test_create_sticker_set( + @pytest.mark.asyncio + async def test_create_sticker_set( self, bot, chat_id, sticker_file, animated_sticker_file, video_sticker_file ): """Creates the sticker set (if needed) which is required for tests. Make sure that this @@ -430,13 +452,13 @@ def test_create_sticker_set( test_by = f"test_by_{bot.username}" for sticker_set in [test_by, f'animated_{test_by}', f'video_{test_by}']: try: - bot.get_sticker_set(sticker_set) + await bot.get_sticker_set(sticker_set) except BadRequest as e: if not e.message == "Stickerset_invalid": raise e if sticker_set.startswith(test_by): - s = bot.create_new_sticker_set( + s = await bot.create_new_sticker_set( chat_id, name=sticker_set, title="Sticker Test", @@ -445,7 +467,7 @@ def test_create_sticker_set( ) assert s elif sticker_set.startswith("animated"): - a = bot.create_new_sticker_set( + a = await bot.create_new_sticker_set( chat_id, name=sticker_set, title="Animated Test", @@ -454,7 +476,7 @@ def test_create_sticker_set( ) assert a elif sticker_set.startswith("video"): - v = bot.create_new_sticker_set( + v = await bot.create_new_sticker_set( chat_id, name=sticker_set, title="Video Test", @@ -464,16 +486,17 @@ def test_create_sticker_set( assert v @flaky(3, 1) - def test_bot_methods_1_png(self, bot, chat_id, sticker_file): + @pytest.mark.asyncio + async def test_bot_methods_1_png(self, bot, chat_id, sticker_file): with data_file('telegram_sticker.png').open('rb') as f: # chat_id was hardcoded as 95205500 but it stopped working for some reason - file = bot.upload_sticker_file(chat_id, f) + file = await bot.upload_sticker_file(chat_id, f) assert file - assert bot.add_sticker_to_set( + assert await bot.add_sticker_to_set( chat_id, f'test_by_{bot.username}', png_sticker=file.file_id, emojis='😄' ) # Also test with file input and mask - assert bot.add_sticker_to_set( + assert await bot.add_sticker_to_set( chat_id, f'test_by_{bot.username}', png_sticker=sticker_file, @@ -482,8 +505,9 @@ def test_bot_methods_1_png(self, bot, chat_id, sticker_file): ) @flaky(3, 1) - def test_bot_methods_1_tgs(self, bot, chat_id): - assert bot.add_sticker_to_set( + @pytest.mark.asyncio + async def test_bot_methods_1_tgs(self, bot, chat_id): + assert await bot.add_sticker_to_set( chat_id, f'animated_test_by_{bot.username}', tgs_sticker=data_file('telegram_animated_sticker.tgs').open('rb'), @@ -491,9 +515,10 @@ def test_bot_methods_1_tgs(self, bot, chat_id): ) @flaky(3, 1) - def test_bot_methods_1_webm(self, bot, chat_id): + @pytest.mark.asyncio + async def test_bot_methods_1_webm(self, bot, chat_id): with Path('tests/data/telegram_video_sticker.webm').open('rb') as f: - assert bot.add_sticker_to_set( + assert await bot.add_sticker_to_set( chat_id, f'video_test_by_{bot.username}', webm_sticker=f, emojis='🤔' ) @@ -509,35 +534,42 @@ def test_sticker_set_to_dict(self, sticker_set): assert sticker_set_dict['stickers'][0] == sticker_set.stickers[0].to_dict() @flaky(3, 1) - def test_bot_methods_2_png(self, bot, sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_2_png(self, bot, sticker_set): file_id = sticker_set.stickers[0].file_id - assert bot.set_sticker_position_in_set(file_id, 1) + assert await bot.set_sticker_position_in_set(file_id, 1) @flaky(3, 1) - def test_bot_methods_2_tgs(self, bot, animated_sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_2_tgs(self, bot, animated_sticker_set): file_id = animated_sticker_set.stickers[0].file_id - assert bot.set_sticker_position_in_set(file_id, 1) + assert await bot.set_sticker_position_in_set(file_id, 1) @flaky(3, 1) - def test_bot_methods_2_webm(self, bot, video_sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_2_webm(self, bot, video_sticker_set): file_id = video_sticker_set.stickers[0].file_id - assert bot.set_sticker_position_in_set(file_id, 1) + assert await bot.set_sticker_position_in_set(file_id, 1) @flaky(10, 1) - def test_bot_methods_3_png(self, bot, chat_id, sticker_set_thumb_file): + @pytest.mark.asyncio + async def test_bot_methods_3_png(self, bot, chat_id, sticker_set_thumb_file): sleep(1) - assert bot.set_sticker_set_thumb( + assert await bot.set_sticker_set_thumb( f'test_by_{bot.username}', chat_id, sticker_set_thumb_file ) @flaky(10, 1) - def test_bot_methods_3_tgs(self, bot, chat_id, animated_sticker_file, animated_sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_3_tgs( + self, bot, chat_id, animated_sticker_file, animated_sticker_set + ): sleep(1) animated_test = f'animated_test_by_{bot.username}' - assert bot.set_sticker_set_thumb(animated_test, chat_id, animated_sticker_file) + assert await bot.set_sticker_set_thumb(animated_test, chat_id, animated_sticker_file) file_id = animated_sticker_set.stickers[-1].file_id # also test with file input and mask - assert bot.set_sticker_set_thumb(animated_test, chat_id, file_id) + assert await bot.set_sticker_set_thumb(animated_test, chat_id, file_id) # TODO: Try the below by creating a custom .webm and not by downloading another pack's thumb @pytest.mark.skip( @@ -548,45 +580,50 @@ def test_bot_methods_3_webm(self, bot, chat_id, video_sticker_file, video_sticke pass @flaky(10, 1) - def test_bot_methods_4_png(self, bot, sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_4_png(self, bot, sticker_set): sleep(1) file_id = sticker_set.stickers[-1].file_id - assert bot.delete_sticker_from_set(file_id) + assert await bot.delete_sticker_from_set(file_id) @flaky(10, 1) - def test_bot_methods_4_tgs(self, bot, animated_sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_4_tgs(self, bot, animated_sticker_set): sleep(1) file_id = animated_sticker_set.stickers[-1].file_id - assert bot.delete_sticker_from_set(file_id) + assert await bot.delete_sticker_from_set(file_id) @flaky(10, 1) - def test_bot_methods_4_webm(self, bot, video_sticker_set): + @pytest.mark.asyncio + async def test_bot_methods_4_webm(self, bot, video_sticker_set): sleep(1) file_id = video_sticker_set.stickers[-1].file_id - assert bot.delete_sticker_from_set(file_id) + assert await bot.delete_sticker_from_set(file_id) - def test_upload_sticker_file_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_upload_sticker_file_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('png_sticker') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.upload_sticker_file(chat_id, file) + await bot.upload_sticker_file(chat_id, file) assert test_flag monkeypatch.delattr(bot, '_post') - def test_create_new_sticker_set_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_create_new_sticker_set_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = ( data.get('png_sticker') == expected @@ -595,7 +632,7 @@ def make_assertion(_, data, *args, **kwargs): ) monkeypatch.setattr(bot, '_post', make_assertion) - bot.create_new_sticker_set( + await bot.create_new_sticker_set( chat_id, 'name', 'title', @@ -607,46 +644,49 @@ def make_assertion(_, data, *args, **kwargs): assert test_flag monkeypatch.delattr(bot, '_post') - def test_add_sticker_to_set_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_add_sticker_to_set_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('png_sticker') == expected and data.get('tgs_sticker') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.add_sticker_to_set(chat_id, 'name', 'emoji', png_sticker=file, tgs_sticker=file) + await bot.add_sticker_to_set(chat_id, 'name', 'emoji', png_sticker=file, tgs_sticker=file) assert test_flag monkeypatch.delattr(bot, '_post') - def test_set_sticker_set_thumb_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_set_sticker_set_thumb_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('thumb') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.set_sticker_set_thumb('name', chat_id, thumb=file) + await bot.set_sticker_set_thumb('name', chat_id, thumb=file) assert test_flag monkeypatch.delattr(bot, '_post') - def test_get_file_instance_method(self, monkeypatch, sticker): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, sticker): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == sticker.file_id assert check_shortcut_signature(Sticker.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(sticker.get_file, sticker.get_bot(), 'get_file') - assert check_defaults_handling(sticker.get_file, sticker.get_bot()) + assert await check_shortcut_call(sticker.get_file, sticker.get_bot(), 'get_file') + assert await check_defaults_handling(sticker.get_file, sticker.get_bot()) monkeypatch.setattr(sticker.get_bot(), 'get_file', make_assertion) - assert sticker.get_file() + assert await sticker.get_file() def test_equality(self): a = StickerSet( diff --git a/tests/test_stringcommandhandler.py b/tests/test_stringcommandhandler.py deleted file mode 100644 index 6aca8211088..00000000000 --- a/tests/test_stringcommandhandler.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Bot, - Update, - Message, - User, - Chat, - CallbackQuery, - InlineQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import StringCommandHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'inline_query', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=1, **request.param) - - -class TestStringCommandHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = StringCommandHandler('sleepy', self.callback_context) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, str) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and context.user_data is None - and context.chat_data is None - and isinstance(context.bot_data, dict) - ) - - def callback_context_args(self, update, context): - self.test_flag = context.args == ['one', 'two'] - - def test_other_update_types(self, false_update): - handler = StringCommandHandler('test', self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp): - handler = StringCommandHandler('test', self.callback_context) - dp.add_handler(handler) - - dp.process_update('/test') - assert self.test_flag - - def test_context_args(self, dp): - handler = StringCommandHandler('test', self.callback_context_args) - dp.add_handler(handler) - - dp.process_update('/test') - assert not self.test_flag - - dp.process_update('/test one two') - assert self.test_flag diff --git a/tests/test_stringregexhandler.py b/tests/test_stringregexhandler.py deleted file mode 100644 index 472ac36b9a9..00000000000 --- a/tests/test_stringregexhandler.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from queue import Queue - -import pytest - -from telegram import ( - Bot, - Update, - Message, - User, - Chat, - CallbackQuery, - InlineQuery, - ChosenInlineResult, - ShippingQuery, - PreCheckoutQuery, -) -from telegram.ext import StringRegexHandler, CallbackContext, JobQueue - -message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - -params = [ - {'message': message}, - {'edited_message': message}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, - {'channel_post': message}, - {'edited_channel_post': message}, - {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, - {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, - {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, - {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, - {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, -] - -ids = ( - 'message', - 'edited_message', - 'callback_query', - 'channel_post', - 'edited_channel_post', - 'inline_query', - 'chosen_inline_result', - 'shipping_query', - 'pre_checkout_query', - 'callback_query_without_message', -) - - -@pytest.fixture(scope='class', params=params, ids=ids) -def false_update(request): - return Update(update_id=1, **request.param) - - -class TestStringRegexHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = StringRegexHandler('pfft', self.callback_context) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, str) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - ) - - def callback_context_pattern(self, update, context): - if context.matches[0].groups(): - self.test_flag = context.matches[0].groups() == ('t', ' message') - if context.matches[0].groupdict(): - self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' message'} - - def test_basic(self, dp): - handler = StringRegexHandler('(?P.*)est(?P.*)', self.callback_context) - dp.add_handler(handler) - - assert handler.check_update('test message') - dp.process_update('test message') - assert self.test_flag - - assert not handler.check_update('does not match') - - def test_other_update_types(self, false_update): - handler = StringRegexHandler('test', self.callback_context) - assert not handler.check_update(false_update) - - def test_context(self, dp): - handler = StringRegexHandler(r'(t)est(.*)', self.callback_context) - dp.add_handler(handler) - - dp.process_update('test message') - assert self.test_flag - - def test_context_pattern(self, dp): - handler = StringRegexHandler(r'(t)est(.*)', self.callback_context_pattern) - dp.add_handler(handler) - - dp.process_update('test message') - assert self.test_flag - - dp.remove_handler(handler) - handler = StringRegexHandler(r'(t)est(.*)', self.callback_context_pattern) - dp.add_handler(handler) - - dp.process_update('test message') - assert self.test_flag diff --git a/tests/test_typehandler.py b/tests/test_typehandler.py deleted file mode 100644 index 70398eae23b..00000000000 --- a/tests/test_typehandler.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -from collections import OrderedDict -from queue import Queue - -import pytest - -from telegram import Bot -from telegram.ext import TypeHandler, CallbackContext, JobQueue - - -class TestTypeHandler: - test_flag = False - - def test_slot_behaviour(self, mro_slots): - inst = TypeHandler(dict, self.callback_context) - for attr in inst.__slots__: - assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - - @pytest.fixture(autouse=True) - def reset(self): - self.test_flag = False - - def callback_context(self, update, context): - self.test_flag = ( - isinstance(context, CallbackContext) - and isinstance(context.bot, Bot) - and isinstance(update, dict) - and isinstance(context.update_queue, Queue) - and isinstance(context.job_queue, JobQueue) - and context.user_data is None - and context.chat_data is None - and isinstance(context.bot_data, dict) - ) - - def test_basic(self, dp): - handler = TypeHandler(dict, self.callback_context) - dp.add_handler(handler) - - assert handler.check_update({'a': 1, 'b': 2}) - assert not handler.check_update('not a dict') - dp.process_update({'a': 1, 'b': 2}) - assert self.test_flag - - def test_strict(self): - handler = TypeHandler(dict, self.callback_context, strict=True) - o = OrderedDict({'a': 1, 'b': 2}) - assert handler.check_update({'a': 1, 'b': 2}) - assert not handler.check_update(o) - - def test_context(self, dp): - handler = TypeHandler(dict, self.callback_context) - dp.add_handler(handler) - - dp.process_update({'a': 1, 'b': 2}) - assert self.test_flag diff --git a/tests/test_updater.py b/tests/test_updater.py deleted file mode 100644 index 3e7920d0468..00000000000 --- a/tests/test_updater.py +++ /dev/null @@ -1,654 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -import asyncio -import logging -import os -import signal -import sys -import threading -from contextlib import contextmanager -from pathlib import Path - -from flaky import flaky -from functools import partial -from queue import Queue -from random import randrange -from threading import Thread, Event -from time import sleep - -from urllib.request import Request, urlopen -from urllib.error import HTTPError - -import pytest -from .conftest import DictBot - -from telegram import ( - Message, - User, - Chat, - Update, - Bot, - InlineKeyboardMarkup, - InlineKeyboardButton, -) -from telegram.error import Unauthorized, InvalidToken, TimedOut, RetryAfter, TelegramError -from telegram.ext import ( - InvalidCallbackData, - ExtBot, - Updater, - UpdaterBuilder, - DispatcherBuilder, -) -from telegram.ext._utils.webhookhandler import WebhookServer - -signalskip = pytest.mark.skipif( - sys.platform == 'win32', - reason="Can't send signals without stopping whole process on windows", -) - - -ASYNCIO_LOCK = threading.Lock() - - -@contextmanager -def set_asyncio_event_loop(loop): - with ASYNCIO_LOCK: - try: - orig_lop = asyncio.get_event_loop() - except RuntimeError: - orig_lop = None - asyncio.set_event_loop(loop) - try: - yield - finally: - asyncio.set_event_loop(orig_lop) - - -class TestUpdater: - message_count = 0 - received = None - attempts = 0 - err_handler_called = Event() - cb_handler_called = Event() - offset = 0 - test_flag = False - - @pytest.fixture(autouse=True) - def reset(self): - self.message_count = 0 - self.received = None - self.attempts = 0 - self.err_handler_called.clear() - self.cb_handler_called.clear() - self.test_flag = False - - def error_handler(self, update, context): - self.received = context.error.message - self.err_handler_called.set() - - def callback(self, update, context): - self.received = update.message.text - self.cb_handler_called.set() - - def test_slot_behaviour(self, updater, mro_slots): - for at in updater.__slots__: - at = f"_Updater{at}" if at.startswith('__') and not at.endswith('__') else at - assert getattr(updater, at, 'err') != 'err', f"got extra slot '{at}'" - assert len(mro_slots(updater)) == len(set(mro_slots(updater))), "duplicate slot" - - def test_manual_init_warning(self, recwarn): - Updater( - bot=None, - dispatcher=None, - update_queue=None, - exception_event=None, - user_signal_handler=None, - ) - assert len(recwarn) == 1 - assert ( - str(recwarn[-1].message) - == '`Updater` instances should be built via the `UpdaterBuilder`.' - ) - assert recwarn[0].filename == __file__, "stacklevel is incorrect!" - - def test_builder(self, updater): - builder_1 = updater.builder() - builder_2 = updater.builder() - assert isinstance(builder_1, UpdaterBuilder) - assert isinstance(builder_2, UpdaterBuilder) - assert builder_1 is not builder_2 - - # Make sure that setting a token doesn't raise an exception - # i.e. check that the builders are "empty"/new - builder_1.token(updater.bot.token) - builder_2.token(updater.bot.token) - - def test_warn_con_pool(self, bot, recwarn, dp): - DispatcherBuilder().bot(bot).workers(5).build() - UpdaterBuilder().bot(bot).workers(8).build() - UpdaterBuilder().bot(bot).workers(2).build() - assert len(recwarn) == 2 - for idx, value in enumerate((9, 12)): - warning = ( - 'The Connection pool of Request object is smaller (8) than the ' - f'recommended value of {value}.' - ) - assert str(recwarn[idx].message) == warning - assert recwarn[idx].filename == __file__, "wrong stacklevel!" - - @pytest.mark.parametrize( - ('error',), - argvalues=[(TelegramError('Test Error 2'),), (Unauthorized('Test Unauthorized'),)], - ids=('TelegramError', 'Unauthorized'), - ) - def test_get_updates_normal_err(self, monkeypatch, updater, error): - def test(*args, **kwargs): - raise error - - monkeypatch.setattr(updater.bot, 'get_updates', test) - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - updater.dispatcher.add_error_handler(self.error_handler) - updater.start_polling(0.01) - - # Make sure that the error handler was called - self.err_handler_called.wait() - assert self.received == error.message - - # Make sure that Updater polling thread keeps running - self.err_handler_called.clear() - self.err_handler_called.wait() - - @pytest.mark.filterwarnings('ignore:.*:pytest.PytestUnhandledThreadExceptionWarning') - def test_get_updates_bailout_err(self, monkeypatch, updater, caplog): - error = InvalidToken() - - def test(*args, **kwargs): - raise error - - with caplog.at_level(logging.DEBUG): - monkeypatch.setattr(updater.bot, 'get_updates', test) - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - updater.dispatcher.add_error_handler(self.error_handler) - updater.start_polling(0.01) - assert self.err_handler_called.wait(1) is not True - - sleep(1) - # NOTE: This test might hit a race condition and fail (though the 1 seconds delay above - # should work around it). - # NOTE: Checking Updater.running is problematic because it is not set to False when there's - # an unhandled exception. - # TODO: We should have a way to poll Updater status and decide if it's running or not. - import pprint - - pprint.pprint([rec.getMessage() for rec in caplog.get_records('call')]) - assert any( - f'unhandled exception in Bot:{updater.bot.id}:updater' in rec.getMessage() - for rec in caplog.get_records('call') - ) - - @pytest.mark.parametrize( - ('error',), argvalues=[(RetryAfter(0.01),), (TimedOut(),)], ids=('RetryAfter', 'TimedOut') - ) - def test_get_updates_retries(self, monkeypatch, updater, error): - event = Event() - - def test(*args, **kwargs): - event.set() - raise error - - monkeypatch.setattr(updater.bot, 'get_updates', test) - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - updater.dispatcher.add_error_handler(self.error_handler) - updater.start_polling(0.01) - - # Make sure that get_updates was called, but not the error handler - event.wait() - assert self.err_handler_called.wait(0.5) is not True - assert self.received != error.message - - # Make sure that Updater polling thread keeps running - event.clear() - event.wait() - assert self.err_handler_called.wait(0.5) is not True - - @pytest.mark.parametrize('ext_bot', [True, False]) - def test_webhook(self, monkeypatch, updater, ext_bot): - # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler - # that depends on this distinction works - if ext_bot and not isinstance(updater.bot, ExtBot): - updater.bot = ExtBot(updater.bot.token) - if not ext_bot and not type(updater.bot) is Bot: - updater.bot = DictBot(updater.bot.token) - - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook(ip, port, url_path='TOKEN') - sleep(0.2) - try: - # Now, we send an update to the server via urlopen - update = Update( - 1, - message=Message( - 1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook' - ), - ) - self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') - sleep(0.2) - assert q.get(False) == update - - # Returns 404 if path is incorrect - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg(ip, port, None, 'webookhandler.py') - assert excinfo.value.code == 404 - - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg( - ip, port, None, 'webookhandler.py', get_method=lambda: 'HEAD' - ) - assert excinfo.value.code == 404 - - # Test multiple shutdown() calls - updater.httpd.shutdown() - finally: - updater.httpd.shutdown() - sleep(0.2) - assert not updater.httpd.is_running - updater.stop() - - @pytest.mark.parametrize('invalid_data', [True, False]) - def test_webhook_arbitrary_callback_data(self, monkeypatch, updater, invalid_data): - """Here we only test one simple setup. telegram.ext.ExtBot.insert_callback_data is tested - extensively in test_bot.py in conjunction with get_updates.""" - updater.bot.arbitrary_callback_data = True - try: - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook(ip, port, url_path='TOKEN') - sleep(0.2) - try: - # Now, we send an update to the server via urlopen - reply_markup = InlineKeyboardMarkup.from_button( - InlineKeyboardButton(text='text', callback_data='callback_data') - ) - if not invalid_data: - reply_markup = updater.bot.callback_data_cache.process_keyboard(reply_markup) - - message = Message( - 1, - None, - None, - reply_markup=reply_markup, - ) - update = Update(1, message=message) - self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') - sleep(0.2) - received_update = q.get(False) - assert received_update == update - - button = received_update.message.reply_markup.inline_keyboard[0][0] - if invalid_data: - assert isinstance(button.callback_data, InvalidCallbackData) - else: - assert button.callback_data == 'callback_data' - - # Test multiple shutdown() calls - updater.httpd.shutdown() - finally: - updater.httpd.shutdown() - sleep(0.2) - assert not updater.httpd.is_running - updater.stop() - finally: - updater.bot.arbitrary_callback_data = False - updater.bot.callback_data_cache.clear_callback_data() - updater.bot.callback_data_cache.clear_callback_queries() - - @pytest.mark.parametrize('use_dispatcher', (True, False)) - def test_start_webhook_no_warning_or_error_logs( - self, caplog, updater, monkeypatch, use_dispatcher - ): - if not use_dispatcher: - updater.dispatcher = None - - self.test_flag = 0 - - def set_flag(): - self.test_flag += 1 - - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot._request, 'stop', lambda *args, **kwargs: set_flag()) - # prevent api calls from @info decorator when updater.bot.id is used in thread names - monkeypatch.setattr(updater.bot, '_bot', User(id=123, first_name='bot', is_bot=True)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - with caplog.at_level(logging.WARNING): - updater.start_webhook(ip, port) - updater.stop() - assert not caplog.records - # Make sure that bot.request.stop() has been called exactly once - assert self.test_flag == 1 - - def test_webhook_ssl(self, monkeypatch, updater): - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - tg_err = False - try: - updater._start_webhook( - ip, - port, - url_path='TOKEN', - cert=Path(__file__).as_posix(), - key=Path(__file__).as_posix(), - bootstrap_retries=0, - drop_pending_updates=False, - webhook_url=None, - allowed_updates=None, - ) - except TelegramError: - tg_err = True - assert tg_err - - def test_webhook_no_ssl(self, monkeypatch, updater): - q = Queue() - monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook(ip, port, webhook_url=None) - sleep(0.2) - - # Now, we send an update to the server via urlopen - update = Update( - 1, - message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), - ) - self._send_webhook_msg(ip, port, update.to_json()) - sleep(0.2) - assert q.get(False) == update - updater.stop() - - def test_webhook_ssl_just_for_telegram(self, monkeypatch, updater): - q = Queue() - - def set_webhook(**kwargs): - self.test_flag.append(bool(kwargs.get('certificate'))) - return True - - orig_wh_server_init = WebhookServer.__init__ - - def webhook_server_init(*args): - self.test_flag = [args[-1] is None] - orig_wh_server_init(*args) - - monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - monkeypatch.setattr( - 'telegram.ext._utils.webhookhandler.WebhookServer.__init__', webhook_server_init - ) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - updater.start_webhook(ip, port, webhook_url=None, cert=Path(__file__).as_posix()) - sleep(0.2) - - # Now, we send an update to the server via urlopen - update = Update( - 1, - message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), - ) - self._send_webhook_msg(ip, port, update.to_json()) - sleep(0.2) - assert q.get(False) == update - updater.stop() - assert self.test_flag == [True, True] - - @pytest.mark.parametrize('pass_max_connections', [True, False]) - def test_webhook_max_connections(self, monkeypatch, updater, pass_max_connections): - q = Queue() - max_connections = 42 - - def set_webhook(**kwargs): - print(kwargs) - self.test_flag = kwargs.get('max_connections') == ( - max_connections if pass_max_connections else 40 - ) - return True - - monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) - monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - - ip = '127.0.0.1' - port = randrange(1024, 49152) # Select random port - if pass_max_connections: - updater.start_webhook(ip, port, webhook_url=None, max_connections=max_connections) - else: - updater.start_webhook(ip, port, webhook_url=None) - - sleep(0.2) - - # Now, we send an update to the server via urlopen - update = Update( - 1, - message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), - ) - self._send_webhook_msg(ip, port, update.to_json()) - sleep(0.2) - assert q.get(False) == update - updater.stop() - assert self.test_flag is True - - @pytest.mark.parametrize(('error',), argvalues=[(TelegramError(''),)], ids=('TelegramError',)) - def test_bootstrap_retries_success(self, monkeypatch, updater, error): - retries = 2 - - def attempt(*args, **kwargs): - if self.attempts < retries: - self.attempts += 1 - raise error - - monkeypatch.setattr(updater.bot, 'set_webhook', attempt) - - updater.running = True - updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) - assert self.attempts == retries - - @pytest.mark.parametrize( - ('error', 'attempts'), - argvalues=[(TelegramError(''), 2), (Unauthorized(''), 1), (InvalidToken(), 1)], - ids=('TelegramError', 'Unauthorized', 'InvalidToken'), - ) - def test_bootstrap_retries_error(self, monkeypatch, updater, error, attempts): - retries = 1 - - def attempt(*args, **kwargs): - self.attempts += 1 - raise error - - monkeypatch.setattr(updater.bot, 'set_webhook', attempt) - - updater.running = True - with pytest.raises(type(error)): - updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) - assert self.attempts == attempts - - @pytest.mark.parametrize('drop_pending_updates', (True, False)) - def test_bootstrap_clean_updates(self, monkeypatch, updater, drop_pending_updates): - # As dropping pending updates is done by passing `drop_pending_updates` to - # set_webhook, we just check that we pass the correct value - self.test_flag = False - - def delete_webhook(**kwargs): - self.test_flag = kwargs.get('drop_pending_updates') == drop_pending_updates - - monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) - - updater.running = True - updater._bootstrap( - 1, - drop_pending_updates=drop_pending_updates, - webhook_url=None, - allowed_updates=None, - bootstrap_interval=0, - ) - assert self.test_flag is True - - @flaky(3, 1) - def test_webhook_invalid_posts(self, updater): - ip = '127.0.0.1' - port = randrange(1024, 49152) # select random port for travis - thr = Thread( - target=updater._start_webhook, args=(ip, port, '', None, None, 0, False, None, None) - ) - thr.start() - - sleep(0.2) - - try: - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg( - ip, port, 'data', content_type='application/xml' - ) - assert excinfo.value.code == 403 - - with pytest.raises(HTTPError) as excinfo: - self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2) - assert excinfo.value.code == 500 - - # TODO: prevent urllib or the underlying from adding content-length - # with pytest.raises(HTTPError) as excinfo: - # self._send_webhook_msg(ip, port, 'dummy-payload', content_len=None) - # assert excinfo.value.code == 411 - - with pytest.raises(HTTPError): - self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number') - assert excinfo.value.code == 500 - - finally: - updater.httpd.shutdown() - thr.join() - - def _send_webhook_msg( - self, - ip, - port, - payload_str, - url_path='', - content_len=-1, - content_type='application/json', - get_method=None, - ): - headers = { - 'content-type': content_type, - } - - if not payload_str: - content_len = None - payload = None - else: - payload = bytes(payload_str, encoding='utf-8') - - if content_len == -1: - content_len = len(payload) - - if content_len is not None: - headers['content-length'] = str(content_len) - - url = f'http://{ip}:{port}/{url_path}' - - req = Request(url, data=payload, headers=headers) - - if get_method is not None: - req.get_method = get_method - - return urlopen(req) - - def signal_sender(self, updater): - sleep(0.2) - while not updater.running: - sleep(0.2) - - os.kill(os.getpid(), signal.SIGTERM) - - @signalskip - def test_idle(self, updater, caplog): - updater.start_polling(0.01) - Thread(target=partial(self.signal_sender, updater=updater)).start() - - with caplog.at_level(logging.INFO): - updater.idle() - - # There is a chance of a conflict when getting updates since there can be many tests - # (bots) running simultaneously while testing in github actions. - records = caplog.records.copy() # To avoid iterating and removing at same time - for idx, log in enumerate(records): - print(idx, log) - msg = log.getMessage() - if msg.startswith('Error while getting Updates: Conflict'): - caplog.records.remove(log) # For stability - - elif msg.startswith('No error handlers are registered'): - caplog.records.remove(log) - - assert len(caplog.records) == 2, caplog.records - - rec = caplog.records[-2] - assert rec.getMessage().startswith(f'Received signal {signal.SIGTERM}') - assert rec.levelname == 'INFO' - - rec = caplog.records[-1] - assert rec.getMessage().startswith('Scheduler has been shut down') - assert rec.levelname == 'INFO' - - # If we get this far, idle() ran through - sleep(0.5) - assert updater.running is False - - @signalskip - def test_user_signal(self, updater): - temp_var = {'a': 0} - - def user_signal_inc(signum, frame): - temp_var['a'] = 1 - - updater.user_signal_handler = user_signal_inc - updater.start_polling(0.01) - Thread(target=partial(self.signal_sender, updater=updater)).start() - updater.idle() - # If we get this far, idle() ran through - sleep(0.5) - assert updater.running is False - assert temp_var['a'] != 0 diff --git a/tests/test_user.py b/tests/test_user.py index 3d375a6145d..d4f621d3ec2 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -65,11 +65,6 @@ class TestUser: can_read_all_group_messages = True supports_inline_queries = False - def test_slot_behaviour(self, user, mro_slots): - for attr in user.__slots__: - assert getattr(user, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(user)) == len(set(mro_slots(user))), "duplicate slot" - def test_de_json(self, json_dict, bot): user = User.de_json(json_dict, bot) @@ -133,165 +128,179 @@ def test_link(self, user): user.username = None assert user.link is None - def test_instance_method_get_profile_photos(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_get_profile_photos(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['user_id'] == user.id assert check_shortcut_signature( User.get_profile_photos, Bot.get_user_profile_photos, ['user_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( user.get_profile_photos, user.get_bot(), 'get_user_profile_photos' ) - assert check_defaults_handling(user.get_profile_photos, user.get_bot()) + assert await check_defaults_handling(user.get_profile_photos, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'get_user_profile_photos', make_assertion) - assert user.get_profile_photos() + assert await user.get_profile_photos() - def test_instance_method_pin_message(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_pin_message(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id assert check_shortcut_signature(User.pin_message, Bot.pin_chat_message, ['chat_id'], []) - assert check_shortcut_call(user.pin_message, user.get_bot(), 'pin_chat_message') - assert check_defaults_handling(user.pin_message, user.get_bot()) + assert await check_shortcut_call(user.pin_message, user.get_bot(), 'pin_chat_message') + assert await check_defaults_handling(user.pin_message, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'pin_chat_message', make_assertion) - assert user.pin_message(1) + assert await user.pin_message(1) - def test_instance_method_unpin_message(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_unpin_message(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id assert check_shortcut_signature( User.unpin_message, Bot.unpin_chat_message, ['chat_id'], [] ) - assert check_shortcut_call(user.unpin_message, user.get_bot(), 'unpin_chat_message') - assert check_defaults_handling(user.unpin_message, user.get_bot()) + assert await check_shortcut_call(user.unpin_message, user.get_bot(), 'unpin_chat_message') + assert await check_defaults_handling(user.unpin_message, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'unpin_chat_message', make_assertion) - assert user.unpin_message() + assert await user.unpin_message() - def test_instance_method_unpin_all_messages(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_unpin_all_messages(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id assert check_shortcut_signature( User.unpin_all_messages, Bot.unpin_all_chat_messages, ['chat_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( user.unpin_all_messages, user.get_bot(), 'unpin_all_chat_messages' ) - assert check_defaults_handling(user.unpin_all_messages, user.get_bot()) + assert await check_defaults_handling(user.unpin_all_messages, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'unpin_all_chat_messages', make_assertion) - assert user.unpin_all_messages() + assert await user.unpin_all_messages() - def test_instance_method_send_message(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_message(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['text'] == 'test' assert check_shortcut_signature(User.send_message, Bot.send_message, ['chat_id'], []) - assert check_shortcut_call(user.send_message, user.get_bot(), 'send_message') - assert check_defaults_handling(user.send_message, user.get_bot()) + assert await check_shortcut_call(user.send_message, user.get_bot(), 'send_message') + assert await check_defaults_handling(user.send_message, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_message', make_assertion) - assert user.send_message('test') + assert await user.send_message('test') - def test_instance_method_send_photo(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_photo(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['photo'] == 'test_photo' assert check_shortcut_signature(User.send_photo, Bot.send_photo, ['chat_id'], []) - assert check_shortcut_call(user.send_photo, user.get_bot(), 'send_photo') - assert check_defaults_handling(user.send_photo, user.get_bot()) + assert await check_shortcut_call(user.send_photo, user.get_bot(), 'send_photo') + assert await check_defaults_handling(user.send_photo, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_photo', make_assertion) - assert user.send_photo('test_photo') + assert await user.send_photo('test_photo') - def test_instance_method_send_media_group(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_media_group(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['media'] == 'test_media_group' assert check_shortcut_signature( User.send_media_group, Bot.send_media_group, ['chat_id'], [] ) - assert check_shortcut_call(user.send_media_group, user.get_bot(), 'send_media_group') - assert check_defaults_handling(user.send_media_group, user.get_bot()) + assert await check_shortcut_call(user.send_media_group, user.get_bot(), 'send_media_group') + assert await check_defaults_handling(user.send_media_group, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_media_group', make_assertion) - assert user.send_media_group('test_media_group') + assert await user.send_media_group('test_media_group') - def test_instance_method_send_audio(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_audio(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['audio'] == 'test_audio' assert check_shortcut_signature(User.send_audio, Bot.send_audio, ['chat_id'], []) - assert check_shortcut_call(user.send_audio, user.get_bot(), 'send_audio') - assert check_defaults_handling(user.send_audio, user.get_bot()) + assert await check_shortcut_call(user.send_audio, user.get_bot(), 'send_audio') + assert await check_defaults_handling(user.send_audio, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_audio', make_assertion) - assert user.send_audio('test_audio') + assert await user.send_audio('test_audio') - def test_instance_method_send_chat_action(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_chat_action(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['action'] == 'test_chat_action' assert check_shortcut_signature( User.send_chat_action, Bot.send_chat_action, ['chat_id'], [] ) - assert check_shortcut_call(user.send_chat_action, user.get_bot(), 'send_chat_action') - assert check_defaults_handling(user.send_chat_action, user.get_bot()) + assert await check_shortcut_call(user.send_chat_action, user.get_bot(), 'send_chat_action') + assert await check_defaults_handling(user.send_chat_action, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_chat_action', make_assertion) - assert user.send_chat_action('test_chat_action') + assert await user.send_chat_action('test_chat_action') - def test_instance_method_send_contact(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_contact(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['phone_number'] == 'test_contact' assert check_shortcut_signature(User.send_contact, Bot.send_contact, ['chat_id'], []) - assert check_shortcut_call(user.send_contact, user.get_bot(), 'send_contact') - assert check_defaults_handling(user.send_contact, user.get_bot()) + assert await check_shortcut_call(user.send_contact, user.get_bot(), 'send_contact') + assert await check_defaults_handling(user.send_contact, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_contact', make_assertion) - assert user.send_contact(phone_number='test_contact') + assert await user.send_contact(phone_number='test_contact') - def test_instance_method_send_dice(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_dice(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['emoji'] == 'test_dice' assert check_shortcut_signature(User.send_dice, Bot.send_dice, ['chat_id'], []) - assert check_shortcut_call(user.send_dice, user.get_bot(), 'send_dice') - assert check_defaults_handling(user.send_dice, user.get_bot()) + assert await check_shortcut_call(user.send_dice, user.get_bot(), 'send_dice') + assert await check_defaults_handling(user.send_dice, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_dice', make_assertion) - assert user.send_dice(emoji='test_dice') + assert await user.send_dice(emoji='test_dice') - def test_instance_method_send_document(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_document(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['document'] == 'test_document' assert check_shortcut_signature(User.send_document, Bot.send_document, ['chat_id'], []) - assert check_shortcut_call(user.send_document, user.get_bot(), 'send_document') - assert check_defaults_handling(user.send_document, user.get_bot()) + assert await check_shortcut_call(user.send_document, user.get_bot(), 'send_document') + assert await check_defaults_handling(user.send_document, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_document', make_assertion) - assert user.send_document('test_document') + assert await user.send_document('test_document') - def test_instance_method_send_game(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_game(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['game_short_name'] == 'test_game' assert check_shortcut_signature(User.send_game, Bot.send_game, ['chat_id'], []) - assert check_shortcut_call(user.send_game, user.get_bot(), 'send_game') - assert check_defaults_handling(user.send_game, user.get_bot()) + assert await check_shortcut_call(user.send_game, user.get_bot(), 'send_game') + assert await check_defaults_handling(user.send_game, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_game', make_assertion) - assert user.send_game(game_short_name='test_game') + assert await user.send_game(game_short_name='test_game') - def test_instance_method_send_invoice(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_invoice(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): title = kwargs['title'] == 'title' description = kwargs['description'] == 'description' payload = kwargs['payload'] == 'payload' @@ -302,11 +311,11 @@ def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and args assert check_shortcut_signature(User.send_invoice, Bot.send_invoice, ['chat_id'], []) - assert check_shortcut_call(user.send_invoice, user.get_bot(), 'send_invoice') - assert check_defaults_handling(user.send_invoice, user.get_bot()) + assert await check_shortcut_call(user.send_invoice, user.get_bot(), 'send_invoice') + assert await check_defaults_handling(user.send_invoice, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_invoice', make_assertion) - assert user.send_invoice( + assert await user.send_invoice( 'title', 'description', 'payload', @@ -315,124 +324,135 @@ def make_assertion(*_, **kwargs): 'prices', ) - def test_instance_method_send_location(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_location(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['latitude'] == 'test_location' assert check_shortcut_signature(User.send_location, Bot.send_location, ['chat_id'], []) - assert check_shortcut_call(user.send_location, user.get_bot(), 'send_location') - assert check_defaults_handling(user.send_location, user.get_bot()) + assert await check_shortcut_call(user.send_location, user.get_bot(), 'send_location') + assert await check_defaults_handling(user.send_location, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_location', make_assertion) - assert user.send_location('test_location') + assert await user.send_location('test_location') - def test_instance_method_send_sticker(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_sticker(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['sticker'] == 'test_sticker' assert check_shortcut_signature(User.send_sticker, Bot.send_sticker, ['chat_id'], []) - assert check_shortcut_call(user.send_sticker, user.get_bot(), 'send_sticker') - assert check_defaults_handling(user.send_sticker, user.get_bot()) + assert await check_shortcut_call(user.send_sticker, user.get_bot(), 'send_sticker') + assert await check_defaults_handling(user.send_sticker, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_sticker', make_assertion) - assert user.send_sticker('test_sticker') + assert await user.send_sticker('test_sticker') - def test_instance_method_send_video(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_video(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['video'] == 'test_video' assert check_shortcut_signature(User.send_video, Bot.send_video, ['chat_id'], []) - assert check_shortcut_call(user.send_video, user.get_bot(), 'send_video') - assert check_defaults_handling(user.send_video, user.get_bot()) + assert await check_shortcut_call(user.send_video, user.get_bot(), 'send_video') + assert await check_defaults_handling(user.send_video, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_video', make_assertion) - assert user.send_video('test_video') + assert await user.send_video('test_video') - def test_instance_method_send_venue(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_venue(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['title'] == 'test_venue' assert check_shortcut_signature(User.send_venue, Bot.send_venue, ['chat_id'], []) - assert check_shortcut_call(user.send_venue, user.get_bot(), 'send_venue') - assert check_defaults_handling(user.send_venue, user.get_bot()) + assert await check_shortcut_call(user.send_venue, user.get_bot(), 'send_venue') + assert await check_defaults_handling(user.send_venue, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_venue', make_assertion) - assert user.send_venue(title='test_venue') + assert await user.send_venue(title='test_venue') - def test_instance_method_send_video_note(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_video_note(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['video_note'] == 'test_video_note' assert check_shortcut_signature(User.send_video_note, Bot.send_video_note, ['chat_id'], []) - assert check_shortcut_call(user.send_video_note, user.get_bot(), 'send_video_note') - assert check_defaults_handling(user.send_video_note, user.get_bot()) + assert await check_shortcut_call(user.send_video_note, user.get_bot(), 'send_video_note') + assert await check_defaults_handling(user.send_video_note, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_video_note', make_assertion) - assert user.send_video_note('test_video_note') + assert await user.send_video_note('test_video_note') - def test_instance_method_send_voice(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_voice(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['voice'] == 'test_voice' assert check_shortcut_signature(User.send_voice, Bot.send_voice, ['chat_id'], []) - assert check_shortcut_call(user.send_voice, user.get_bot(), 'send_voice') - assert check_defaults_handling(user.send_voice, user.get_bot()) + assert await check_shortcut_call(user.send_voice, user.get_bot(), 'send_voice') + assert await check_defaults_handling(user.send_voice, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_voice', make_assertion) - assert user.send_voice('test_voice') + assert await user.send_voice('test_voice') - def test_instance_method_send_animation(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_animation(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['animation'] == 'test_animation' assert check_shortcut_signature(User.send_animation, Bot.send_animation, ['chat_id'], []) - assert check_shortcut_call(user.send_animation, user.get_bot(), 'send_animation') - assert check_defaults_handling(user.send_animation, user.get_bot()) + assert await check_shortcut_call(user.send_animation, user.get_bot(), 'send_animation') + assert await check_defaults_handling(user.send_animation, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_animation', make_assertion) - assert user.send_animation('test_animation') + assert await user.send_animation('test_animation') - def test_instance_method_send_poll(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_poll(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): return kwargs['chat_id'] == user.id and kwargs['question'] == 'test_poll' assert check_shortcut_signature(User.send_poll, Bot.send_poll, ['chat_id'], []) - assert check_shortcut_call(user.send_poll, user.get_bot(), 'send_poll') - assert check_defaults_handling(user.send_poll, user.get_bot()) + assert await check_shortcut_call(user.send_poll, user.get_bot(), 'send_poll') + assert await check_defaults_handling(user.send_poll, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'send_poll', make_assertion) - assert user.send_poll(question='test_poll', options=[1, 2]) + assert await user.send_poll(question='test_poll', options=[1, 2]) - def test_instance_method_send_copy(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_send_copy(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): user_id = kwargs['chat_id'] == user.id message_id = kwargs['message_id'] == 'message_id' from_chat_id = kwargs['from_chat_id'] == 'from_chat_id' return from_chat_id and message_id and user_id assert check_shortcut_signature(User.send_copy, Bot.copy_message, ['chat_id'], []) - assert check_shortcut_call(user.copy_message, user.get_bot(), 'copy_message') - assert check_defaults_handling(user.copy_message, user.get_bot()) + assert await check_shortcut_call(user.copy_message, user.get_bot(), 'copy_message') + assert await check_defaults_handling(user.copy_message, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'copy_message', make_assertion) - assert user.send_copy(from_chat_id='from_chat_id', message_id='message_id') + assert await user.send_copy(from_chat_id='from_chat_id', message_id='message_id') - def test_instance_method_copy_message(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_copy_message(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 'chat_id' message_id = kwargs['message_id'] == 'message_id' user_id = kwargs['from_chat_id'] == user.id return chat_id and message_id and user_id assert check_shortcut_signature(User.copy_message, Bot.copy_message, ['from_chat_id'], []) - assert check_shortcut_call(user.copy_message, user.get_bot(), 'copy_message') - assert check_defaults_handling(user.copy_message, user.get_bot()) + assert await check_shortcut_call(user.copy_message, user.get_bot(), 'copy_message') + assert await check_defaults_handling(user.copy_message, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'copy_message', make_assertion) - assert user.copy_message(chat_id='chat_id', message_id='message_id') + assert await user.copy_message(chat_id='chat_id', message_id='message_id') - def test_instance_method_approve_join_request(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_approve_join_request(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 'chat_id' user_id = kwargs['user_id'] == user.id return chat_id and user_id @@ -440,16 +460,17 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( User.approve_join_request, Bot.approve_chat_join_request, ['user_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( user.approve_join_request, user.get_bot(), 'approve_chat_join_request' ) - assert check_defaults_handling(user.approve_join_request, user.get_bot()) + assert await check_defaults_handling(user.approve_join_request, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'approve_chat_join_request', make_assertion) - assert user.approve_join_request(chat_id='chat_id') + assert await user.approve_join_request(chat_id='chat_id') - def test_instance_method_decline_join_request(self, monkeypatch, user): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_instance_method_decline_join_request(self, monkeypatch, user): + async def make_assertion(*_, **kwargs): chat_id = kwargs['chat_id'] == 'chat_id' user_id = kwargs['user_id'] == user.id return chat_id and user_id @@ -457,15 +478,16 @@ def make_assertion(*_, **kwargs): assert check_shortcut_signature( User.decline_join_request, Bot.decline_chat_join_request, ['user_id'], [] ) - assert check_shortcut_call( + assert await check_shortcut_call( user.decline_join_request, user.get_bot(), 'decline_chat_join_request' ) - assert check_defaults_handling(user.decline_join_request, user.get_bot()) + assert await check_defaults_handling(user.decline_join_request, user.get_bot()) monkeypatch.setattr(user.get_bot(), 'decline_chat_join_request', make_assertion) - assert user.decline_join_request(chat_id='chat_id') + assert await user.decline_join_request(chat_id='chat_id') - def test_mention_html(self, user): + @pytest.mark.asyncio + async def test_mention_html(self, user): expected = '{}' assert user.mention_html() == expected.format(user.id, user.full_name) @@ -490,7 +512,8 @@ def test_mention_markdown(self, user): ) assert user.mention_markdown(user.username) == expected.format(user.username, user.id) - def test_mention_markdown_v2(self, user): + @pytest.mark.asyncio + async def test_mention_markdown_v2(self, user): user.first_name = 'first{name' user.last_name = 'last_name' diff --git a/tests/test_venue.py b/tests/test_venue.py index 736116dc6f3..dbca6458076 100644 --- a/tests/test_venue.py +++ b/tests/test_venue.py @@ -21,6 +21,7 @@ from telegram import Location, Venue from telegram.error import BadRequest +from telegram.request import RequestData @pytest.fixture(scope='class') @@ -70,11 +71,13 @@ def test_de_json(self, bot): assert venue.google_place_id == self.google_place_id assert venue.google_place_type == self.google_place_type - def test_send_with_venue(self, monkeypatch, bot, chat_id, venue): - def test(url, data, **kwargs): + @pytest.mark.asyncio + async def test_send_with_venue(self, monkeypatch, bot, chat_id, venue): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + data = request_data.json_parameters return ( - data['longitude'] == self.location.longitude - and data['latitude'] == self.location.latitude + data['longitude'] == str(self.location.longitude) + and data['latitude'] == str(self.location.latitude) and data['title'] == self.title and data['address'] == self.address and data['foursquare_id'] == self.foursquare_id @@ -83,8 +86,8 @@ def test(url, data, **kwargs): and data['google_place_type'] == self.google_place_type ) - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_venue(chat_id, venue=venue) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_venue(chat_id, venue=venue) assert message @flaky(3, 1) @@ -97,13 +100,14 @@ def test(url, data, **kwargs): ], indirect=['default_bot'], ) - def test_send_venue_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_venue_default_allow_sending_without_reply( self, default_bot, chat_id, venue, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_venue( + message = await default_bot.send_venue( chat_id, venue=venue, allow_sending_without_reply=custom, @@ -111,27 +115,29 @@ def test_send_venue_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_venue( + message = await default_bot.send_venue( chat_id, venue=venue, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_venue( + await default_bot.send_venue( chat_id, venue=venue, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_venue_default_protect_content(self, default_bot, chat_id, venue): - protected = default_bot.send_venue(chat_id, venue=venue) + async def test_send_venue_default_protect_content(self, default_bot, chat_id, venue): + protected = await default_bot.send_venue(chat_id, venue=venue) assert protected.has_protected_content - unprotected = default_bot.send_venue(chat_id, venue=venue, protect_content=False) + unprotected = await default_bot.send_venue(chat_id, venue=venue, protect_content=False) assert not unprotected.has_protected_content - def test_send_venue_without_required(self, bot, chat_id): + @pytest.mark.asyncio + async def test_send_venue_without_required(self, bot, chat_id): with pytest.raises(ValueError, match='Either venue or latitude, longitude, address and'): - bot.send_venue(chat_id=chat_id) + await bot.send_venue(chat_id=chat_id) def test_to_dict(self, venue): venue_dict = venue.to_dict() diff --git a/tests/test_video.py b/tests/test_video.py index 2574304cadf..141f20a2668 100644 --- a/tests/test_video.py +++ b/tests/test_video.py @@ -25,6 +25,7 @@ from telegram import Video, Voice, PhotoSize, MessageEntity, Bot from telegram.error import BadRequest, TelegramError from telegram.helpers import escape_markdown +from telegram.request import RequestData from tests.conftest import ( check_shortcut_call, check_shortcut_signature, @@ -41,9 +42,10 @@ def video_file(): @pytest.fixture(scope='class') -def video(bot, chat_id): +@pytest.mark.asyncio +async def video(bot, chat_id): with data_file('telegram.mp4').open('rb') as f: - return bot.send_video(chat_id, video=f, timeout=50).video + return (await bot.send_video(chat_id, video=f, read_timeout=50)).video class TestVideo: @@ -65,11 +67,6 @@ class TestVideo: video_file_id = '5a3128a4d2a04750b5b58397f3b5e812' video_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' - def test_slot_behaviour(self, video, mro_slots): - for attr in video.__slots__: - assert getattr(video, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(video)) == len(set(mro_slots(video))), "duplicate slot" - def test_creation(self, video): # Make sure file has been uploaded. assert isinstance(video, Video) @@ -92,8 +89,9 @@ def test_expected_values(self, video): assert video.mime_type == self.mime_type @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, video_file, video, thumb_file): - message = bot.send_video( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, video_file, video, thumb_file): + message = await bot.send_video( chat_id, video_file, duration=self.duration, @@ -127,30 +125,37 @@ def test_send_all_args(self, bot, chat_id, video_file, video, thumb_file): assert message.has_protected_content @flaky(3, 1) - def test_send_video_custom_filename(self, bot, chat_id, video_file, monkeypatch): - def make_assertion(url, data, **kwargs): - return data['video'].filename == 'custom_filename' + @pytest.mark.asyncio + async def test_send_video_custom_filename(self, bot, chat_id, video_file, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return list(request_data.multipart_data.values())[0][0] == 'custom_filename' monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_video(chat_id, video_file, filename='custom_filename') + assert await bot.send_video(chat_id, video_file, filename='custom_filename') @flaky(3, 1) - def test_get_and_download(self, bot, video): - new_file = bot.get_file(video.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, video): + path = Path('telegram.mp4') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(video.file_id) assert new_file.file_size == self.file_size assert new_file.file_id == video.file_id assert new_file.file_unique_id == video.file_unique_id assert new_file.file_path.startswith('https://') - new_file.download('telegram.mp4') + await new_file.download('telegram.mp4') - assert Path('telegram.mp4').is_file() + assert path.is_file() @flaky(3, 1) - def test_send_mp4_file_url(self, bot, chat_id, video): - message = bot.send_video(chat_id, self.video_file_url, caption=self.caption) + @pytest.mark.asyncio + async def test_send_mp4_file_url(self, bot, chat_id, video): + message = await bot.send_video(chat_id, self.video_file_url, caption=self.caption) assert isinstance(message.video, Video) assert isinstance(message.video.file_id, str) @@ -174,48 +179,55 @@ def test_send_mp4_file_url(self, bot, chat_id, video): assert message.caption == self.caption @flaky(3, 1) - def test_send_video_caption_entities(self, bot, chat_id, video): + @pytest.mark.asyncio + async def test_send_video_caption_entities(self, bot, chat_id, video): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_video(chat_id, video, caption=test_string, caption_entities=entities) + message = await bot.send_video( + chat_id, video, caption=test_string, caption_entities=entities + ) assert message.caption == test_string assert message.caption_entities == entities @flaky(3, 1) - def test_resend(self, bot, chat_id, video): - message = bot.send_video(chat_id, video.file_id) + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, video): + message = await bot.send_video(chat_id, video.file_id) assert message.video == video - def test_send_with_video(self, monkeypatch, bot, chat_id, video): - def test(url, data, **kwargs): - return data['video'] == video.file_id + @pytest.mark.asyncio + async def test_send_with_video(self, monkeypatch, bot, chat_id, video): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['video'] == video.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_video(chat_id, video=video) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_video(chat_id, video=video) assert message @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_video_default_parse_mode_1(self, default_bot, chat_id, video): + @pytest.mark.asyncio + async def test_send_video_default_parse_mode_1(self, default_bot, chat_id, video): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_video(chat_id, video, caption=test_markdown_string) + message = await default_bot.send_video(chat_id, video, caption=test_markdown_string) assert message.caption_markdown == test_markdown_string assert message.caption == test_string @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_video_default_parse_mode_2(self, default_bot, chat_id, video): + @pytest.mark.asyncio + async def test_send_video_default_parse_mode_2(self, default_bot, chat_id, video): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_video( + message = await default_bot.send_video( chat_id, video, caption=test_markdown_string, parse_mode=None ) assert message.caption == test_markdown_string @@ -223,37 +235,39 @@ def test_send_video_default_parse_mode_2(self, default_bot, chat_id, video): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_video_default_parse_mode_3(self, default_bot, chat_id, video): + @pytest.mark.asyncio + async def test_send_video_default_parse_mode_3(self, default_bot, chat_id, video): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_video( + message = await default_bot.send_video( chat_id, video, caption=test_markdown_string, parse_mode='HTML' ) assert message.caption == test_markdown_string assert message.caption_markdown == escape_markdown(test_markdown_string) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_video_default_protect_content(self, chat_id, default_bot, video): - protected = default_bot.send_video(chat_id, video) + async def test_send_video_default_protect_content(self, chat_id, default_bot, video): + protected = await default_bot.send_video(chat_id, video) assert protected.has_protected_content - unprotected = default_bot.send_video(chat_id, video, protect_content=False) + unprotected = await default_bot.send_video(chat_id, video, protect_content=False) assert not unprotected.has_protected_content - def test_send_video_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_video_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('video') == expected and data.get('thumb') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_video(chat_id, file, thumb=file) + await bot.send_video(chat_id, file, thumb=file) assert test_flag - monkeypatch.delattr(bot, '_post') @flaky(3, 1) @pytest.mark.parametrize( @@ -265,13 +279,14 @@ def make_assertion(_, data, *args, **kwargs): ], indirect=['default_bot'], ) - def test_send_video_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_video_default_allow_sending_without_reply( self, default_bot, chat_id, video, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_video( + message = await default_bot.send_video( chat_id, video, allow_sending_without_reply=custom, @@ -279,13 +294,13 @@ def test_send_video_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_video( + message = await default_bot.send_video( chat_id, video, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_video( + await default_bot.send_video( chat_id, video, reply_to_message_id=reply_to_message.message_id ) @@ -325,29 +340,33 @@ def test_to_dict(self, video): assert video_dict['file_name'] == video.file_name @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_video(chat_id, open(os.devnull, 'rb')) + await bot.send_video(chat_id, open(os.devnull, 'rb')) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_video(chat_id, '') + await bot.send_video(chat_id, '') - def test_error_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_video(chat_id=chat_id) + await bot.send_video(chat_id=chat_id) - def test_get_file_instance_method(self, monkeypatch, video): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, video): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == video.file_id assert check_shortcut_signature(Video.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(video.get_file, video.get_bot(), 'get_file') - assert check_defaults_handling(video.get_file, video.get_bot()) + assert await check_shortcut_call(video.get_file, video.get_bot(), 'get_file') + assert await check_defaults_handling(video.get_file, video.get_bot()) monkeypatch.setattr(video.get_bot(), 'get_file', make_assertion) - assert video.get_file() + assert await video.get_file() def test_equality(self, video): a = Video(video.file_id, video.file_unique_id, self.width, self.height, self.duration) diff --git a/tests/test_videonote.py b/tests/test_videonote.py index db09d7436f9..915a0e88615 100644 --- a/tests/test_videonote.py +++ b/tests/test_videonote.py @@ -24,6 +24,7 @@ from telegram import VideoNote, Voice, PhotoSize, Bot from telegram.error import BadRequest, TelegramError +from telegram.request import RequestData from tests.conftest import ( check_shortcut_call, check_shortcut_signature, @@ -40,9 +41,10 @@ def video_note_file(): @pytest.fixture(scope='class') -def video_note(bot, chat_id): +@pytest.mark.asyncio +async def video_note(bot, chat_id): with data_file('telegram2.mp4').open('rb') as f: - return bot.send_video_note(chat_id, video_note=f, timeout=50).video_note + return (await bot.send_video_note(chat_id, video_note=f, read_timeout=50)).video_note class TestVideoNote: @@ -58,11 +60,6 @@ class TestVideoNote: videonote_file_id = '5a3128a4d2a04750b5b58397f3b5e812' videonote_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' - def test_slot_behaviour(self, video_note, mro_slots): - for attr in video_note.__slots__: - assert getattr(video_note, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(video_note)) == len(set(mro_slots(video_note))), "duplicate slot" - def test_creation(self, video_note): # Make sure file has been uploaded. assert isinstance(video_note, VideoNote) @@ -83,8 +80,9 @@ def test_expected_values(self, video_note): assert video_note.file_size == self.file_size @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, video_note_file, video_note, thumb_file): - message = bot.send_video_note( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, video_note_file, video_note, thumb_file): + message = await bot.send_video_note( chat_id, video_note_file, duration=self.duration, @@ -109,39 +107,49 @@ def test_send_all_args(self, bot, chat_id, video_note_file, video_note, thumb_fi assert message.has_protected_content @flaky(3, 1) - def test_send_video_note_custom_filename(self, bot, chat_id, video_note_file, monkeypatch): - def make_assertion(url, data, **kwargs): - return data['video_note'].filename == 'custom_filename' + @pytest.mark.asyncio + async def test_send_video_note_custom_filename( + self, bot, chat_id, video_note_file, monkeypatch + ): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return list(request_data.multipart_data.values())[0][0] == 'custom_filename' monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_video_note(chat_id, video_note_file, filename='custom_filename') + assert await bot.send_video_note(chat_id, video_note_file, filename='custom_filename') @flaky(3, 1) - def test_get_and_download(self, bot, video_note): - new_file = bot.get_file(video_note.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, video_note): + path = Path('telegram2.mp4') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(video_note.file_id) assert new_file.file_size == self.file_size assert new_file.file_id == video_note.file_id assert new_file.file_unique_id == video_note.file_unique_id assert new_file.file_path.startswith('https://') - new_file.download('telegram2.mp4') + await new_file.download('telegram2.mp4') - assert Path('telegram2.mp4').is_file() + assert path.is_file() @flaky(3, 1) - def test_resend(self, bot, chat_id, video_note): - message = bot.send_video_note(chat_id, video_note.file_id) + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, video_note): + message = await bot.send_video_note(chat_id, video_note.file_id) assert message.video_note == video_note - def test_send_with_video_note(self, monkeypatch, bot, chat_id, video_note): - def test(url, data, **kwargs): - return data['video_note'] == video_note.file_id + @pytest.mark.asyncio + async def test_send_with_video_note(self, monkeypatch, bot, chat_id, video_note): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['video_note'] == video_note.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_video_note(chat_id, video_note=video_note) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_video_note(chat_id, video_note=video_note) assert message def test_de_json(self, bot): @@ -170,20 +178,20 @@ def test_to_dict(self, video_note): assert video_note_dict['duration'] == video_note.duration assert video_note_dict['file_size'] == video_note.file_size - def test_send_video_note_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_video_note_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('video_note') == expected and data.get('thumb') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_video_note(chat_id, file, thumb=file) + await bot.send_video_note(chat_id, file, thumb=file) assert test_flag - monkeypatch.delattr(bot, '_post') @flaky(3, 1) @pytest.mark.parametrize( @@ -195,13 +203,14 @@ def make_assertion(_, data, *args, **kwargs): ], indirect=['default_bot'], ) - def test_send_video_note_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_video_note_default_allow_sending_without_reply( self, default_bot, chat_id, video_note, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_video_note( + message = await default_bot.send_video_note( chat_id, video_note, allow_sending_without_reply=custom, @@ -209,48 +218,53 @@ def test_send_video_note_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_video_note( + message = await default_bot.send_video_note( chat_id, video_note, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_video_note( + await default_bot.send_video_note( chat_id, video_note, reply_to_message_id=reply_to_message.message_id ) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_video_note_default_protect_content(self, chat_id, default_bot, video_note): - protected = default_bot.send_video_note(chat_id, video_note) + async def test_send_video_note_default_protect_content(self, chat_id, default_bot, video_note): + protected = await default_bot.send_video_note(chat_id, video_note) assert protected.has_protected_content - unprotected = default_bot.send_video_note(chat_id, video_note, protect_content=False) + unprotected = await default_bot.send_video_note(chat_id, video_note, protect_content=False) assert not unprotected.has_protected_content @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_video_note(chat_id, open(os.devnull, 'rb')) + await bot.send_video_note(chat_id, open(os.devnull, 'rb')) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.send_video_note(chat_id, '') + await bot.send_video_note(chat_id, '') - def test_error_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.send_video_note(chat_id=chat_id) + await bot.send_video_note(chat_id=chat_id) - def test_get_file_instance_method(self, monkeypatch, video_note): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, video_note): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == video_note.file_id assert check_shortcut_signature(VideoNote.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(video_note.get_file, video_note.get_bot(), 'get_file') - assert check_defaults_handling(video_note.get_file, video_note.get_bot()) + assert await check_shortcut_call(video_note.get_file, video_note.get_bot(), 'get_file') + assert await check_defaults_handling(video_note.get_file, video_note.get_bot()) monkeypatch.setattr(video_note.get_bot(), 'get_file', make_assertion) - assert video_note.get_file() + assert await video_note.get_file() def test_equality(self, video_note): a = VideoNote(video_note.file_id, video_note.file_unique_id, self.length, self.duration) diff --git a/tests/test_voice.py b/tests/test_voice.py index bea17f87477..4190e95cdcf 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import os +from pathlib import Path import pytest from flaky import flaky @@ -24,6 +25,7 @@ from telegram import Audio, Voice, MessageEntity, Bot from telegram.error import BadRequest, TelegramError from telegram.helpers import escape_markdown +from telegram.request import RequestData from tests.conftest import ( check_shortcut_call, check_shortcut_signature, @@ -40,9 +42,9 @@ def voice_file(): @pytest.fixture(scope='class') -def voice(bot, chat_id): +async def voice(bot, chat_id): with data_file('telegram.ogg').open('rb') as f: - return bot.send_voice(chat_id, voice=f, timeout=50).voice + return (await bot.send_voice(chat_id, voice=f, read_timeout=50)).voice class TestVoice: @@ -56,12 +58,8 @@ class TestVoice: voice_file_id = '5a3128a4d2a04750b5b58397f3b5e812' voice_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' - def test_slot_behaviour(self, voice, mro_slots): - for attr in voice.__slots__: - assert getattr(voice, attr, 'err') != 'err', f"got extra slot '{attr}'" - assert len(mro_slots(voice)) == len(set(mro_slots(voice))), "duplicate slot" - - def test_creation(self, voice): + @pytest.mark.asyncio + async def test_creation(self, voice): # Make sure file has been uploaded. assert isinstance(voice, Voice) assert isinstance(voice.file_id, str) @@ -75,8 +73,9 @@ def test_expected_values(self, voice): assert voice.file_size == self.file_size @flaky(3, 1) - def test_send_all_args(self, bot, chat_id, voice_file, voice): - message = bot.send_voice( + @pytest.mark.asyncio + async def test_send_all_args(self, bot, chat_id, voice_file, voice): + message = await bot.send_voice( chat_id, voice_file, duration=self.duration, @@ -98,30 +97,37 @@ def test_send_all_args(self, bot, chat_id, voice_file, voice): assert message.has_protected_content @flaky(3, 1) - def test_send_voice_custom_filename(self, bot, chat_id, voice_file, monkeypatch): - def make_assertion(url, data, **kwargs): - return data['voice'].filename == 'custom_filename' + @pytest.mark.asyncio + async def test_send_voice_custom_filename(self, bot, chat_id, voice_file, monkeypatch): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return list(request_data.multipart_data.values())[0][0] == 'custom_filename' monkeypatch.setattr(bot.request, 'post', make_assertion) - assert bot.send_voice(chat_id, voice_file, filename='custom_filename') + assert await bot.send_voice(chat_id, voice_file, filename='custom_filename') @flaky(3, 1) - def test_get_and_download(self, bot, voice): - new_file = bot.get_file(voice.file_id) + @pytest.mark.asyncio + async def test_get_and_download(self, bot, voice): + path = Path('telegram.ogg') + if path.is_file(): + path.unlink() + + new_file = await bot.get_file(voice.file_id) assert new_file.file_size == voice.file_size assert new_file.file_id == voice.file_id assert new_file.file_unique_id == voice.file_unique_id assert new_file.file_path.startswith('https://') - new_filepath = new_file.download('telegram.ogg') + await new_file.download('telegram.ogg') - assert new_filepath.is_file() + assert path.is_file() @flaky(3, 1) - def test_send_ogg_url_file(self, bot, chat_id, voice): - message = bot.sendVoice(chat_id, self.voice_file_url, duration=self.duration) + @pytest.mark.asyncio + async def test_send_ogg_url_file(self, bot, chat_id, voice): + message = await bot.sendVoice(chat_id, self.voice_file_url, duration=self.duration) assert isinstance(message.voice, Voice) assert isinstance(message.voice.file_id, str) @@ -133,28 +139,31 @@ def test_send_ogg_url_file(self, bot, chat_id, voice): assert message.voice.file_size == voice.file_size @flaky(3, 1) - def test_resend(self, bot, chat_id, voice): - message = bot.sendVoice(chat_id, voice.file_id) + @pytest.mark.asyncio + async def test_resend(self, bot, chat_id, voice): + message = await bot.sendVoice(chat_id, voice.file_id) assert message.voice == voice - def test_send_with_voice(self, monkeypatch, bot, chat_id, voice): - def test(url, data, **kwargs): - return data['voice'] == voice.file_id + @pytest.mark.asyncio + async def test_send_with_voice(self, monkeypatch, bot, chat_id, voice): + async def make_assertion(url, request_data: RequestData, *args, **kwargs): + return request_data.json_parameters['voice'] == voice.file_id - monkeypatch.setattr(bot.request, 'post', test) - message = bot.send_voice(chat_id, voice=voice) + monkeypatch.setattr(bot.request, 'post', make_assertion) + message = await bot.send_voice(chat_id, voice=voice) assert message @flaky(3, 1) - def test_send_voice_caption_entities(self, bot, chat_id, voice_file): + @pytest.mark.asyncio + async def test_send_voice_caption_entities(self, bot, chat_id, voice_file): test_string = 'Italic Bold Code' entities = [ MessageEntity(MessageEntity.ITALIC, 0, 6), MessageEntity(MessageEntity.ITALIC, 7, 4), MessageEntity(MessageEntity.ITALIC, 12, 4), ] - message = bot.send_voice( + message = await bot.send_voice( chat_id, voice_file, caption=test_string, caption_entities=entities ) @@ -163,20 +172,22 @@ def test_send_voice_caption_entities(self, bot, chat_id, voice_file): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_voice_default_parse_mode_1(self, default_bot, chat_id, voice): + @pytest.mark.asyncio + async def test_send_voice_default_parse_mode_1(self, default_bot, chat_id, voice): test_string = 'Italic Bold Code' test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_voice(chat_id, voice, caption=test_markdown_string) + message = await default_bot.send_voice(chat_id, voice, caption=test_markdown_string) assert message.caption_markdown == test_markdown_string assert message.caption == test_string @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_voice_default_parse_mode_2(self, default_bot, chat_id, voice): + @pytest.mark.asyncio + async def test_send_voice_default_parse_mode_2(self, default_bot, chat_id, voice): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_voice( + message = await default_bot.send_voice( chat_id, voice, caption=test_markdown_string, parse_mode=None ) assert message.caption == test_markdown_string @@ -184,37 +195,39 @@ def test_send_voice_default_parse_mode_2(self, default_bot, chat_id, voice): @flaky(3, 1) @pytest.mark.parametrize('default_bot', [{'parse_mode': 'Markdown'}], indirect=True) - def test_send_voice_default_parse_mode_3(self, default_bot, chat_id, voice): + @pytest.mark.asyncio + async def test_send_voice_default_parse_mode_3(self, default_bot, chat_id, voice): test_markdown_string = '_Italic_ *Bold* `Code`' - message = default_bot.send_voice( + message = await default_bot.send_voice( chat_id, voice, caption=test_markdown_string, parse_mode='HTML' ) assert message.caption == test_markdown_string assert message.caption_markdown == escape_markdown(test_markdown_string) @flaky(3, 1) + @pytest.mark.asyncio @pytest.mark.parametrize('default_bot', [{'protect_content': True}], indirect=True) - def test_send_voice_default_protect_content(self, chat_id, default_bot, voice): - protected = default_bot.send_voice(chat_id, voice) + async def test_send_voice_default_protect_content(self, chat_id, default_bot, voice): + protected = await default_bot.send_voice(chat_id, voice) assert protected.has_protected_content - unprotected = default_bot.send_voice(chat_id, voice, protect_content=False) + unprotected = await default_bot.send_voice(chat_id, voice, protect_content=False) assert not unprotected.has_protected_content - def test_send_voice_local_files(self, monkeypatch, bot, chat_id): + @pytest.mark.asyncio + async def test_send_voice_local_files(self, monkeypatch, bot, chat_id): # For just test that the correct paths are passed as we have no local bot API set up test_flag = False file = data_file('telegram.jpg') expected = file.as_uri() - def make_assertion(_, data, *args, **kwargs): + async def make_assertion(_, data, *args, **kwargs): nonlocal test_flag test_flag = data.get('voice') == expected monkeypatch.setattr(bot, '_post', make_assertion) - bot.send_voice(chat_id, file) + await bot.send_voice(chat_id, file) assert test_flag - monkeypatch.delattr(bot, '_post') @flaky(3, 1) @pytest.mark.parametrize( @@ -226,13 +239,14 @@ def make_assertion(_, data, *args, **kwargs): ], indirect=['default_bot'], ) - def test_send_voice_default_allow_sending_without_reply( + @pytest.mark.asyncio + async def test_send_voice_default_allow_sending_without_reply( self, default_bot, chat_id, voice, custom ): - reply_to_message = default_bot.send_message(chat_id, 'test') - reply_to_message.delete() + reply_to_message = await default_bot.send_message(chat_id, 'test') + await reply_to_message.delete() if custom is not None: - message = default_bot.send_voice( + message = await default_bot.send_voice( chat_id, voice, allow_sending_without_reply=custom, @@ -240,13 +254,13 @@ def test_send_voice_default_allow_sending_without_reply( ) assert message.reply_to_message is None elif default_bot.defaults.allow_sending_without_reply: - message = default_bot.send_voice( + message = await default_bot.send_voice( chat_id, voice, reply_to_message_id=reply_to_message.message_id ) assert message.reply_to_message is None else: with pytest.raises(BadRequest, match='message not found'): - default_bot.send_voice( + await default_bot.send_voice( chat_id, voice, reply_to_message_id=reply_to_message.message_id ) @@ -278,29 +292,33 @@ def test_to_dict(self, voice): assert voice_dict['file_size'] == voice.file_size @flaky(3, 1) - def test_error_send_empty_file(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file(self, bot, chat_id): with pytest.raises(TelegramError): - bot.sendVoice(chat_id, open(os.devnull, 'rb')) + await bot.sendVoice(chat_id, open(os.devnull, 'rb')) @flaky(3, 1) - def test_error_send_empty_file_id(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_send_empty_file_id(self, bot, chat_id): with pytest.raises(TelegramError): - bot.sendVoice(chat_id, '') + await bot.sendVoice(chat_id, '') - def test_error_without_required_args(self, bot, chat_id): + @pytest.mark.asyncio + async def test_error_without_required_args(self, bot, chat_id): with pytest.raises(TypeError): - bot.sendVoice(chat_id) + await bot.sendVoice(chat_id) - def test_get_file_instance_method(self, monkeypatch, voice): - def make_assertion(*_, **kwargs): + @pytest.mark.asyncio + async def test_get_file_instance_method(self, monkeypatch, voice): + async def make_assertion(*_, **kwargs): return kwargs['file_id'] == voice.file_id assert check_shortcut_signature(Voice.get_file, Bot.get_file, ['file_id'], []) - assert check_shortcut_call(voice.get_file, voice.get_bot(), 'get_file') - assert check_defaults_handling(voice.get_file, voice.get_bot()) + assert await check_shortcut_call(voice.get_file, voice.get_bot(), 'get_file') + assert await check_defaults_handling(voice.get_file, voice.get_bot()) monkeypatch.setattr(voice.get_bot(), 'get_file', make_assertion) - assert voice.get_file() + assert await voice.get_file() def test_equality(self, voice): a = Voice(voice.file_id, voice.file_unique_id, self.duration) From ea703887d624da44fdf03c134a0ceea66925d8fe Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 8 Feb 2022 22:04:14 +0100 Subject: [PATCH 002/153] Get CH to roughly work --- telegram/ext/_application.py | 41 +++--- telegram/ext/_builders.py | 13 +- telegram/ext/_conversationhandler.py | 198 +++++++++++++++++---------- 3 files changed, 157 insertions(+), 95 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index da73641cf44..e8203dc2ee5 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -336,10 +336,6 @@ async def _initialize_persistence(self) -> None: if not self.persistence: return - # This raises an exception if persistence.store_data.callback_data is True - # but self.bot is not an instance of ExtBot - so no need to check that later on - self.persistence.set_bot(self.bot) - if self.persistence.store_data.user_data: cast(TrackingDefaultDict, self._user_data).update_no_track( await self.persistence.get_user_data() @@ -950,32 +946,39 @@ async def __update_persistence(self) -> None: else: coroutines.add(self.persistence.drop_user_data(user_id)) + # Unfortunately due to circular imports this has to be here + # pylint: disable=import-outside-toplevel + from telegram.ext._conversationhandler import PendingState + for name, (key, new_state) in itertools.chain.from_iterable( zip(itertools.repeat(name), states_dict.pop_accessed_write_items()) for name, states_dict in self._conversation_handler_conversations.items() ): - if isinstance(new_state, tuple) and isinstance(new_state[1], asyncio.Task): + if isinstance(new_state, PendingState): # If the handler was running non-blocking, we check if the new state is already # available. Otherwise, we update with the old state, which is the next best # guess. # Note that when updating the persistence one last time during self.stop(), # *all* tasks will be done. - try: - result = new_state[1].result() - coroutines.add( - self.persistence.update_conversation(name=name, key=key, new_state=result) + if not new_state.done(): + # TODO: Try to test that this doesn't happen on shutdown + _logger.warning( + 'A ConversationHandlers state was not yet resolved. Updating the ' + 'persistence with the current state.' ) - except (asyncio.InvalidStateError, asyncio.CancelledError): - effective_new_state = ( - None if new_state[0] is TrackingDefaultDict.DELETED else new_state[0] - ) - coroutines.add( - self.persistence.update_conversation( - name=name, - key=key, - new_state=effective_new_state, - ) + result = new_state.old_state + else: + result = new_state.resolve() + + effective_new_state = None if result is TrackingDefaultDict.DELETED else result + print(name, key, effective_new_state) + # TODO: Test that we actually pass `None` here in case the conversation had ended, + # i.e. effective_new_state is TrackingDefaultDict.DELETED + coroutines.add( + self.persistence.update_conversation( + name=name, key=key, new_state=effective_new_state ) + ) results = await asyncio.gather(*coroutines, return_exceptions=True) _logger.debug('Finished updating persistence.') diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index a495a161730..b6f87d17388 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -181,7 +181,7 @@ def __init__(self: 'InitApplicationBuilder'): self._updater: ODVInput[Updater] = DEFAULT_NONE def _build_request(self, get_updates: bool) -> BaseRequest: - prefix = 'get_updates_' if get_updates else '' + prefix = '_get_updates_' if get_updates else '_' if not isinstance(getattr(self, f'{prefix}request'), DefaultValue): return getattr(self, f'{prefix}request') @@ -228,10 +228,14 @@ def build( ) -> Application[BT, CCT, UD, CD, BD, JQ]: """Builds a :class:`telegram.ext.Application` with the provided arguments. + Calls :meth:`telegram.ext.JobQueue.set_application` and + :meth:`telegram.ext.BasePersistence.set_bot` if appropriate. + Returns: :class:`telegram.ext.Application` """ job_queue = DefaultValue.get_value(self._job_queue) + persistence = DefaultValue.get_value(self._persistence) if isinstance(self._updater, DefaultValue) or self._updater is None: if isinstance(self._bot, DefaultValue): @@ -255,7 +259,7 @@ def build( updater=updater, concurrent_updates=DefaultValue.get_value(self._concurrent_updates), job_queue=job_queue, - persistence=DefaultValue.get_value(self._persistence), + persistence=persistence, context_types=DefaultValue.get_value(self._context_types), **self._application_kwargs, ) @@ -263,6 +267,11 @@ def build( if job_queue is not None: job_queue.set_application(application) + if persistence is not None: + # This raises an exception if persistence.store_data.callback_data is True + # but self.bot is not an instance of ExtBot - so no need to check that later on + persistence.set_bot(bot) + return application def application_class( diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 32cb0a9bea2..cf944220f6c 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -20,8 +20,9 @@ """This module contains the ConversationHandler.""" import asyncio import logging -import functools import datetime +import threading +from dataclasses import dataclass from typing import ( # pylint: disable=unused-import # for the "Any" import TYPE_CHECKING, Dict, @@ -34,6 +35,7 @@ ClassVar, Any, Set, + Generic, ) from telegram import Update @@ -57,21 +59,50 @@ from telegram.ext import Application, Job, JobQueue CheckUpdateType = Tuple[object, Tuple[int, ...], Handler, object] +_logger = logging.getLogger(__name__) -class _ConversationTimeoutContext: + +@dataclass +class _ConversationTimeoutContext(Generic[CCT]): __slots__ = ('conversation_key', 'update', 'application', 'callback_context') - def __init__( - self, - conversation_key: Tuple[int, ...], - update: Update, - application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', - callback_context: CallbackContext, - ): - self.conversation_key = conversation_key - self.update = update - self.application = application - self.callback_context = callback_context + conversation_key: Tuple[int, ...] + update: Update + application: 'Application[Any, CCT, Any, Any, Any, JobQueue]' + callback_context: CallbackContext + + +@dataclass +class PendingState: + """Thin wrapper around asyncio.Task to handle block=False handlers. Note that this is a + public class of this module, since `Application.update_persistence` needs to access it.""" + + __slots__ = ('task', 'old_state') + + task: asyncio.Task + old_state: object + + def done(self) -> bool: + return self.task.done() + + def resolve(self) -> object: + if not self.task.done(): + raise RuntimeError('New state is not yet available') + + exc = self.task.exception() + if exc: + _logger.exception( + "Task function raised exception. Falling back to old state %s", + self.old_state, + exc_info=exc, + ) + return self.old_state + + res = self.task.result() + if res is None and self.old_state is None: + res = ConversationHandler.END + + return res class ConversationHandler(Handler[Update, CCT]): @@ -79,6 +110,11 @@ class ConversationHandler(Handler[Update, CCT]): A handler to hold a conversation with a single or multiple users through Telegram updates by managing four collections of other handlers. + Warning: + :class:`ConversationHandler` heavily relies on incoming updates being processed one by one. + When using this handler, :attr:`telegram.ext.Application.concurrent_updates` should be + :obj:`False`. + Note: ``ConversationHandler`` will only accept updates that are (subclass-)instances of :class:`telegram.Update`. This is, because depending on the :attr:`per_user` and @@ -179,7 +215,7 @@ class ConversationHandler(Handler[Update, CCT]): block (:obj:`bool`, optional): Pass :obj:`False` to *overrule* the :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, :attr:`states` and :attr:`fallbacks`). - Defaults to :obj:`True`. + Defaults to :obj:`True`, in which case the handlers setting will be respected. .. versionadded:: 13.2 .. versionchanged:: 14.0 @@ -267,15 +303,14 @@ def __init__( self.timeout_jobs: Dict[Tuple[int, ...], 'Job'] = {} self._timeout_jobs_lock = asyncio.Lock() self._conversations: ConversationDict = {} - self._conversations_lock = asyncio.Lock() + # TODO: Do we still need this lock? + self._conversations_lock = threading.Lock() self._child_conversations: Set['ConversationHandler'] = set() if persistent and not self.name: raise ValueError("Conversations can't be persistent when handler is unnamed.") self.persistent: bool = persistent - self._logger = logging.getLogger(__name__) - if not any((self.per_user, self.per_chat, self.per_message)): raise ValueError("'per_user', 'per_chat' and 'per_message' can't all be 'False'") @@ -541,49 +576,54 @@ def _get_key(self, update: Update) -> Tuple[int, ...]: return tuple(key) - def _resolve_task(self, state: Tuple[object, asyncio.Task]) -> object: - old_state, new_state = state - res = new_state.result() - res = res if res is not None else old_state - - exc = new_state.exception() - if exc: - self._logger.exception("Task function raised exception") - self._logger.exception("%s", exc) - res = old_state - - if res is None and old_state is None: - res = self.END - - return res + async def _schedule_job_delayed( + self, + new_state: asyncio.Task, + application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', + update: Update, + context: CallbackContext, + conversation_key: Tuple[int, ...], + ) -> None: + try: + effective_new_state = await new_state + except Exception as exc: + _logger.debug( + 'Non-blocking handler callback raised exception. Not scheduling conversation ' + 'timeout.', + exc_info=exc, + ) + return + return self._schedule_job( + new_state=effective_new_state, + application=application, + update=update, + context=context, + conversation_key=conversation_key, + ) def _schedule_job( self, - new_state: Union[object, asyncio.Task], + new_state: object, application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', update: Update, context: CallbackContext, conversation_key: Tuple[int, ...], ) -> None: - if isinstance(new_state, asyncio.Task): - new_state = new_state.result() - - if new_state != self.END: - try: - # both job_queue & conversation_timeout are checked before calling _schedule_job - j_queue = application.job_queue - self.timeout_jobs[conversation_key] = j_queue.run_once( - self._trigger_timeout, - self.conversation_timeout, # type: ignore[arg-type] - context=_ConversationTimeoutContext( - conversation_key, update, application, context - ), - ) - except Exception as exc: - self._logger.exception( - "Failed to schedule timeout job due to the following exception:" - ) - self._logger.exception("%s", exc) + if new_state == self.END: + return + + try: + # both job_queue & conversation_timeout are checked before calling _schedule_job + j_queue = application.job_queue + self.timeout_jobs[conversation_key] = j_queue.run_once( + self._trigger_timeout, + self.conversation_timeout, # type: ignore[arg-type] + context=_ConversationTimeoutContext( + conversation_key, update, application, context + ), + ) + except Exception as exc: + _logger.exception("Failed to schedule timeout.", exc_info=exc) # pylint: disable=too-many-return-statements def check_update(self, update: object) -> Optional[CheckUpdateType]: @@ -615,12 +655,12 @@ def check_update(self, update: object) -> Optional[CheckUpdateType]: state = self._conversations.get(key) # Resolve promises - if isinstance(state, tuple) and len(state) == 2 and isinstance(state[1], asyncio.Task): - self._logger.warning('Waiting for asyncio Task to finish ...') + if isinstance(state, PendingState): + _logger.debug('Waiting for asyncio Task to finish ...') # check if promise is finished or not - if state[1].done(): - res = self._resolve_task(state) # type: ignore[arg-type] + if state.done(): + res = state.resolve() self._update_state(res, key) with self._conversations_lock: state = self._conversations.get(key) @@ -634,7 +674,7 @@ def check_update(self, update: object) -> Optional[CheckUpdateType]: return self.WAITING, key, handler_, check return None - self._logger.debug('Selecting conversation %s with state %s', str(key), str(state)) + _logger.debug('Selecting conversation %s with state %s', str(key), str(state)) handler: Optional[Handler] = None @@ -693,7 +733,7 @@ async def handle_update( # type: ignore[override] current_state, conversation_key, handler, handler_check_result = check_result raise_dp_handler_stop = False - with self._timeout_jobs_lock: + async with self._timeout_jobs_lock: # Remove the old timeout job (if present) timeout_job = self.timeout_jobs.pop(conversation_key, None) @@ -701,27 +741,34 @@ async def handle_update( # type: ignore[override] timeout_job.schedule_removal() try: # TODO handle non-blocking handlers correctly - new_state: object = await handler.handle_update( - update, application, handler_check_result, context - ) + block = self.block and handler.block + if block: + new_state: object = await handler.handle_update( + update, application, handler_check_result, context + ) + else: + new_state = application.create_task( + coroutine=handler.handle_update( + update, application, handler_check_result, context + ), + update=update, + ) except ApplicationHandlerStop as exception: new_state = exception.state raise_dp_handler_stop = True - with self._timeout_jobs_lock: + async with self._timeout_jobs_lock: if self.conversation_timeout: if application.job_queue is not None: # Add the new timeout job + # checking if the new state is self.END is done in _schedule_job if isinstance(new_state, asyncio.Task): - new_state.add_done_callback( - functools.partial( - self._schedule_job, - application=application, - update=update, - context=context, - conversation_key=conversation_key, - ) + application.create_task( + self._schedule_job_delayed( + new_state, application, update, context, conversation_key + ), + update=update, ) - elif new_state != self.END: + else: self._schedule_job( new_state, application, update, context, conversation_key ) @@ -754,13 +801,16 @@ def _update_state(self, new_state: object, key: Tuple[int, ...]) -> None: elif isinstance(new_state, asyncio.Task): with self._conversations_lock: - self._conversations[key] = (self._conversations.get(key), new_state) + self._conversations[key] = PendingState( + old_state=self._conversations.get(key), task=new_state + ) elif new_state is not None: if new_state not in self.states: warn( f"Handler returned state {new_state} which is unknown to the " f"ConversationHandler{' ' + self.name if self.name is not None else ''}.", + stacklevel=2, ) with self._conversations_lock: self._conversations[key] = new_state @@ -769,13 +819,13 @@ async def _trigger_timeout(self, context: CallbackContext) -> None: job = cast('Job', context.job) ctxt = cast(_ConversationTimeoutContext, job.context) - self._logger.debug( + _logger.debug( 'Conversation timeout was triggered for conversation %s!', ctxt.conversation_key ) callback_context = ctxt.callback_context - with self._timeout_jobs_lock: + async with self._timeout_jobs_lock: found_job = self.timeout_jobs.get(ctxt.conversation_key) if found_job is not job: # The timeout has been cancelled in handle_update From 362e7b2dd257daeeb82e442a15e0c1b6c6930a77 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 9 Feb 2022 17:54:37 +0100 Subject: [PATCH 003/153] Start fixing some sphinx warnings --- docs/source/telegram.ext.application.rst | 2 +- .../telegram.ext.applicationhandlerstop.rst | 2 +- telegram/ext/_basepersistence.py | 22 +++++++++---------- telegram/ext/_dictpersistence.py | 2 +- telegram/ext/_jobqueue.py | 7 +++--- telegram/ext/_picklepersistence.py | 2 +- telegram/request/_requestdata.py | 4 ++-- 7 files changed, 21 insertions(+), 20 deletions(-) diff --git a/docs/source/telegram.ext.application.rst b/docs/source/telegram.ext.application.rst index 8009517b743..b2fc8ff4113 100644 --- a/docs/source/telegram.ext.application.rst +++ b/docs/source/telegram.ext.application.rst @@ -1,7 +1,7 @@ :github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_application.py telegram.ext.Application -======================= +======================== .. autoclass:: telegram.ext.Application :members: diff --git a/docs/source/telegram.ext.applicationhandlerstop.rst b/docs/source/telegram.ext.applicationhandlerstop.rst index 15ad832cca6..b2ee0c6ed31 100644 --- a/docs/source/telegram.ext.applicationhandlerstop.rst +++ b/docs/source/telegram.ext.applicationhandlerstop.rst @@ -1,7 +1,7 @@ :github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_application.py telegram.ext.ApplicationHandlerStop -================================== +=================================== .. autoclass:: telegram.ext.ApplicationHandlerStop :members: diff --git a/telegram/ext/_basepersistence.py b/telegram/ext/_basepersistence.py index 93a53f676d0..56cbc01fcc7 100644 --- a/telegram/ext/_basepersistence.py +++ b/telegram/ext/_basepersistence.py @@ -71,8 +71,8 @@ class BasePersistence(Generic[UD, CD, BD], ABC): Attention: The interface provided by this class is intended to be accessed exclusively by - :class:`~telegram.ext.Dispatcher`. Calling any of the methods below manually might - interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`. + :class:`~telegram.ext.Application`. Calling any of the methods below manually might + interfere with the integration of persistence into :class:`~telegram.ext.Application`. All relevant methods must be overwritten. This includes: @@ -114,7 +114,7 @@ class BasePersistence(Generic[UD, CD, BD], ABC): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. - update_interval (:obj:`int` | :obj:`float:, optional): The + update_interval (:obj:`int` | :obj:`float`, optional): The :class:`~telegram.ext.Application` will update the persistence in regular intervals. This parameter specifies the time (in seconds) to wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. @@ -607,8 +607,8 @@ async def drop_user_data(self, user_id: int) -> None: @abstractmethod async def refresh_user_data(self, user_id: int, user_data: UD) -> None: """Will be called by the :class:`telegram.ext.Application` before passing the - :attr:`~telegram.ext.Dispatcher.user_data` to a callback. Can be used to update data stored - in :attr:`~telegram.ext.Dispatcher.user_data` from an external source. + :attr:`~telegram.ext.Application.user_data` to a callback. Can be used to update data + stored in :attr:`~telegram.ext.Application.user_data` from an external source. .. versionadded:: 13.6 @@ -616,7 +616,7 @@ async def refresh_user_data(self, user_id: int, user_data: UD) -> None: Changed this method into an ``@abstractmethod``. Args: - user_id (:obj:`int`): The user ID this :attr:`~telegram.ext.Dispatcher.user_data` is + user_id (:obj:`int`): The user ID this :attr:`~telegram.ext.Application.user_data` is associated with. user_data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.user_data`): The ``user_data`` of a single user. @@ -625,8 +625,8 @@ async def refresh_user_data(self, user_id: int, user_data: UD) -> None: @abstractmethod async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: """Will be called by the :class:`telegram.ext.Application` before passing the - :attr:`~telegram.ext.Dispatcher.chat_data` to a callback. Can be used to update data stored - in :attr:`~telegram.ext.Dispatcher.chat_data` from an external source. + :attr:`~telegram.ext.Application.chat_data` to a callback. Can be used to update data + stored in :attr:`~telegram.ext.Application.chat_data` from an external source. .. versionadded:: 13.6 @@ -634,7 +634,7 @@ async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: Changed this method into an ``@abstractmethod``. Args: - chat_id (:obj:`int`): The chat ID this :attr:`~telegram.ext.Dispatcher.chat_data` is + chat_id (:obj:`int`): The chat ID this :attr:`~telegram.ext.Application.chat_data` is associated with. chat_data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.chat_data`): The ``chat_data`` of a single chat. @@ -643,8 +643,8 @@ async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: @abstractmethod async def refresh_bot_data(self, bot_data: BD) -> None: """Will be called by the :class:`telegram.ext.Application` before passing the - :attr:`~telegram.ext.Dispatcher.bot_data` to a callback. Can be used to update data stored - in :attr:`~telegram.ext.Dispatcher.bot_data` from an external source. + :attr:`~telegram.ext.Application.bot_data` to a callback. Can be used to update data stored + in :attr:`~telegram.ext.Application.bot_data` from an external source. .. versionadded:: 13.6 diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index c152d5aedd6..aaceef9356b 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -60,7 +60,7 @@ class DictPersistence(BasePersistence): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. - update_interval (:obj:`int` | :obj:`float:, optional): The + update_interval (:obj:`int` | :obj:`float`, optional): The :class:`~telegram.ext.Application` will update the persistence in regular intervals. This parameter specifies the time (in seconds) to wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 241bdb67641..ce775927cad 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -41,9 +41,10 @@ class JobQueue: Attributes: scheduler (:class:`apscheduler.schedulers.asyncio.AsyncIOScheduler`): The scheduler. - ..versionchanged:: 14.0 - Use :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instead of - :class:`apscheduler.schedulers.background.BackgroundScheduler` + + .. versionchanged:: 14.0 + Use :class:`~apscheduler.schedulers.asyncio.AsyncIOScheduler` instead of + :class:`~apscheduler.schedulers.background.BackgroundScheduler` """ diff --git a/telegram/ext/_picklepersistence.py b/telegram/ext/_picklepersistence.py index f79db222785..64603017964 100644 --- a/telegram/ext/_picklepersistence.py +++ b/telegram/ext/_picklepersistence.py @@ -64,7 +64,7 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. - update_interval (:obj:`int` | :obj:`float:, optional): The + update_interval (:obj:`int` | :obj:`float`, optional): The :class:`~telegram.ext.Application` will update the persistence in regular intervals. This parameter specifies the time (in seconds) to wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. diff --git a/telegram/request/_requestdata.py b/telegram/request/_requestdata.py index f84a9eb1259..156f4e364ee 100644 --- a/telegram/request/_requestdata.py +++ b/telegram/request/_requestdata.py @@ -83,12 +83,12 @@ def url_encoded_parameters(self, encode_kwargs: Dict[str, Any] = None) -> str: def parametrized_url(self, url: str, encode_kwargs: Dict[str, Any] = None) -> str: """Shortcut for attaching the return value of :meth:`url_encoded_parameters` to the - :attr:`url`. + :paramref:`url`. Args: url (:obj:`str`): The URL the parameters will be attached to. encode_kwargs (Dict[:obj:`str`, any], optional): Additional keyword arguments to pass - along to :meth:`urllib.parse.urlencode`. + along to :attr:`urllib.parse.urlencode`. """ url_parameters = self.url_encoded_parameters(encode_kwargs=encode_kwargs) return f'{url}?{url_parameters}' From ecd909d359385f93952a9b366989fd5f8d032b5a Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 9 Feb 2022 22:28:07 +0100 Subject: [PATCH 004/153] More reference fixes --- docs/source/conf.py | 28 ++++++++++----------- telegram/_bot.py | 2 +- telegram/ext/_application.py | 39 +++++++++++++++++++----------- telegram/ext/_builders.py | 21 +++++++++------- telegram/ext/_callbackcontext.py | 23 ++++++------------ telegram/ext/_defaults.py | 5 ++-- telegram/ext/_dictpersistence.py | 4 +-- telegram/ext/_picklepersistence.py | 4 +-- telegram/request/_baserequest.py | 12 ++++----- telegram/request/_httpxrequest.py | 6 ++--- telegram/request/_requestdata.py | 6 ++--- 11 files changed, 76 insertions(+), 74 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1a5ff48f111..46adc79744f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -521,24 +521,24 @@ def autodoc_process_bases(app, name, obj, option, bases: list): bases[idx] = f':class:`{base}`' # Now convert `telegram._message.Message` to `telegram.Message` etc - match = re.search(pattern=r"(telegram(\.ext|))\.", string=base) - if match and '_utils' not in base: - base = base.rstrip("'>") - parts = base.rsplit(".", maxsplit=2) + match = re.search(pattern=r"(telegram(\.ext|))\.[_\w\.]+", string=base) + if not match or '_utils' in base: + return - # Replace private base classes with their respective parent - parts[-1] = PRIVATE_BASE_CLASSES.get(parts[-1], parts[-1]) + parts = match.group(0).split(".") - # To make sure that e.g. `telegram.ext.filters.BaseFilter` keeps the `filters` part - if not parts[-2].startswith('_') and '_' not in parts[0]: - base = '.'.join(parts[-2:]) - else: - base = parts[-1] + # Remove private paths + for index, part in enumerate(parts): + if part.startswith("_"): + parts = parts[:index] + parts[-1:] + break - # add `telegram(.ext).` back in front - base = f'{match.group(0)}{base}' + # Replace private base classes with their respective parent + parts = [PRIVATE_BASE_CLASSES.get(part, part) for part in parts] - bases[idx] = f':class:`{base}`' + base = ".".join(parts) + + bases[idx] = f':class:`{base}`' def setup(app: Sphinx): diff --git a/telegram/_bot.py b/telegram/_bot.py index ab624fac705..ba3eb91ad31 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -363,7 +363,7 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """Stop & clear resources used by this class. Currently just calls - :meth:`telegram.request.BaseRequest.stop` for :attr:`request`. + :meth:`telegram.request.BaseRequest.shutdown` for the request objects used by this bot. """ if self._initialized: await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown()) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index e8203dc2ee5..cb91ee93048 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -117,10 +117,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. update_queue (:class:`asyncio.Queue`): The synchronized queue that will contain the updates. + updater (:class:`telegram.ext.Updater`, optional): The updater used by this application. job_queue (:class:`telegram.ext.JobQueue`): Optional. The :class:`telegram.ext.JobQueue` instance to pass onto handler callbacks. - concurrent_updates (:obj:`int`, optional): Number of maximum concurrent worker threads for - the ``@run_async`` decorator and :meth:`run_async`. + concurrent_updates (:obj:`int`, optional): Number updates that may be processed in + parallel. chat_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for the chat. @@ -148,11 +149,12 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): .. seealso:: :meth:`add_handler`, :meth:`add_handlers`. error_handlers (Dict[:obj:`callable`, :obj:`bool`]): A dict, where the keys are error - handlers and the values indicate whether they are to be run asynchronously via - :meth:`run_async`. + handlers and the values indicate whether they are to be run blocking. .. seealso:: :meth:`add_error_handler` + context_types (:class:`telegram.ext.ContextTypes`): Specifies the types used by this + dispatcher for the ``context`` argument of handler and job callbacks. """ @@ -385,7 +387,7 @@ async def start(self, ready: Event = None) -> None: Note: This does *not* start fetching updates from Telegram. You need either start - :attr:`updater` manually or use one of :attr:`run_polling` or :attr:`run_webhook`. + :attr:`updater` manually or use one of :meth:`run_polling` or :meth:`run_webhook`. Args: ready (:obj:`asyncio.Event`, optional): If specified, the event will be set once the @@ -478,6 +480,9 @@ def run_polling( drop_pending_updates: bool = None, ready: asyncio.Event = None, ) -> None: + """Temp docstring to make this referencable + #TODO: Adda meaningful description + """ if not self.updater: raise RuntimeError( 'Application.run_polling is only available if the application has an Updater.' @@ -517,6 +522,9 @@ def run_webhook( max_connections: int = 40, ready: asyncio.Event = None, ) -> None: + """Temp docstring to make this referencable + #TODO: Adda meaningful description + """ if not self.updater: raise RuntimeError( 'Application.run_webhook is only available if the application has an Updater.' @@ -553,7 +561,7 @@ def __run(self, updater_coroutine: Coroutine, ready: asyncio.Event = None) -> No loop.close() def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Task: - """Thin wrapper around :meth:`asyncio.create_task` that handles exceptions raised by + """Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by the ``coroutine`` with :meth:`dispatch_error`. Note: @@ -833,20 +841,23 @@ async def migrate_chat_data( self, message: 'Message' = None, old_chat_id: int = None, new_chat_id: int = None ) -> None: """Moves the contents of :attr:`chat_data` at key old_chat_id to the key new_chat_id. - Also updates the persistence by calling :attr:`update_persistence`. + Also marks the entries to be updated accordingly in the next run of + :meth:`update_persistence`. + Warning: * Any data stored in :attr:`chat_data` at key `new_chat_id` will be overridden * The key `old_chat_id` of :attr:`chat_data` will be deleted Args: - message (:class:`Message`, optional): A message with either - :attr:`telegram.Message.migrate_from_chat_id` or - :attr:`telegram.Message.migrate_to_chat_id`. - Mutually exclusive with passing ``old_chat_id`` and ``new_chat_id`` + message (:class:`telegram.Message`, optional): A message with either + :attr:`~telegram.Message.migrate_from_chat_id` or + :attr:`~telegram.Message.migrate_to_chat_id`. + Mutually exclusive with passing :paramref:`old_chat_id`` and + :paramref:`new_chat_id` .. seealso: `telegram.ext.filters.StatusUpdate.MIGRATE` old_chat_id (:obj:`int`, optional): The old chat ID. - Mutually exclusive with passing ``message`` + Mutually exclusive with passing :paramref:`message` new_chat_id (:obj:`int`, optional): The new chat ID. - Mutually exclusive with passing ``message`` + Mutually exclusive with passing :paramref:`message` """ if message and (old_chat_id or new_chat_id): raise ValueError("Message and chat_id pair are mutually exclusive") @@ -885,7 +896,7 @@ async def _persistence_updater(self) -> None: async def update_persistence(self) -> None: """Updates :attr:`user_data`, :attr:`chat_data`, :attr:`bot_data` in :attr:`persistence` - along with :attr:`~telegram.ExtBot.callback_data_cache` and the conversation states of + along with :attr:`~telegram.ext.ExtBot.callback_data_cache` and the conversation states of any persistent :class:`~telegram.ext.ConversationHandler` registered for this application. For :attr:`user_data`, :attr:`chat_data`, only entries accessed since the last run of this diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index 1c0637a04cf..5c43828a93f 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -378,8 +378,8 @@ def _request_param_check(self, get_updates: bool) -> None: ) def request(self: BuilderType, request: BaseRequest) -> BuilderType: - """Sets a :class:`telegram.utils.Request` object to be used for the ``request`` parameter - of :attr:`telegram.ext.Application.bot`. + """Sets a :class:`telegram.request.BaseRequest` object to be used for the ``request`` + parameter of :attr:`telegram.ext.Application.bot`. .. seealso:: :meth:`get_updates_request` @@ -424,7 +424,7 @@ def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderTyp return self def get_updates_request(self: BuilderType, request: BaseRequest) -> BuilderType: - """Sets a :class:`telegram.utils.Request` object to be used for the ``get_updates_request`` + """Sets a :class:`telegram.request.BaseRequest` object to be used for the ``get_updates_request`` parameter of :attr:`telegram.ext.Application.bot`. .. seealso:: :meth:`request` @@ -599,15 +599,18 @@ def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: return self def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) -> BuilderType: - """Sets the number of worker threads to be used for - :meth:`telegram.ext.Application.run_async`, i.e. the number of callbacks that can be run - asynchronously at the same time. + """Specifies if and how many updates may be processed concurrently instead of one by one. + + Warning: + Processing updates concurrently is not recommended when stateful handlers like + :class:`telegram.ext.ConversationHandler` are used. - .. seealso:: :paramref:`telegram.ext.Handler.run_async`, - :attr:`telegram.ext.Defaults.block` + .. seealso:: :paramref:`telegram.ext.Application.concurrent_updates` Args: - concurrent_updates (:obj:`int`): The number of worker threads. + concurrent_updates (:obj:`bool` | :obj:`int`): Passing :obj:`True` will allow for 4096 + updates to be processed concurrently. Pass an integer to specify a different number + of updates that may be processed concurrently. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. diff --git a/telegram/ext/_callbackcontext.py b/telegram/ext/_callbackcontext.py index 790dda9775a..e702974b5d5 100644 --- a/telegram/ext/_callbackcontext.py +++ b/telegram/ext/_callbackcontext.py @@ -63,10 +63,10 @@ class CallbackContext(Generic[BT, UD, CD, BD]): use a fairly unique name for the attributes. Warning: - Do not combine custom attributes and ``@run_async``/ - :func:`telegram.ext.Dispatcher.run_async`. Due to how ``run_async`` works, it will - almost certainly execute the callbacks for an update out of order, and the attributes - that you think you added will not be present. + Do not combine custom attributes with :paramref:`telegram.ext.Handler.block` set to + :obj:`False` or :paramref:`telegram.ext.Application.concurrent_updates` set to + :obj:`True`. Due to how those work, it will almost certainly execute the callbacks for an + update out of order, and the attributes that you think you added will not be present. Args: application (:class:`telegram.ext.Application`): The application associated with this @@ -84,12 +84,6 @@ class CallbackContext(Generic[BT, UD, CD, BD]): text after the command, using any whitespace string as a delimiter. error (:obj:`Exception`): Optional. The error that was raised. Only present when passed to a error handler registered with :attr:`telegram.ext.Application.add_error_handler`. - async_args (List[:obj:`object`]): Optional. Positional arguments of the function that - raised the error. Only present when the raising function was run asynchronously using - :meth:`telegram.ext.Application.run_async`. - async_kwargs (Dict[:obj:`str`, :obj:`object`]): Optional. Keyword arguments of the function - that raised the error. Only present when the raising function was run asynchronously - using :meth:`telegram.ext.Application.run_async`. job (:class:`telegram.ext.Job`): Optional. The job which originated this callback. Only present when passed to the callback of :class:`telegram.ext.Job` or in error handlers if the error is caused by a job. @@ -268,18 +262,15 @@ def from_error( .. seealso:: :meth:`telegram.ext.Application.add_error_handler` + .. versionchanged:: 14.0 + Removed arguments ``async_args`` and ``async_kwargs``. + Args: update (:obj:`object` | :class:`telegram.Update`): The update associated with the error. May be :obj:`None`, e.g. for errors in job callbacks. error (:obj:`Exception`): The error. application (:class:`telegram.ext.Application`): The application associated with this context. - async_args (List[:obj:`object`], optional): Positional arguments of the function that - raised the error. Pass only when the raising function was run asynchronously using - :meth:`telegram.ext.Application.run_async`. - async_kwargs (Dict[:obj:`str`, :obj:`object`], optional): Keyword arguments of the - function that raised the error. Pass only when the raising function was run - asynchronously using :meth:`telegram.ext.Application.run_async`. job (:class:`telegram.ext.Job`, optional): The job associated with the error. .. versionadded:: 14.0 diff --git a/telegram/ext/_defaults.py b/telegram/ext/_defaults.py index 1f35813b594..81460795d22 100644 --- a/telegram/ext/_defaults.py +++ b/telegram/ext/_defaults.py @@ -30,8 +30,7 @@ class Defaults: .. versionchanged:: 14.0 Removed the argument and attribute ``timeout``. Specify default timeout behavior for the - networking backend directly via :class:`telegram.ext.UpdaterBuilder` or - :class:`telegram.ext.ApplicationBuilder` instead. + networking backend directly via :class:`telegram.ext.ApplicationBuilder` instead. Parameters: @@ -205,7 +204,7 @@ def block(self) -> bool: @block.setter def block(self, value: object) -> NoReturn: - raise AttributeError("You can not assign a new value to run_async after initialization.") + raise AttributeError("You can not assign a new value to block after initialization.") @property def protect_content(self) -> Optional[bool]: diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index aaceef9356b..655c8c5913c 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -35,8 +35,8 @@ class DictPersistence(BasePersistence): Attention: The interface provided by this class is intended to be accessed exclusively by - :class:`~telegram.ext.Dispatcher`. Calling any of the methods below manually might - interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`. + :class:`~telegram.ext.Application`. Calling any of the methods below manually might + interfere with the integration of persistence into :class:`~telegram.ext.Application`. Note: This class does *not* implement a :meth:`flush` method, meaning that data managed by diff --git a/telegram/ext/_picklepersistence.py b/telegram/ext/_picklepersistence.py index 64603017964..90d5c37a14b 100644 --- a/telegram/ext/_picklepersistence.py +++ b/telegram/ext/_picklepersistence.py @@ -39,8 +39,8 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): Attention: The interface provided by this class is intended to be accessed exclusively by - :class:`~telegram.ext.Dispatcher`. Calling any of the methods below manually might - interfere with the integration of persistence into :class:`~telegram.ext.Dispatcher`. + :class:`~telegram.ext.Application`. Calling any of the methods below manually might + interfere with the integration of persistence into :class:`~telegram.ext.Application`. Warning: :class:`PicklePersistence` will try to replace :class:`telegram.Bot` instances by diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index 260e3ace1f5..00f37faebe3 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -54,8 +54,7 @@ class BaseRequest( ): """Abstract interface class that allows python-telegram-bot to make requests to the Bot API. Can be implemented via different asyncio HTTP libraries. An implementation of this class - must implement all abstract methods and properties. In addition, :attr:`connection_pool_size` - can optionally be overridden. + must implement all abstract methods and properties. Instances of this class can be used as asyncio context managers, where @@ -130,9 +129,8 @@ async def post( """Makes a request to the Bot API handles the return code and parses the answer. Warning: - This method will be called by the methods of :class:`Bot` and should *not* be called - manually. - + This method will be called by the methods of :class:`telegram.Bot` and should *not* be + called manually. Args: url (:obj:`str`): The URL to request. @@ -181,8 +179,8 @@ async def retrieve( """Retrieve the contents of a file by its URL. Warning: - This method will be called by the methods of :class:`Bot` and should *not* be called - manually. + This method will be called by the methods of :class:`telegram.Bot` and should *not* be + called manually. Args: url (:obj:`str`): The web location we want to retrieve. diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 6654eb89b2f..afac4902a4e 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -42,7 +42,7 @@ class HTTPXRequest(BaseRequest): Args: connection_pool_size (:obj:`int`, optional): Number of connections to keep in the - connection pool. Defaults to :obj:`1`. + connection pool. Defaults to ``1``. Note: Independent of the value, one additional connection will be reserved for @@ -73,7 +73,7 @@ class HTTPXRequest(BaseRequest): infinite timeout. Defaults to :obj:`None`. Warning: - With a finite pool timeout, you must expect :exc:`telegram.error.TimeOut` + With a finite pool timeout, you must expect :exc:`telegram.error.TimedOut` exceptions to be thrown when more requests are made simultaneously than there are connections in the connection pool! """ @@ -113,7 +113,7 @@ async def initialize(self) -> None: """See :meth:`BaseRequest.initialize`.""" async def shutdown(self) -> None: - """See :meth:`BaseRequest.stop`.""" + """See :meth:`BaseRequest.shutdown`.""" await self._client.aclose() async def do_request( diff --git a/telegram/request/_requestdata.py b/telegram/request/_requestdata.py index 156f4e364ee..93fba9e6608 100644 --- a/telegram/request/_requestdata.py +++ b/telegram/request/_requestdata.py @@ -71,11 +71,11 @@ def json_parameters(self) -> Dict[str, str]: return {param.name: param.json_value for param in self._parameters} def url_encoded_parameters(self, encode_kwargs: Dict[str, Any] = None) -> str: - """Encodes the parameters with :meth:`urllib.parse.urlencode`. + """Encodes the parameters with :func:`urllib.parse.urlencode`. Args: encode_kwargs (Dict[:obj:`str`, any], optional): Additional keyword arguments to pass - along to :meth:`urllib.parse.urlencode`. + along to :func:`urllib.parse.urlencode`. """ if encode_kwargs: return urlencode(self.json_parameters, **encode_kwargs) @@ -88,7 +88,7 @@ def parametrized_url(self, url: str, encode_kwargs: Dict[str, Any] = None) -> st Args: url (:obj:`str`): The URL the parameters will be attached to. encode_kwargs (Dict[:obj:`str`, any], optional): Additional keyword arguments to pass - along to :attr:`urllib.parse.urlencode`. + along to :func:`urllib.parse.urlencode`. """ url_parameters = self.url_encoded_parameters(encode_kwargs=encode_kwargs) return f'{url}?{url_parameters}' From d0565fc2fdb9a088d662283fb4e5e651ad3ff291 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Feb 2022 13:41:55 +0100 Subject: [PATCH 005/153] Try fixing wrong source links on RTD --- docs/source/conf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 46adc79744f..b59bc35189e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -453,10 +453,14 @@ def _git_branch() -> str: """Get's the current git sha if available or fall back to `master`""" try: output = subprocess.check_output( # skipcq: BAN-B607 - ["git", "describe", "--tags"], stderr=subprocess.STDOUT + ["git", "describe", "--tags", "--always"], stderr=subprocess.STDOUT ) return output.decode().strip() - except Exception: + except Exception as exc: + sphinx_logger.exception( + f'Failed to get a description of the current commit. Falling back to `master`.', + exc_info=exc + ) return 'master' From 985d24fcb630e39553e904ca517a9b8b8eae0543 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Feb 2022 21:31:50 +0100 Subject: [PATCH 006/153] Slowly get started on request tests --- telegram/request/_httpxrequest.py | 5 +- tests/test_request.py | 87 ++++++++++++++++--------------- tests/test_requestdata.py | 6 +-- 3 files changed, 51 insertions(+), 47 deletions(-) diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index afac4902a4e..0b6b10c01e8 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -90,13 +90,12 @@ def __init__( pool_timeout: Optional[float] = 1.0, ): self.__pool_semaphore = asyncio.BoundedSemaphore(connection_pool_size) - self._pool_timeout = pool_timeout timeout = httpx.Timeout( connect=connect_timeout, read=read_timeout, write=write_timeout, - pool=1, + pool=pool_timeout, ) limits = httpx.Limits( max_connections=connection_pool_size, @@ -128,7 +127,7 @@ async def do_request( ) -> Tuple[int, bytes]: """See :meth:`BaseRequest.do_request`.""" if isinstance(pool_timeout, DefaultValue): - pool_timeout = self._pool_timeout + pool_timeout = self._client.timeout.pool if pool_timeout != 0 and self.__pool_semaphore.locked(): _logger.debug( diff --git a/tests/test_request.py b/tests/test_request.py index dd813e4e016..21e2f8329c9 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -19,6 +19,7 @@ """Here we run tests directly with HTTPXRequest because that's easier than providing dummy implementations for BaseRequest and we want to test HTTPXRequest anyway.""" import json +from dataclasses import dataclass from http import HTTPStatus from typing import Tuple, Any, Coroutine, Callable @@ -26,7 +27,6 @@ import pytest from telegram._utils.defaultvalue import DEFAULT_NONE -from telegram._utils.types import ODVInput from telegram.error import ( TelegramError, ChatMigrated, @@ -69,9 +69,6 @@ async def httpx_request(): yield rq -# TODO: Test timeouts - - class TestRequest: test_flag = None @@ -82,6 +79,8 @@ def reset(self): def test_slot_behaviour(self, mro_slots): inst = HTTPXRequest() for attr in inst.__slots__: + if attr.startswith('__'): + attr = f'_{inst.__class__.__name__}{attr}' assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" @@ -90,13 +89,13 @@ async def test_context_manager(self, monkeypatch): async def initialize(): self.test_flag = ['initialize'] - async def stop(): + async def shutdown(): self.test_flag.append('stop') httpx_request = HTTPXRequest() monkeypatch.setattr(httpx_request, 'initialize', initialize) - monkeypatch.setattr(httpx_request, 'stop', stop) + monkeypatch.setattr(httpx_request, 'shutdown', shutdown) async with httpx_request: pass @@ -108,13 +107,13 @@ async def test_context_manager_exception_on_init(self, monkeypatch): async def initialize(): raise RuntimeError('initialize') - async def stop(): + async def shutdown(): self.test_flag = 'stop' httpx_request = HTTPXRequest() monkeypatch.setattr(httpx_request, 'initialize', initialize) - monkeypatch.setattr(httpx_request, 'stop', stop) + monkeypatch.setattr(httpx_request, 'shutdown', shutdown) with pytest.raises(RuntimeError, match='initialize'): async with httpx_request: @@ -237,7 +236,11 @@ async def test_special_errors( ['exception', 'catch_class', 'match'], [ (TelegramError('TelegramError'), TelegramError, 'TelegramError'), - (RuntimeError('CustomError'), Exception, 'HTTP implementation: CustomError'), + ( + RuntimeError('CustomError'), + Exception, + "HTTP implementation: RuntimeError\('CustomError'\)", + ), ], ) async def test_exceptions_in_do_request( @@ -266,63 +269,65 @@ async def test_retrieve(self, monkeypatch, httpx_request): assert await httpx_request.retrieve(None, None) == server_response - def test_connection_pool_size(self): - class Request(BaseRequest): - async def do_request(self, *args, **kwargs): - pass - - async def initialize(self, *args, **kwargs): - pass - - async def shutdown(self, *args, **kwargs): - pass - - with pytest.raises(NotImplementedError): - Request().connection_pool_size - @pytest.mark.asyncio async def test_timeout_propagation(self, monkeypatch, httpx_request): - """Here we just test that retrieve gives us the raw bytes instead of trying to parse them - as json - """ - - async def make_assertion( - method: str, - url: str, - request_data: RequestData = None, - read_timeout: ODVInput[float] = DEFAULT_NONE, - *args, - **kwargs, - ): - self.test_flag = read_timeout + async def make_assertion(*args, **kwargs): + self.test_flag = ( + kwargs.get('read_timeout'), + kwargs.get('connect_timeout'), + kwargs.get('write_timeout'), + kwargs.get('pool_timeout'), + ) return HTTPStatus.OK, b'{"ok": "True", "result": {}}' monkeypatch.setattr(httpx_request, 'do_request', make_assertion) - await httpx_request.post('url', None, read_timeout=42.314) - assert self.test_flag == 42.314 + await httpx_request.post('url', 'method') + assert self.test_flag == (DEFAULT_NONE, DEFAULT_NONE, DEFAULT_NONE, DEFAULT_NONE) + + await httpx_request.post( + 'url', None, read_timeout=1, connect_timeout=2, write_timeout=3, pool_timeout=4 + ) + assert self.test_flag == (1, 2, 3, 4) class TestHTTPXRequest: + # TODO: Properly timeouts + test_flag = None @pytest.fixture(autouse=True) def reset(self): self.test_flag = None - def test_init(self): + def test_init(self, monkeypatch): + @dataclass + class Client: + timeout: object + proxies: object + limits: object + + monkeypatch.setattr(httpx, 'AsyncClient', Client) + request = HTTPXRequest() - assert request.connection_pool_size == 1 assert request._client.timeout == httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + assert request._client.proxies is None + assert request._client.limits == httpx.Limits( + max_connections=1, max_keepalive_connections=1 + ) request = HTTPXRequest( connection_pool_size=42, + proxy_url='proxy_url', connect_timeout=43, read_timeout=44, write_timeout=45, pool_timeout=46, ) - assert request.connection_pool_size == 42 + assert request._client.proxies == 'proxy_url' + assert request._client.limits == httpx.Limits( + max_connections=42, max_keepalive_connections=42 + ) assert request._client.timeout == httpx.Timeout(connect=43, read=44, write=45, pool=46) @pytest.mark.asyncio diff --git a/tests/test_requestdata.py b/tests/test_requestdata.py index a8b0356c195..3a38a8e82b8 100644 --- a/tests/test_requestdata.py +++ b/tests/test_requestdata.py @@ -145,7 +145,7 @@ def test_parameters( ): assert simple_rqs.parameters == simple_params # We don't test these for now since that's a struggle - # And the conversation part is already being tested in test_requestparameter.py + # And the conversion part is already being tested in test_requestparameter.py # assert file_rqs.parameters == file_params # assert mixed_rqs.parameters == mixed_params @@ -192,7 +192,7 @@ def test_url_encoding(self, monkeypatch): expected_params = 'chat_id=123&text=Hello+there%2F%21' expected_url = 'https://te.st/method?' + expected_params assert data.url_encoded_parameters() == expected_params - assert data.build_parametrized_url('https://te.st/method') == expected_url + assert data.parametrized_url('https://te.st/method') == expected_url expected_params = 'chat_id=123&text=Hello%20there/!' expected_url = 'https://te.st/method?' + expected_params @@ -201,7 +201,7 @@ def test_url_encoding(self, monkeypatch): == expected_params ) assert ( - data.build_parametrized_url( + data.parametrized_url( 'https://te.st/method', encode_kwargs={'quote_via': quote, 'safe': '/!'} ) == expected_url From 916fc20d7521ef0a0b47e4af7131b218d2d7e696 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Feb 2022 17:46:22 +0100 Subject: [PATCH 007/153] More request tests --- telegram/request/_httpxrequest.py | 17 +-- tests/test_request.py | 178 ++++++++++++++++++++++++------ 2 files changed, 154 insertions(+), 41 deletions(-) diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 0b6b10c01e8..ca23c474c36 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -129,12 +129,13 @@ async def do_request( if isinstance(pool_timeout, DefaultValue): pool_timeout = self._client.timeout.pool - if pool_timeout != 0 and self.__pool_semaphore.locked(): - _logger.debug( - 'All connections in the pool are currently busy. Waiting pool_timeout=%s for ' - 'a connection to become available.', - pool_timeout, - ) + # TODO: This doesn't seem to work. + # if pool_timeout != 0 and self.__pool_semaphore.locked(): + # _logger.debug( + # 'All connections in the pool are currently busy. Waiting pool_timeout=%s for ' + # 'a connection to become available.', + # pool_timeout, + # ) try: await asyncio.wait_for(self.__pool_semaphore.acquire(), timeout=pool_timeout) @@ -145,6 +146,7 @@ async def do_request( out = await self._do_request( url=url, method=method, + pool_timeout=pool_timeout, request_data=request_data, connect_timeout=connect_timeout, read_timeout=read_timeout, @@ -158,6 +160,7 @@ async def _do_request( self, url: str, method: str, + pool_timeout: Optional[float], request_data: RequestData = None, connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, read_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, @@ -167,7 +170,7 @@ async def _do_request( connect=self._client.timeout.connect, read=self._client.timeout.read, write=self._client.timeout.write, - pool=1, + pool=pool_timeout, ) if not isinstance(read_timeout, DefaultValue): timeout.read = read_timeout diff --git a/tests/test_request.py b/tests/test_request.py index 21e2f8329c9..e75d0b05af4 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -18,7 +18,9 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """Here we run tests directly with HTTPXRequest because that's easier than providing dummy implementations for BaseRequest and we want to test HTTPXRequest anyway.""" +import asyncio import json +import logging from dataclasses import dataclass from http import HTTPStatus from typing import Tuple, Any, Coroutine, Callable @@ -38,7 +40,6 @@ Conflict, TimedOut, ) -from telegram.request import BaseRequest, RequestData from telegram.request._httpxrequest import HTTPXRequest # We only need the first fixture, but it uses the others, so pytest needs us to import them as well @@ -239,7 +240,7 @@ async def test_special_errors( ( RuntimeError('CustomError'), Exception, - "HTTP implementation: RuntimeError\('CustomError'\)", + r"HTTP implementation: RuntimeError\('CustomError'\)", ), ], ) @@ -292,8 +293,6 @@ async def make_assertion(*args, **kwargs): class TestHTTPXRequest: - # TODO: Properly timeouts - test_flag = None @pytest.fixture(autouse=True) @@ -368,65 +367,86 @@ async def aclose(*args): assert self.test_flag == 'stop' @pytest.mark.asyncio - async def test_do_request_default_timeouts(self, monkeypatch, httpx_request): - default_timeouts = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + async def test_do_request_default_timeouts(self, monkeypatch): + default_timeouts = httpx.Timeout(connect=42, read=43, write=44, pool=45) - async def make_assertion(self, method, url, headers, timeout, files, data): - self.test_flag = timeout == default_timeouts + async def make_assertion(_, **kwargs): + self.test_flag = kwargs.get('timeout') == default_timeouts return httpx.Response(HTTPStatus.OK) - monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) - await httpx_request.do_request('GET', 'URL') - assert httpx_request._client.timeout == default_timeouts + async with HTTPXRequest( + connect_timeout=default_timeouts.connect, + read_timeout=default_timeouts.read, + write_timeout=default_timeouts.write, + pool_timeout=default_timeouts.pool, + ) as httpx_request: + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + await httpx_request.do_request(method='GET', url='URL') + + assert self.test_flag @pytest.mark.asyncio async def test_do_request_manual_timeouts(self, monkeypatch, httpx_request): - default_timeouts = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + default_timeouts = httpx.Timeout(connect=42, read=43, write=44, pool=45) + manual_timeouts = httpx.Timeout(connect=52, read=53, write=54, pool=55) - async def make_assertion(self, method, url, headers, timeout, files, data): - self.test_flag = timeout == httpx.Timeout(connect=5.0, read=5.5, write=5.6, pool=1.0) + async def make_assertion(_, **kwargs): + print(kwargs.get('timeout'), manual_timeouts) + self.test_flag = kwargs.get('timeout') == manual_timeouts return httpx.Response(HTTPStatus.OK) - monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) - await httpx_request.do_request('GET', 'URL', read_timeout=5.5, write_timeout=5.6) - assert httpx_request._client.timeout == default_timeouts + async with HTTPXRequest( + connect_timeout=default_timeouts.connect, + read_timeout=default_timeouts.read, + write_timeout=default_timeouts.write, + pool_timeout=default_timeouts.pool, + ) as httpx_request: + + monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) + await httpx_request.do_request( + method='GET', + url='URL', + connect_timeout=manual_timeouts.connect, + read_timeout=manual_timeouts.read, + write_timeout=manual_timeouts.write, + pool_timeout=manual_timeouts.pool, + ) + + assert self.test_flag @pytest.mark.asyncio async def test_do_request_params_no_data(self, monkeypatch, httpx_request): - async def make_assertion(self, method, url, headers, timeout, files, data): - method_assertion = method == 'method' - url_assertion = url == 'url' - files_assertion = files is None - data_assertion = data is None + async def make_assertion(self, **kwargs): + method_assertion = kwargs.get('method') == 'method' + url_assertion = kwargs.get('url') == 'url' + files_assertion = kwargs.get('files') is None + data_assertion = kwargs.get('data') is None if method_assertion and url_assertion and files_assertion and data_assertion: return httpx.Response(HTTPStatus.OK) return httpx.Response(HTTPStatus.BAD_REQUEST) monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) - code, _ = await httpx_request.do_request( - 'method', 'url', read_timeout=5.5, write_timeout=5.6 - ) + code, _ = await httpx_request.do_request(method='method', url='url') assert code == HTTPStatus.OK @pytest.mark.asyncio async def test_do_request_params_with_data( self, monkeypatch, httpx_request, mixed_rqs # noqa: 9811 ): - async def make_assertion(self, method, url, headers, timeout, files, data): - method_assertion = method == 'method' - url_assertion = url == 'url' - files_assertion = files == mixed_rqs.multipart_data - data_assertion = data == mixed_rqs.json_parameters + async def make_assertion(self, **kwargs): + method_assertion = kwargs.get('method') == 'method' + url_assertion = kwargs.get('url') == 'url' + files_assertion = kwargs.get('files') == mixed_rqs.multipart_data + data_assertion = kwargs.get('data') == mixed_rqs.json_parameters if method_assertion and url_assertion and files_assertion and data_assertion: return httpx.Response(HTTPStatus.OK) return httpx.Response(HTTPStatus.BAD_REQUEST) monkeypatch.setattr(httpx.AsyncClient, 'request', make_assertion) code, _ = await httpx_request.do_request( - 'method', - 'url', - read_timeout=5.5, - write_timeout=5.6, + method='method', + url='url', request_data=mixed_rqs, ) assert code == HTTPStatus.OK @@ -462,3 +482,93 @@ async def make_assertion(self, method, url, headers, timeout, files, data): 'method', 'url', ) + + @pytest.mark.asyncio + async def test_do_request_pool_timeout(self, monkeypatch): + async def request(self, **kwargs): + await asyncio.sleep(0.05) + return httpx.Response(HTTPStatus.OK) + + monkeypatch.setattr(httpx.AsyncClient, 'request', request) + + with pytest.raises(TimedOut, match='Pool timeout'): + async with HTTPXRequest(pool_timeout=0.02) as httpx_request: + await asyncio.gather( + httpx_request.do_request(method='GET', url='URL'), + httpx_request.do_request(method='GET', url='URL'), + ) + + @pytest.mark.asyncio + async def test_do_request_wait_for_pool(self, monkeypatch, httpx_request): + async def request(_, **kwargs): + await asyncio.sleep(0.05) + if self.test_flag is None: + self.test_flag = 1 + else: + self.test_flag += 1 + return httpx.Response(HTTPStatus.OK) + + monkeypatch.setattr(httpx.AsyncClient, 'request', request) + + task_1 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) + task_2 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) + await asyncio.sleep(0.07) + assert self.test_flag == 1 + await asyncio.sleep(0.07) + assert self.test_flag == 2 + await asyncio.gather(task_1, task_2) + + @pytest.mark.asyncio + async def test_do_request_wait_for_pool_with_exception(self, monkeypatch, httpx_request): + async def request(_, **kwargs): + await asyncio.sleep(0.05) + if self.test_flag is None: + self.test_flag = 1 + raise RuntimeError('Raising Exception') + else: + self.test_flag += 1 + return httpx.Response(HTTPStatus.OK) + + monkeypatch.setattr(httpx.AsyncClient, 'request', request) + + task_1 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) + task_2 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) + await asyncio.sleep(0.06) + assert self.test_flag == 1 + await asyncio.sleep(0.06) + assert self.test_flag == 2 + out = await asyncio.gather(task_1, task_2, return_exceptions=True) + assert sum(isinstance(entry, RuntimeError) for entry in out) == 1 + + @pytest.mark.asyncio + async def test_do_request_critical_timeout_logging(self, httpx_request, caplog, monkeypatch): + async def request(self, **kwargs): + raise httpx.PoolTimeout('pool timeout') + + monkeypatch.setattr(httpx.AsyncClient, 'request', request) + + with pytest.raises(TimedOut): + with caplog.at_level(logging.CRITICAL): + await httpx_request.do_request(method='GET', url='URL') + + assert len(caplog.records) == 1 + assert ( + 'All connections in the connection pool are occupied' in caplog.records[0].getMessage() + ) + + # @pytest.mark.asyncio + # async def test_do_request_busy_logging(self, httpx_request, caplog, monkeypatch): + # async def request(self, **kwargs): + # await asyncio.sleep(0.5) + # return httpx.Response(HTTPStatus.OK) + # + # monkeypatch.setattr(httpx.AsyncClient, 'request', request) + # + # with caplog.at_level(logging.DEBUG): + # await asyncio.gather( + # httpx_request.do_request(method='GET', url='URL'), + # httpx_request.do_request(method='GET', url='URL'), + # ) + # + # assert len(caplog.records) == 1 + # assert 'currently busy. Waiting pool_timeout' in caplog.records[0].getMessage() From c6b898aa86a853ee9e0975f2b4aadde307a66305 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Feb 2022 20:12:25 +0100 Subject: [PATCH 008/153] Add two comments --- telegram/ext/_application.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index cb91ee93048..27444d5883e 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -548,12 +548,14 @@ def run_webhook( ) def __run(self, updater_coroutine: Coroutine, ready: asyncio.Event = None) -> None: - loop = asyncio.get_event_loop() + # TODO: get_event_loop is deprecated - switch to get_running_loop() + loop = asyncio.get_event_loop() # get_running_loop() loop.run_until_complete(self.initialize()) loop.run_until_complete(self.start(ready=ready)) loop.run_until_complete(updater_coroutine) try: loop.run_forever() + # TODO: maybe allow for custom exception classes to catch here? Or provide a custom one? except (KeyboardInterrupt, SystemExit): loop.run_until_complete(self.stop()) loop.run_until_complete(self.shutdown()) From 3008aad99c1294419d3a2f5eb2d5eb24e3427399 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 13 Feb 2022 21:32:10 +0100 Subject: [PATCH 009/153] Rework persistence logic --- telegram/_bot.py | 2 +- telegram/ext/_application.py | 189 +++++++++++++-------- telegram/ext/_builders.py | 5 +- telegram/ext/_jobqueue.py | 5 + telegram/ext/_utils/trackingdefaultdict.py | 3 + 5 files changed, 132 insertions(+), 72 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index ba3eb91ad31..7e44ebbec97 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -164,7 +164,7 @@ class Bot(TelegramObject, AbstractAsyncContextManager): :class:`telegram.request.BaseRequest` instances. Will be used for all bot methods *except* for :attr:`get_updates`. If not passed, an instance of :class:`telegram.request.HTTPXRequest` will be used. - request (:class:`telegram.request.BaseRequest`, optional): Pre initialized + get_updates_request (:class:`telegram.request.BaseRequest`, optional): Pre initialized :class:`telegram.request.BaseRequest` instances. Will be used exclusively for :attr:`get_updates`. If not passed, an instance of :class:`telegram.request.HTTPXRequest` will be used. diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 27444d5883e..533b65a373f 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -41,10 +41,10 @@ Any, Set, Mapping, - cast, - MutableMapping, + DefaultDict, ) +from telegram import Update from telegram._utils.types import DVInput, ODVInput from telegram.error import TelegramError from telegram.ext import BasePersistence, ContextTypes, ExtBot, Updater @@ -167,12 +167,16 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): '__update_persistence_task', '__weakref__', '_chat_data', + '_chat_ids_to_be_deleted_in_persistence', + '_chat_ids_to_be_updated_in_persistence', '_concurrent_updates', '_concurrent_updates_sem', '_conversation_handler_conversations', '_initialized', '_running', '_user_data', + '_user_ids_to_be_deleted_in_persistence', + '_user_ids_to_be_updated_in_persistence', 'bot', 'bot_data', 'chat_data', @@ -224,30 +228,22 @@ def __init__( self.job_queue.set_application(self) self.bot_data = self.context_types.bot_data() + self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data) + self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data) + # Read only mapping + self.user_data: Mapping[int, UD] = MappingProxyType(self._user_data) + self.chat_data: Mapping[int, CD] = MappingProxyType(self._chat_data) + self.persistence: Optional[BasePersistence] = None if persistence and not isinstance(persistence, BasePersistence): raise TypeError("persistence must be based on telegram.ext.BasePersistence") self.persistence = persistence - # Track access to chat_ids only if necessary for the persistence - if self.persistence and self.persistence.store_data.user_data: - self._user_data: MutableMapping[int, UD] = TrackingDefaultDict( - default_factory=self.context_types.user_data, track_read=True, track_write=True - ) - else: - self._user_data = defaultdict(self.context_types.user_data) - # Track access to user_ids only if necessary for the persistence - if self.persistence and self.persistence.store_data.chat_data: - self._chat_data: MutableMapping[int, CD] = TrackingDefaultDict( - # track_write = True for self.migrate_chat_data - default_factory=self.context_types.chat_data, - track_read=True, - track_write=True, - ) - else: - self._chat_data = defaultdict(self.context_types.chat_data) - # Read only mapping - self.user_data: Mapping[int, UD] = MappingProxyType(self._user_data) - self.chat_data: Mapping[int, CD] = MappingProxyType(self._chat_data) + + # Some book keeping for persistence logic + self._chat_ids_to_be_updated_in_persistence: Set[int] = set() + self._user_ids_to_be_updated_in_persistence: Set[int] = set() + self._chat_ids_to_be_deleted_in_persistence: Set[int] = set() + self._user_ids_to_be_deleted_in_persistence: Set[int] = set() # This attribute will hold references to the conversation dicts of all conversation # handlers so that we can extract the changed states during `update_persistence` @@ -339,13 +335,9 @@ async def _initialize_persistence(self) -> None: return if self.persistence.store_data.user_data: - cast(TrackingDefaultDict, self._user_data).update_no_track( - await self.persistence.get_user_data() - ) + self._user_data.update(await self.persistence.get_user_data()) if self.persistence.store_data.chat_data: - cast(TrackingDefaultDict, self._chat_data).update_no_track( - await self.persistence.get_chat_data() - ) + self._chat_data.update(await self.persistence.get_chat_data()) if self.persistence.store_data.bot_data: self.bot_data = await self.persistence.get_bot_data() if not isinstance(self.bot_data, self.context_types.bot_data): @@ -564,18 +556,20 @@ def __run(self, updater_coroutine: Coroutine, ready: asyncio.Event = None) -> No def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Task: """Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by - the ``coroutine`` with :meth:`dispatch_error`. + the :paramref:`coroutine` with :meth:`dispatch_error`. Note: - * If ``coroutine`` raises an exception, it will be set on the task created by this - method even though it's handled by :meth:`dispatch_error`. + * If :paramref:`coroutine` raises an exception, it will be set on the task created by + this method even though it's handled by :meth:`dispatch_error`. * If the application is currently running, tasks created by this methods will be awaited by :meth:`stop`. Args: coroutine: The coroutine to run as task. update: Optional. If passed, will be passed to :meth:`dispatch_error` as additional - information for the error handlers. + information for the error handlers. Moreover, the corresponding :attr:`chat_data` + and :attr:`user_data` entries will be updated in the next run of + :meth:`update_persistence` after the :paramref:`coroutine` is finished. Returns: :class:`asyncio.Task`: The created task. @@ -632,9 +626,11 @@ async def __create_task_callback( # If we arrive here, an exception happened in the task and was neither # ApplicationHandlerStop nor raised by an error handler. # So we can and must handle it - self.create_task(self.dispatch_error(update, exception, coroutine=coroutine)) + await self.dispatch_error(update, exception, coroutine=coroutine) raise exception + finally: + self._mark_update_for_persistence_update(update=update) async def _update_fetcher(self) -> None: # Continuously fetch updates from the queue. Exit only once the signal object is found. @@ -676,11 +672,6 @@ async def process_update(self, update: object) -> None: The update to process. """ - # An error happened while polling - if isinstance(update, TelegramError): - await self.dispatch_error(None, update) - return - context = None for handlers in self.handlers.values(): @@ -709,6 +700,8 @@ async def process_update(self, update: object) -> None: _logger.debug('Error handler stopped further handlers.') break + self._mark_update_for_persistence_update(update=update) + def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: """Register a handler. @@ -817,29 +810,45 @@ def remove_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None: if not self.handlers[group]: del self.handlers[group] - async def drop_chat_data(self, chat_id: int) -> None: - """Used for deleting a key from the :attr:`chat_data`. + def drop_chat_data(self, chat_id: int) -> None: + """Drops the corresponding entry from the :attr:`chat_data`. Will also be deleted from + the persistence on the next run of :meth:`update_persistence`, if applicable. + + Warning: + When using :paramref:`concurrent_updates` or the :attr:`job_queue`, + :meth:`process_update` or :meth:`telegram.ext.Job.run` may re-create this entry due to + the asynchronous nature of these features. Please make sure that your program can + avoid or handle such situations. .. versionadded:: 14.0 Args: - chat_id (:obj:`int`): The chat id to delete from the persistence. The entry - will be deleted even if it is not empty. + chat_id (:obj:`int`): The chat id to delete. The entry will be deleted even if it is + not empty. """ self._chat_data.pop(chat_id, None) # type: ignore[arg-type] + self._chat_ids_to_be_deleted_in_persistence.add(chat_id) - async def drop_user_data(self, user_id: int) -> None: - """Used for deleting a key from the :attr:`user_data`. + def drop_user_data(self, user_id: int) -> None: + """Drops the corresponding entry from the :attr:`user_data`. Will also be deleted from + the persistence on the next run of :meth:`update_persistence`, if applicable. + + Warning: + When using :paramref:`concurrent_updates` or the :attr:`job_queue`, + :meth:`process_update` or :meth:`telegram.ext.Job.run` may re-create this entry due to + the asynchronous nature of these features. Please make sure that your program can + avoid or handle such situations. .. versionadded:: 14.0 Args: - user_id (:obj:`int`): The user id to delete from the persistence. The entry - will be deleted even if it is not empty. + user_id (:obj:`int`): The user id to delete. The entry will be deleted even if it is + not empty. """ self._user_data.pop(user_id, None) # type: ignore[arg-type] + self._user_ids_to_be_deleted_in_persistence.add(user_id) - async def migrate_chat_data( + def migrate_chat_data( self, message: 'Message' = None, old_chat_id: int = None, new_chat_id: int = None ) -> None: """Moves the contents of :attr:`chat_data` at key old_chat_id to the key new_chat_id. @@ -849,6 +858,15 @@ async def migrate_chat_data( Warning: * Any data stored in :attr:`chat_data` at key `new_chat_id` will be overridden * The key `old_chat_id` of :attr:`chat_data` will be deleted + * This does not update the :attr:`~telegram.ext.Job.chat_id` attribute of any scheduled + :class:`telegram.ext.Job`. + + Warning: + When using :paramref:`concurrent_updates` or the :attr:`job_queue`, + :meth:`process_update` or :meth:`telegram.ext.Job.run` may re-create the old entry due + to the asynchronous nature of these features. Please make sure that your program can + avoid or handle such situations. + Args: message (:class:`telegram.Message`, optional): A message with either :attr:`~telegram.Message.migrate_from_chat_id` or @@ -880,6 +898,28 @@ async def migrate_chat_data( raise ValueError("old_chat_id and new_chat_id must be integers") self._chat_data[new_chat_id] = self._chat_data[old_chat_id] + self.drop_chat_data(old_chat_id) + + self._chat_ids_to_be_updated_in_persistence.add(new_chat_id) + self._chat_ids_to_be_deleted_in_persistence.add(old_chat_id) + + def _mark_update_for_persistence_update( + self, *, update: object = None, job: 'Job' = None + ) -> None: + # TODO: This should be at the end of `Application.process_update`, when the task created + # by `Application.create_task` is done and when a `Job` is done. Add tests to make sure + # that this is happening + if isinstance(update, Update): + if update.effective_chat: + self._chat_ids_to_be_updated_in_persistence.add(update.effective_chat.id) + if update.effective_user: + self._user_ids_to_be_updated_in_persistence.add(update.effective_user.id) + + if job: + if job.chat_id: + self._chat_ids_to_be_updated_in_persistence.add(job.chat_id) + if job.user_id: + self._user_ids_to_be_updated_in_persistence.add(job.user_id) async def _persistence_updater(self) -> None: # Update the persistence in regular intervals. Exit only when the stop event has been set @@ -893,15 +933,16 @@ async def _persistence_updater(self) -> None: self.__update_persistence_event.wait(), timeout=self.persistence.update_interval, ) - except asyncio.TimeoutError: return + except asyncio.TimeoutError: + pass async def update_persistence(self) -> None: """Updates :attr:`user_data`, :attr:`chat_data`, :attr:`bot_data` in :attr:`persistence` along with :attr:`~telegram.ext.ExtBot.callback_data_cache` and the conversation states of any persistent :class:`~telegram.ext.ConversationHandler` registered for this application. - For :attr:`user_data`, :attr:`chat_data`, only entries accessed since the last run of this + For :attr:`user_data`, :attr:`chat_data`, only entries used since the last run of this method are updated. Tip: @@ -936,28 +977,38 @@ async def __update_persistence(self) -> None: coroutines.add(self.persistence.update_bot_data(deepcopy(self.bot_data))) if self.persistence.store_data.chat_data: - # Mypy can't handle the conditional assignment in `__init__` - chat_data = cast(TrackingDefaultDict, self._chat_data) - for chat_id, data in chat_data.pop_accessed_read_items(): - coroutines.add(self.persistence.update_chat_data(chat_id, deepcopy(data))) - for chat_id, data in chat_data.pop_accessed_write_items(): - if data is not chat_data.DELETED: - _logger.critical('`Application._chat_data[%s]` was written manually', chat_id) - coroutines.add(self.persistence.update_chat_data(chat_id, deepcopy(data))) - else: - coroutines.add(self.persistence.drop_chat_data(chat_id)) + update_ids = self._chat_ids_to_be_updated_in_persistence + self._chat_ids_to_be_updated_in_persistence = set() + delete_ids = self._chat_ids_to_be_deleted_in_persistence + self._chat_ids_to_be_deleted_in_persistence = set() + + # We don't want to update any data that has been deleted! + update_ids -= delete_ids + print('deleting chat_ids', delete_ids) + print('updating chat_ids', update_ids) + + for chat_id in update_ids: + coroutines.add( + self.persistence.update_chat_data(chat_id, deepcopy(self.chat_data[chat_id])) + ) + for chat_id in delete_ids: + coroutines.add(self.persistence.drop_chat_data(chat_id)) if self.persistence.store_data.user_data: - # Mypy can't handle the conditional assignment in `__init__` - user_data = cast(TrackingDefaultDict, self._user_data) - for user_id, data in user_data.pop_accessed_read_items(): - coroutines.add(self.persistence.update_user_data(user_id, deepcopy(data))) - for user_id, data in user_data.pop_accessed_write_items(): - if data is not user_data.DELETED: - _logger.critical('`Application._user_data[%s]` was written manually', user_id) - coroutines.add(self.persistence.update_user_data(user_id, deepcopy(data))) - else: - coroutines.add(self.persistence.drop_user_data(user_id)) + update_ids = self._user_ids_to_be_updated_in_persistence + self._user_ids_to_be_updated_in_persistence = set() + delete_ids = self._user_ids_to_be_deleted_in_persistence + self._user_ids_to_be_deleted_in_persistence = set() + + # We don't want to update any data that has been deleted! + update_ids -= delete_ids + + for user_id in update_ids: + coroutines.add( + self.persistence.update_user_data(user_id, deepcopy(self.user_data[user_id])) + ) + for user_id in delete_ids: + coroutines.add(self.persistence.drop_user_data(user_id)) # Unfortunately due to circular imports this has to be here # pylint: disable=import-outside-toplevel diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index 5c43828a93f..d330bfb4099 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -424,8 +424,9 @@ def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderTyp return self def get_updates_request(self: BuilderType, request: BaseRequest) -> BuilderType: - """Sets a :class:`telegram.request.BaseRequest` object to be used for the ``get_updates_request`` - parameter of :attr:`telegram.ext.Application.bot`. + """Sets a :class:`telegram.request.BaseRequest` object to be used for the + :paramref:`~telegram.Bot.get_updates_request` parameter of + :attr:`telegram.ext.Application.bot`. .. seealso:: :meth:`request` diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index ce775927cad..9a5eb09e1b4 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -624,6 +624,11 @@ async def _run(self, application: 'Application') -> None: await self.callback(context) except Exception as exc: await application.create_task(application.dispatch_error(None, exc, job=self)) + finally: + # This is internal logic of application - let's keep it private for now + application._mark_update_for_persistence_update( # pylint: disable=protected-access + job=self + ) def schedule_removal(self) -> None: """ diff --git a/telegram/ext/_utils/trackingdefaultdict.py b/telegram/ext/_utils/trackingdefaultdict.py index 1b48f4b4206..fc0c429f297 100644 --- a/telegram/ext/_utils/trackingdefaultdict.py +++ b/telegram/ext/_utils/trackingdefaultdict.py @@ -58,6 +58,9 @@ # For methods like `pop`, `get`, `setdefault`, we should also check that we have the same # behavior as defaultdict +# TODO: We currently don't use `track_read=True` anywhere. We might want to drop that for +# ease of maintenance + class TrackingDefaultDict(MutableMapping[_KT, _VT]): """DefaultDict that keeps track of which keys where accessed. From f44ec187630dba28134b6a7f1df7ab1f9a71b6b5 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 14 Feb 2022 17:42:15 +0100 Subject: [PATCH 010/153] Change pool timeout logic in HTTPXRequest --- telegram/error.py | 4 ++ telegram/request/_httpxrequest.py | 72 ++++++------------------ tests/test_request.py | 92 +++++++------------------------ 3 files changed, 40 insertions(+), 128 deletions(-) diff --git a/telegram/error.py b/telegram/error.py index e10b1359458..69aa9024920 100644 --- a/telegram/error.py +++ b/telegram/error.py @@ -74,6 +74,10 @@ def __init__(self, message: str): def __str__(self) -> str: return self.message + # TODO: test this + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.message})' + def __reduce__(self) -> Tuple[type, Tuple[str]]: return self.__class__, (self.message,) diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index ca23c474c36..923d6c08d36 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -16,7 +16,6 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains methods to make POST and GET requests using the httpx library.""" -import asyncio import logging from typing import Tuple, Optional @@ -78,7 +77,7 @@ class HTTPXRequest(BaseRequest): connections in the connection pool! """ - __slots__ = ('_client', '__pool_semaphore') + __slots__ = ('_client',) def __init__( self, @@ -89,8 +88,6 @@ def __init__( write_timeout: Optional[float] = 5.0, pool_timeout: Optional[float] = 1.0, ): - self.__pool_semaphore = asyncio.BoundedSemaphore(connection_pool_size) - timeout = httpx.Timeout( connect=connect_timeout, read=read_timeout, @@ -126,58 +123,21 @@ async def do_request( pool_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, ) -> Tuple[int, bytes]: """See :meth:`BaseRequest.do_request`.""" + if isinstance(read_timeout, DefaultValue): + read_timeout = self._client.timeout.read + if isinstance(write_timeout, DefaultValue): + write_timeout = self._client.timeout.write + if isinstance(connect_timeout, DefaultValue): + connect_timeout = self._client.timeout.connect if isinstance(pool_timeout, DefaultValue): pool_timeout = self._client.timeout.pool - # TODO: This doesn't seem to work. - # if pool_timeout != 0 and self.__pool_semaphore.locked(): - # _logger.debug( - # 'All connections in the pool are currently busy. Waiting pool_timeout=%s for ' - # 'a connection to become available.', - # pool_timeout, - # ) - - try: - await asyncio.wait_for(self.__pool_semaphore.acquire(), timeout=pool_timeout) - except asyncio.TimeoutError as exc: - raise TimedOut('Pool timeout') from exc - - try: - out = await self._do_request( - url=url, - method=method, - pool_timeout=pool_timeout, - request_data=request_data, - connect_timeout=connect_timeout, - read_timeout=read_timeout, - write_timeout=write_timeout, - ) - return out - finally: - self.__pool_semaphore.release() - - async def _do_request( - self, - url: str, - method: str, - pool_timeout: Optional[float], - request_data: RequestData = None, - connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, - read_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, - write_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, - ) -> Tuple[int, bytes]: timeout = httpx.Timeout( - connect=self._client.timeout.connect, - read=self._client.timeout.read, - write=self._client.timeout.write, + connect=connect_timeout, + read=read_timeout, + write=write_timeout, pool=pool_timeout, ) - if not isinstance(read_timeout, DefaultValue): - timeout.read = read_timeout - if not isinstance(write_timeout, DefaultValue): - timeout.write = write_timeout - if not isinstance(connect_timeout, DefaultValue): - timeout.connect = connect_timeout # TODO p0: On Linux, use setsockopt to properly set socket level keepalive. # (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 120) @@ -199,11 +159,13 @@ async def _do_request( ) except httpx.TimeoutException as err: if isinstance(err, httpx.PoolTimeout): - _logger.critical( - 'All connections in the connection pool are occupied. Request was *not* sent ' - 'to Telegram. Adjust connection pool size!', - ) - raise TimedOut('Pool timeout') from err + raise TimedOut( + message=( + 'Pool timeout: All connections in the connection pool are occupied. ' + 'Request was *not* sent to Telegram. Consider adjusting the connection ' + 'pool size or the pool timeout.' + ) + ) from err raise TimedOut from err except httpx.HTTPError as err: # HTTPError must come last as its the base httpx exception class diff --git a/tests/test_request.py b/tests/test_request.py index e75d0b05af4..3490ca4b7dc 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -20,13 +20,13 @@ implementations for BaseRequest and we want to test HTTPXRequest anyway.""" import asyncio import json -import logging from dataclasses import dataclass from http import HTTPStatus from typing import Tuple, Any, Coroutine, Callable import httpx import pytest +from flaky import flaky from telegram._utils.defaultvalue import DEFAULT_NONE from telegram.error import ( @@ -485,8 +485,11 @@ async def make_assertion(self, method, url, headers, timeout, files, data): @pytest.mark.asyncio async def test_do_request_pool_timeout(self, monkeypatch): - async def request(self, **kwargs): - await asyncio.sleep(0.05) + async def request(_, **kwargs): + if self.test_flag is None: + self.test_flag = True + else: + raise httpx.PoolTimeout('pool timeout') return httpx.Response(HTTPStatus.OK) monkeypatch.setattr(httpx.AsyncClient, 'request', request) @@ -499,76 +502,19 @@ async def request(self, **kwargs): ) @pytest.mark.asyncio + @flaky(3, 1) async def test_do_request_wait_for_pool(self, monkeypatch, httpx_request): - async def request(_, **kwargs): - await asyncio.sleep(0.05) - if self.test_flag is None: - self.test_flag = 1 - else: - self.test_flag += 1 - return httpx.Response(HTTPStatus.OK) - - monkeypatch.setattr(httpx.AsyncClient, 'request', request) - - task_1 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) - task_2 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) - await asyncio.sleep(0.07) - assert self.test_flag == 1 - await asyncio.sleep(0.07) - assert self.test_flag == 2 - await asyncio.gather(task_1, task_2) - - @pytest.mark.asyncio - async def test_do_request_wait_for_pool_with_exception(self, monkeypatch, httpx_request): - async def request(_, **kwargs): - await asyncio.sleep(0.05) - if self.test_flag is None: - self.test_flag = 1 - raise RuntimeError('Raising Exception') - else: - self.test_flag += 1 - return httpx.Response(HTTPStatus.OK) + """The pool logic is buried rather deeply in httpxcore, so we make actual requests here + instead of mocking""" - monkeypatch.setattr(httpx.AsyncClient, 'request', request) - - task_1 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) - task_2 = asyncio.create_task(httpx_request.do_request(method='GET', url='URL')) - await asyncio.sleep(0.06) - assert self.test_flag == 1 - await asyncio.sleep(0.06) - assert self.test_flag == 2 - out = await asyncio.gather(task_1, task_2, return_exceptions=True) - assert sum(isinstance(entry, RuntimeError) for entry in out) == 1 - - @pytest.mark.asyncio - async def test_do_request_critical_timeout_logging(self, httpx_request, caplog, monkeypatch): - async def request(self, **kwargs): - raise httpx.PoolTimeout('pool timeout') - - monkeypatch.setattr(httpx.AsyncClient, 'request', request) - - with pytest.raises(TimedOut): - with caplog.at_level(logging.CRITICAL): - await httpx_request.do_request(method='GET', url='URL') - - assert len(caplog.records) == 1 - assert ( - 'All connections in the connection pool are occupied' in caplog.records[0].getMessage() + task_1 = httpx_request.do_request( + method='GET', url='https://python-telegram-bot.org/static/testfiles/telegram.mp4' ) - - # @pytest.mark.asyncio - # async def test_do_request_busy_logging(self, httpx_request, caplog, monkeypatch): - # async def request(self, **kwargs): - # await asyncio.sleep(0.5) - # return httpx.Response(HTTPStatus.OK) - # - # monkeypatch.setattr(httpx.AsyncClient, 'request', request) - # - # with caplog.at_level(logging.DEBUG): - # await asyncio.gather( - # httpx_request.do_request(method='GET', url='URL'), - # httpx_request.do_request(method='GET', url='URL'), - # ) - # - # assert len(caplog.records) == 1 - # assert 'currently busy. Waiting pool_timeout' in caplog.records[0].getMessage() + task_2 = httpx_request.do_request( + method='GET', url='https://python-telegram-bot.org/static/testfiles/telegram.mp4' + ) + done, pending = await asyncio.wait({task_1, task_2}, return_when=asyncio.FIRST_COMPLETED) + assert len(done) == len(pending) == 1 + done, pending = await asyncio.wait({task_1, task_2}, return_when=asyncio.ALL_COMPLETED) + assert len(done) == 2 + assert len(pending) == 0 From 6dbafc2d3f161dc48d69845997a9d690b4fa166a Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 15 Feb 2022 20:32:47 +0100 Subject: [PATCH 011/153] Handle unknown RequestResponse paramas --- telegram/ext/_conversationhandler.py | 4 +++- telegram/request/_baserequest.py | 12 +++++++----- tests/test_request.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index cf944220f6c..306ee03ac6a 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -75,7 +75,9 @@ class _ConversationTimeoutContext(Generic[CCT]): @dataclass class PendingState: """Thin wrapper around asyncio.Task to handle block=False handlers. Note that this is a - public class of this module, since `Application.update_persistence` needs to access it.""" + public class of this module, since `Application.update_persistence` needs to access it. + It's still hidden from users, since this module itself is private. + """ __slots__ = ('task', 'old_state') diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index 00f37faebe3..18d643fdd47 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -257,6 +257,12 @@ async def _request_wrapper( response_data = self._parse_json_response(payload) + description = response_data.get('description') + if description: + message = description + else: + message = 'Unknown HTTPError' + # In some special cases, we ca raise more informative exceptions: # see https://core.telegram.org/bots/api#responseparameters and # https://core.telegram.org/bots/api#making-requests @@ -269,11 +275,7 @@ async def _request_wrapper( if retry_after: raise RetryAfter(retry_after) - description = response_data.get('description') - if description: - message = description - else: - message = 'Unknown HTTPError' + message += f'\nThe server response contained unknown parameters: {parameters}' if code == HTTPStatus.FORBIDDEN: raise Forbidden(message) diff --git a/tests/test_request.py b/tests/test_request.py index 3490ca4b7dc..ac6e868e086 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -173,6 +173,22 @@ async def test_retry_after(self, monkeypatch, httpx_request: HTTPXRequest): assert exc_info.value.retry_after == 42.0 + @pytest.mark.asyncio + async def test_unknown_request_params(self, monkeypatch, httpx_request: HTTPXRequest): + server_response = b'{"ok": "False", "parameters": {"unknown": "42"}}' + + monkeypatch.setattr( + httpx_request, + 'do_request', + mocker_factory(response=server_response, return_code=HTTPStatus.BAD_REQUEST), + ) + + with pytest.raises( + BadRequest, + match="{'unknown': '42'}", + ): + await httpx_request.post(None, None, None) + @pytest.mark.asyncio @pytest.mark.parametrize('description', [True, False]) async def test_error_description(self, monkeypatch, httpx_request: HTTPXRequest, description): From 65410907b99d1baba508a691198804b98bbff8ec Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 15 Feb 2022 22:06:42 +0100 Subject: [PATCH 012/153] Work on TrackingDict + tests --- telegram/ext/_application.py | 8 +- telegram/ext/_conversationhandler.py | 13 +- telegram/ext/_utils/trackingdefaultdict.py | 222 --------------------- telegram/ext/_utils/trackingdict.py | 125 ++++++++++++ tests/test_trackingdict.py | 161 +++++++++++++++ 5 files changed, 294 insertions(+), 235 deletions(-) delete mode 100644 telegram/ext/_utils/trackingdefaultdict.py create mode 100644 telegram/ext/_utils/trackingdict.py create mode 100644 tests/test_trackingdict.py diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 533b65a373f..be5bbfef4f2 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -52,7 +52,7 @@ from telegram.ext._callbackdatacache import CallbackDataCache from telegram._utils.defaultvalue import DefaultValue, DEFAULT_TRUE, DEFAULT_NONE from telegram._utils.warnings import warn -from telegram.ext._utils.trackingdefaultdict import TrackingDefaultDict +from telegram.ext._utils.trackingdict import TrackingDict from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ, HandlerCallback from telegram.ext._utils.stack import was_called_by @@ -248,7 +248,7 @@ def __init__( # This attribute will hold references to the conversation dicts of all conversation # handlers so that we can extract the changed states during `update_persistence` self._conversation_handler_conversations: Dict[ - str, TrackingDefaultDict[Tuple[int, ...], object] + str, TrackingDict[Tuple[int, ...], object] ] = {} # A number of low-level helpers for the internal logic @@ -1034,10 +1034,10 @@ async def __update_persistence(self) -> None: else: result = new_state.resolve() - effective_new_state = None if result is TrackingDefaultDict.DELETED else result + effective_new_state = None if result is TrackingDict.DELETED else result print(name, key, effective_new_state) # TODO: Test that we actually pass `None` here in case the conversation had ended, - # i.e. effective_new_state is TrackingDefaultDict.DELETED + # i.e. effective_new_state is TrackingDict.DELETED coroutines.add( self.persistence.update_conversation( name=name, key=key, new_state=effective_new_state diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 306ee03ac6a..1d08a17d696 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -51,7 +51,7 @@ TypeHandler, ) from telegram._utils.warnings import warn -from telegram.ext._utils.trackingdefaultdict import TrackingDefaultDict +from telegram.ext._utils.trackingdict import TrackingDict from telegram.ext._utils.types import ConversationDict from telegram.ext._utils.types import CCT @@ -525,7 +525,7 @@ def map_to_parent(self, value: object) -> NoReturn: async def _initialize_persistence( self, application: 'Application' - ) -> TrackingDefaultDict[Tuple[int, ...], object]: + ) -> TrackingDict[Tuple[int, ...], object]: """Initializes the persistence for this handler. While this method is marked as protected, we expect it to be called by the Application/parent conversations. It's just protected to hide it from users. @@ -540,14 +540,9 @@ async def _initialize_persistence( 'persistence!' ) - def default_factory() -> NoReturn: - raise KeyError - self._conversations = cast( - TrackingDefaultDict[Tuple[int, ...], object], - TrackingDefaultDict( - default_factory=default_factory, track_read=False, track_write=True - ), + TrackingDict[Tuple[int, ...], object], + TrackingDict(), ) self._conversations.update(await application.persistence.get_conversations(self.name)) diff --git a/telegram/ext/_utils/trackingdefaultdict.py b/telegram/ext/_utils/trackingdefaultdict.py deleted file mode 100644 index fc0c429f297..00000000000 --- a/telegram/ext/_utils/trackingdefaultdict.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python -# -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains subclasses of :class:`collections.defaultdict` that keeps track of the -keys that where accessed. - -.. versionadded:: 14.0 - -Warning: - Contents of this module are intended to be used internally by the library and *not* by the - user. Changes to this module are not considered breaking changes and may not be documented in - the changelog. -""" -from typing import ( - TypeVar, - DefaultDict, - Callable, - Set, - ClassVar, - Iterator, - Optional, - Union, - Tuple, - overload, - MutableMapping, - List, - Mapping, -) -from collections import defaultdict - -from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue - -_VT = TypeVar('_VT') -_KT = TypeVar('_KT') -_T = TypeVar('_T') - - -# TODO: Implement tests for this class that cover all methods implemented by (Mutable)Mapping and -# check if they give the correct behavior in terms of keeping track on the access. This includes -# __eq__ & access through Key/ItemViews -# We should also test that all this behavior stays the same when accessing the mapping through -# a MappingProxyType -# For methods like `pop`, `get`, `setdefault`, we should also check that we have the same -# behavior as defaultdict - -# TODO: We currently don't use `track_read=True` anywhere. We might want to drop that for -# ease of maintenance - - -class TrackingDefaultDict(MutableMapping[_KT, _VT]): - """DefaultDict that keeps track of which keys where accessed. - - Note: - * ``key in tdd`` is not considered reading - * ``setdefault()`` is considered both reading and writing depending on - whether or not the key is present - * ``pop`` is only considered writing, since the value is deleted instead of being changed - - Args: - default_factory (Callable): Default factory for missing entries - track_read (:obj:`bool`): Whether read access should be tracked. Deleting entries is - not considered reading. - track_write (:obj:`bool`): Whether write access should be tracked. Deleting entries is - considered writing. - """ - - DELETED: ClassVar = object() - """Special marker indicating that an entry was deleted.""" - - __slots__ = ('_data', '_write_access_keys', '_read_access_keys', 'track_read', 'track_write') - - def __init__(self, default_factory: Callable[[], _VT], track_read: bool, track_write: bool): - # The default_factory argument for defaultdict is positional only! - self._data: DefaultDict[_KT, _VT] = defaultdict(default_factory) - self.track_read = track_read - self.track_write = track_write - self._write_access_keys: Set[_KT] = set() - self._read_access_keys: Set[_KT] = set() - - def __track_read(self, key: Union[_KT, Set[_KT]]) -> None: - if self.track_read: - if isinstance(key, set): - self._read_access_keys |= key - else: - self._read_access_keys.add(key) - - def __track_write(self, key: Union[_KT, Set[_KT]]) -> None: - if self.track_write: - if isinstance(key, set): - self._write_access_keys |= key - else: - self._write_access_keys.add(key) - - def __repr__(self) -> str: - return repr(self._data) - - def __str__(self) -> str: - return str(self._data) - - def __eq__(self, other: object) -> bool: - return other == self._data - - def pop_accessed_read_keys(self) -> Set[_KT]: - """Returns all keys that were read-accessed since the last time this method was called.""" - if not self.track_read: - raise RuntimeError('Not tracking read access!') - - out = self._read_access_keys - self._read_access_keys = set() - return out - - def pop_accessed_write_keys(self) -> Set[_KT]: - """Returns all keys that were write-accessed since the last time this method was called.""" - if not self.track_write: - raise RuntimeError('Not tracking write access!') - - out = self._write_access_keys - self._write_access_keys = set() - return out - - def pop_accessed_read_items(self) -> List[Tuple[_KT, _VT]]: - """ - Returns all keys & corresponding values as set of tuples that were read-accessed since - the last time this method was called. - """ - keys = self.pop_accessed_read_keys() - return [(key, self._data[key]) for key in keys] - - def pop_accessed_write_items(self) -> List[Tuple[_KT, _VT]]: - """ - Returns all keys & corresponding values as set of tuples that were write-accessed since - the last time this method was called. If a key was deleted, the value will be - :attr:`DELETED`. - """ - keys = self.pop_accessed_write_keys() - return [(key, self._data[key] if key in self._data else self.DELETED) for key in keys] - - # Implement abstract interface - - def __getitem__(self, key: _KT) -> _VT: - item = self._data[key] - self.__track_read(key) - return item - - def __setitem__(self, key: _KT, value: _VT) -> None: - self._data[key] = value - self.__track_write(key) - - def __delitem__(self, key: _KT) -> None: - del self._data[key] - self.__track_write(key) - - def __iter__(self) -> Iterator[_KT]: - for key in self._data: - self.__track_read(key) - yield key - - def __len__(self) -> int: - return len(self._data) - - def update_no_track(self, mapping: Mapping[_KT, _VT]) -> None: - return self._data.update(mapping) - - # Override some methods so that they fit better with the read/write access book keeping - - def __contains__(self, key: object) -> bool: - return key in self._data - - # Mypy seems a bit inconsistent about what it wants as types for `default` and return value - # so we just ignore a bit - def pop( # type: ignore[override] - self, key: _KT, default: _VT = DEFAULT_NONE # type: ignore[assignment] - ) -> _VT: - self.__track_write(key) - if isinstance(default, DefaultValue): - return self._data.pop(key) - return self._data.pop(key, default=default) - - def clear(self) -> None: - self.__track_write(set(self._data.keys())) - self._data.clear() - - # Mypy seems a bit inconsistent about what it wants as types for `default` and return value - # so we just ignore a bit - def setdefault(self: 'TrackingDefaultDict[_KT, _T]', key: _KT, default: _T = None) -> _T: - if key in self._data: - self.__track_read(key) - return self._data[key] - - self.__track_write(key) - self._data[key] = default # type: ignore[assignment] - return default # type: ignore[return-value] - - # Overriding to comply with the behavior of `defaultdict` - - @overload - def get(self, key: _KT) -> Optional[_VT]: # pylint: disable=arguments-differ - ... - - @overload - def get(self, key: _KT, default: _T) -> _T: # pylint: disable=signature-differs - ... - - def get(self, key: _KT, default: _T = None) -> Optional[Union[_VT, _T]]: - if key in self: - return self[key] - return default diff --git a/telegram/ext/_utils/trackingdict.py b/telegram/ext/_utils/trackingdict.py new file mode 100644 index 00000000000..f65e20d080c --- /dev/null +++ b/telegram/ext/_utils/trackingdict.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains a mutable mapping that keeps track of the keys that where accessed. + +.. versionadded:: 14.0 + +Warning: + Contents of this module are intended to be used internally by the library and *not* by the + user. Changes to this module are not considered breaking changes and may not be documented in + the changelog. +""" +from typing import ( + TypeVar, + Set, + ClassVar, + Union, + Tuple, + List, + Mapping, + Generic, +) +from collections import UserDict + +from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue + +_VT = TypeVar('_VT') +_KT = TypeVar('_KT') +_T = TypeVar('_T') + + +class TrackingDict(UserDict, Generic[_KT, _VT]): + """Mutable mapping that keeps track of which keys where accessed with write access. + Read-access is not tracked. + + Note: + * ``setdefault()`` and ``pop`` are considered writing only depending on whether or not the + key is present + * deleting values is considered writing + """ + + DELETED: ClassVar = object() + """Special marker indicating that an entry was deleted.""" + + __slots__ = ('_data', '_write_access_keys') + + def __init__(self) -> None: + super().__init__() + self._write_access_keys: Set[_KT] = set() + + def __track_write(self, key: Union[_KT, Set[_KT]]) -> None: + if isinstance(key, set): + self._write_access_keys |= key + else: + self._write_access_keys.add(key) + + def pop_accessed_keys(self) -> Set[_KT]: + """Returns all keys that were write-accessed since the last time this method was called.""" + out = self._write_access_keys + self._write_access_keys = set() + return out + + def pop_accessed_write_items(self) -> List[Tuple[_KT, _VT]]: + """ + Returns all keys & corresponding values as set of tuples that were write-accessed since + the last time this method was called. If a key was deleted, the value will be + :attr:`DELETED`. + """ + keys = self.pop_accessed_keys() + return [(key, self[key] if key in self else self.DELETED) for key in keys] + + # Override methods to track access + + def __setitem__(self, key: _KT, value: _VT) -> None: + self.__track_write(key) + super().__setitem__(key, value) + + def __delitem__(self, key: _KT) -> None: + self.__track_write(key) + super().__delitem__(key) + + def update_no_track(self, mapping: Mapping[_KT, _VT]) -> None: + """Like ``update``, but doesn't count towards write access.""" + for key, value in mapping.items(): + self.data[key] = value + + # Mypy seems a bit inconsistent about what it wants as types for `default` and return value + # so we just ignore a bit + def pop( # type: ignore[override] + self, key: _KT, default: _VT = DEFAULT_NONE # type: ignore[assignment] + ) -> _VT: + if key in self: + self.__track_write(key) + if isinstance(default, DefaultValue): + return super().pop(key) + return super().pop(key, default=default) + + def clear(self) -> None: + self.__track_write(set(super().keys())) + super().clear() + + # Mypy seems a bit inconsistent about what it wants as types for `default` and return value + # so we just ignore a bit + def setdefault(self: 'TrackingDict[_KT, _T]', key: _KT, default: _T = None) -> _T: + if key in self: + return self[key] + + self.__track_write(key) + self[key] = default # type: ignore[assignment] + return default # type: ignore[return-value] diff --git a/tests/test_trackingdict.py b/tests/test_trackingdict.py new file mode 100644 index 00000000000..7f8849a693f --- /dev/null +++ b/tests/test_trackingdict.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. + +import pytest + +from telegram.ext._utils.trackingdict import TrackingDict + + +@pytest.fixture(scope='function') +def td() -> TrackingDict: + td = TrackingDict() + td.update_no_track({1: 1}) + return td + + +@pytest.fixture(scope='function') +def data() -> dict: + return {1: 1} + + +class TestTrackingDict: + def test_representations(self, td, data): + assert repr(td) == repr(data) + assert str(td) == str(data) + + def test_len(self, td, data): + assert len(td) == len(data) + + def test_boolean(self, td, data): + assert bool(td) == bool(data) + assert bool(TrackingDict()) == bool({}) + + def test_equality(self, td, data): + assert td == data + assert data == td + assert td != TrackingDict() + assert TrackingDict() != td + td_2 = TrackingDict() + td_2['foo'] = 7 + assert td != td_2 + assert td_2 != td + assert td != 1 + assert 1 != td + assert td != 5 + assert 5 != td + + def test_getitem(self, td): + assert td[1] == 1 + assert not td.pop_accessed_write_items() + assert not td.pop_accessed_keys() + + def test_setitem(self, td): + td[5] = 5 + assert td[5] == 5 + assert td.pop_accessed_write_items() == [(5, 5)] + td[5] = 7 + assert td[5] == 7 + assert td.pop_accessed_keys() == {5} + + def test_delitem(self, td): + assert not td.pop_accessed_keys() + td[5] = 7 + del td[1] + assert 1 not in td + assert td.pop_accessed_keys() == {1, 5} + td[1] = 7 + td[5] = 7 + assert td.pop_accessed_keys() == {1, 5} + del td[5] + assert 5 not in td + assert td.pop_accessed_write_items() == [(5, TrackingDict.DELETED)] + + def test_update_no_track(self, td): + assert not td.pop_accessed_keys() + td.update_no_track({2: 2, 3: 3}) + assert td == {1: 1, 2: 2, 3: 3} + assert not td.pop_accessed_keys() + + def test_pop(self, td): + td.pop(1) + assert 1 not in td + assert td.pop_accessed_keys() == {1} + td[1] = 7 + td[5] = 8 + assert 1 in td + assert 5 in td + assert td.pop_accessed_keys() == {1, 5} + td.pop(5) + assert 5 not in td + assert td.pop_accessed_write_items() == [(5, TrackingDict.DELETED)] + + with pytest.raises(KeyError): + td.pop(5) + + assert td.pop(5, 8) == 8 + assert 5 not in td + assert not td.pop_accessed_keys() + + assert td.pop(5, 8) == 8 + assert 5 not in td + assert not td.pop_accessed_write_items() + + def test_popitem(self, td): + td.update_no_track({2: 2}) + assert td.popitem() == (1, 1) + assert 1 not in td + assert td.pop_accessed_keys() == {1} + + assert td.popitem() == (2, 2) + assert 2 not in td + assert not td + assert td.pop_accessed_write_items() == [(2, TrackingDict.DELETED)] + + with pytest.raises(KeyError): + td.popitem() + + def test_clear(self, td): + td.clear() + assert td == {} + assert td.pop_accessed_keys() == {1} + td[5] = 7 + assert 5 in td + assert td.pop_accessed_keys() == {5} + td.clear() + assert td == {} + assert td.pop_accessed_write_items() == [(5, TrackingDict.DELETED)] + + def test_set_default(self, td): + assert td.setdefault(1, 2) == 1 + assert td[1] == 1 + assert not td.pop_accessed_keys() + assert not td.pop_accessed_write_items() + + assert td.setdefault(2, 3) == 3 + assert td[2] == 3 + assert td.pop_accessed_keys() == {2} + assert td.setdefault(3, 4) == 4 + assert td[3] == 4 + assert td.pop_accessed_write_items() == [(3, 4)] + + def test_iter(self, td, data): + data.update({2: 2, 3: 3, 4: 4}) + td.update_no_track({2: 2, 3: 3, 4: 4}) + assert not td.pop_accessed_keys() + assert list(iter(td)) == list(iter(data)) From fb76b4fbde80ef7cb16cf98206409ba2d94579dd Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 15 Feb 2022 22:09:12 +0100 Subject: [PATCH 013/153] Re-add test_callbackdatacache.py --- tests/test_callbackdatacache.py | 381 ++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) create mode 100644 tests/test_callbackdatacache.py diff --git a/tests/test_callbackdatacache.py b/tests/test_callbackdatacache.py new file mode 100644 index 00000000000..1d97022d29c --- /dev/null +++ b/tests/test_callbackdatacache.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import time +from copy import deepcopy +from datetime import datetime +from uuid import uuid4 + +import pytest +import pytz + +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, CallbackQuery, Message, User +from telegram.ext._callbackdatacache import ( + CallbackDataCache, + _KeyboardData, + InvalidCallbackData, +) + + +@pytest.fixture(scope='function') +def callback_data_cache(bot): + return CallbackDataCache(bot) + + +class TestInvalidCallbackData: + def test_slot_behaviour(self, mro_slots): + invalid_callback_data = InvalidCallbackData() + for attr in invalid_callback_data.__slots__: + assert getattr(invalid_callback_data, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(invalid_callback_data)) == len( + set(mro_slots(invalid_callback_data)) + ), "duplicate slot" + + +class TestKeyboardData: + def test_slot_behaviour(self, mro_slots): + keyboard_data = _KeyboardData('uuid') + for attr in keyboard_data.__slots__: + assert getattr(keyboard_data, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(keyboard_data)) == len( + set(mro_slots(keyboard_data)) + ), "duplicate slot" + + +class TestCallbackDataCache: + def test_slot_behaviour(self, callback_data_cache, mro_slots): + for attr in callback_data_cache.__slots__: + attr = ( + f"_CallbackDataCache{attr}" + if attr.startswith('__') and not attr.endswith('__') + else attr + ) + assert getattr(callback_data_cache, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(callback_data_cache)) == len( + set(mro_slots(callback_data_cache)) + ), "duplicate slot" + + @pytest.mark.parametrize('maxsize', [1, 5, 2048]) + def test_init_maxsize(self, maxsize, bot): + assert CallbackDataCache(bot).maxsize == 1024 + cdc = CallbackDataCache(bot, maxsize=maxsize) + assert cdc.maxsize == maxsize + assert cdc.bot is bot + + def test_init_and_access__persistent_data(self, bot): + keyboard_data = _KeyboardData('123', 456, {'button': 678}) + persistent_data = ([keyboard_data.to_tuple()], {'id': '123'}) + cdc = CallbackDataCache(bot, persistent_data=persistent_data) + + assert cdc.maxsize == 1024 + assert dict(cdc._callback_queries) == {'id': '123'} + assert list(cdc._keyboard_data.keys()) == ['123'] + assert cdc._keyboard_data['123'].keyboard_uuid == '123' + assert cdc._keyboard_data['123'].access_time == 456 + assert cdc._keyboard_data['123'].button_data == {'button': 678} + + assert cdc.persistence_data == persistent_data + + def test_process_keyboard(self, callback_data_cache): + changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') + changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') + non_changing_button = InlineKeyboardButton('non-changing', url='https://ptb.org') + reply_markup = InlineKeyboardMarkup.from_row( + [non_changing_button, changing_button_1, changing_button_2] + ) + + out = callback_data_cache.process_keyboard(reply_markup) + assert out.inline_keyboard[0][0] is non_changing_button + assert out.inline_keyboard[0][1] != changing_button_1 + assert out.inline_keyboard[0][2] != changing_button_2 + + keyboard_1, button_1 = callback_data_cache.extract_uuids( + out.inline_keyboard[0][1].callback_data + ) + keyboard_2, button_2 = callback_data_cache.extract_uuids( + out.inline_keyboard[0][2].callback_data + ) + assert keyboard_1 == keyboard_2 + assert ( + callback_data_cache._keyboard_data[keyboard_1].button_data[button_1] == 'some data 1' + ) + assert ( + callback_data_cache._keyboard_data[keyboard_2].button_data[button_2] == 'some data 2' + ) + + def test_process_keyboard_no_changing_button(self, callback_data_cache): + reply_markup = InlineKeyboardMarkup.from_button( + InlineKeyboardButton('non-changing', url='https://ptb.org') + ) + assert callback_data_cache.process_keyboard(reply_markup) is reply_markup + + def test_process_keyboard_full(self, bot): + cdc = CallbackDataCache(bot, maxsize=1) + changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') + changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') + non_changing_button = InlineKeyboardButton('non-changing', url='https://ptb.org') + reply_markup = InlineKeyboardMarkup.from_row( + [non_changing_button, changing_button_1, changing_button_2] + ) + + out1 = cdc.process_keyboard(reply_markup) + assert len(cdc.persistence_data[0]) == 1 + out2 = cdc.process_keyboard(reply_markup) + assert len(cdc.persistence_data[0]) == 1 + + keyboard_1, button_1 = cdc.extract_uuids(out1.inline_keyboard[0][1].callback_data) + keyboard_2, button_2 = cdc.extract_uuids(out2.inline_keyboard[0][2].callback_data) + assert cdc.persistence_data[0][0][0] != keyboard_1 + assert cdc.persistence_data[0][0][0] == keyboard_2 + + @pytest.mark.parametrize('data', [True, False]) + @pytest.mark.parametrize('message', [True, False]) + @pytest.mark.parametrize('invalid', [True, False]) + def test_process_callback_query(self, callback_data_cache, data, message, invalid): + """This also tests large parts of process_message""" + changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') + changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') + non_changing_button = InlineKeyboardButton('non-changing', url='https://ptb.org') + reply_markup = InlineKeyboardMarkup.from_row( + [non_changing_button, changing_button_1, changing_button_2] + ) + + out = callback_data_cache.process_keyboard(reply_markup) + if invalid: + callback_data_cache.clear_callback_data() + + effective_message = Message(message_id=1, date=None, chat=None, reply_markup=out) + effective_message.reply_to_message = deepcopy(effective_message) + effective_message.pinned_message = deepcopy(effective_message) + cq_id = uuid4().hex + callback_query = CallbackQuery( + cq_id, + from_user=None, + chat_instance=None, + # not all CallbackQueries have callback_data + data=out.inline_keyboard[0][1].callback_data if data else None, + # CallbackQueries from inline messages don't have the message attached, so we test that + message=effective_message if message else None, + ) + callback_data_cache.process_callback_query(callback_query) + + if not invalid: + if data: + assert callback_query.data == 'some data 1' + # make sure that we stored the mapping CallbackQuery.id -> keyboard_uuid correctly + assert len(callback_data_cache._keyboard_data) == 1 + assert ( + callback_data_cache._callback_queries[cq_id] + == list(callback_data_cache._keyboard_data.keys())[0] + ) + else: + assert callback_query.data is None + if message: + for msg in ( + callback_query.message, + callback_query.message.reply_to_message, + callback_query.message.pinned_message, + ): + assert msg.reply_markup == reply_markup + else: + if data: + assert isinstance(callback_query.data, InvalidCallbackData) + else: + assert callback_query.data is None + if message: + for msg in ( + callback_query.message, + callback_query.message.reply_to_message, + callback_query.message.pinned_message, + ): + assert isinstance( + msg.reply_markup.inline_keyboard[0][1].callback_data, + InvalidCallbackData, + ) + assert isinstance( + msg.reply_markup.inline_keyboard[0][2].callback_data, + InvalidCallbackData, + ) + + @pytest.mark.parametrize('pass_from_user', [True, False]) + @pytest.mark.parametrize('pass_via_bot', [True, False]) + def test_process_message_wrong_sender(self, pass_from_user, pass_via_bot, callback_data_cache): + reply_markup = InlineKeyboardMarkup.from_button( + InlineKeyboardButton('test', callback_data='callback_data') + ) + user = User(1, 'first', False) + message = Message( + 1, + None, + None, + from_user=user if pass_from_user else None, + via_bot=user if pass_via_bot else None, + reply_markup=reply_markup, + ) + callback_data_cache.process_message(message) + if pass_from_user or pass_via_bot: + # Here we can determine that the message is not from our bot, so no replacing + assert message.reply_markup.inline_keyboard[0][0].callback_data == 'callback_data' + else: + # Here we have no chance to know, so InvalidCallbackData + assert isinstance( + message.reply_markup.inline_keyboard[0][0].callback_data, InvalidCallbackData + ) + + @pytest.mark.parametrize('pass_from_user', [True, False]) + def test_process_message_inline_mode(self, pass_from_user, callback_data_cache): + """Check that via_bot tells us correctly that our bot sent the message, even if + from_user is not our bot.""" + reply_markup = InlineKeyboardMarkup.from_button( + InlineKeyboardButton('test', callback_data='callback_data') + ) + user = User(1, 'first', False) + message = Message( + 1, + None, + None, + from_user=user if pass_from_user else None, + via_bot=callback_data_cache.bot.bot, + reply_markup=callback_data_cache.process_keyboard(reply_markup), + ) + callback_data_cache.process_message(message) + # Here we can determine that the message is not from our bot, so no replacing + assert message.reply_markup.inline_keyboard[0][0].callback_data == 'callback_data' + + def test_process_message_no_reply_markup(self, callback_data_cache): + message = Message(1, None, None) + callback_data_cache.process_message(message) + assert message.reply_markup is None + + def test_drop_data(self, callback_data_cache): + changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') + changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') + reply_markup = InlineKeyboardMarkup.from_row([changing_button_1, changing_button_2]) + + out = callback_data_cache.process_keyboard(reply_markup) + callback_query = CallbackQuery( + '1', + from_user=None, + chat_instance=None, + data=out.inline_keyboard[0][1].callback_data, + ) + callback_data_cache.process_callback_query(callback_query) + + assert len(callback_data_cache.persistence_data[1]) == 1 + assert len(callback_data_cache.persistence_data[0]) == 1 + + callback_data_cache.drop_data(callback_query) + assert len(callback_data_cache.persistence_data[1]) == 0 + assert len(callback_data_cache.persistence_data[0]) == 0 + + def test_drop_data_missing_data(self, callback_data_cache): + changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') + changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') + reply_markup = InlineKeyboardMarkup.from_row([changing_button_1, changing_button_2]) + + out = callback_data_cache.process_keyboard(reply_markup) + callback_query = CallbackQuery( + '1', + from_user=None, + chat_instance=None, + data=out.inline_keyboard[0][1].callback_data, + ) + + with pytest.raises(KeyError, match='CallbackQuery was not found in cache.'): + callback_data_cache.drop_data(callback_query) + + callback_data_cache.process_callback_query(callback_query) + callback_data_cache.clear_callback_data() + callback_data_cache.drop_data(callback_query) + assert callback_data_cache.persistence_data == ([], {}) + + @pytest.mark.parametrize('method', ('callback_data', 'callback_queries')) + def test_clear_all(self, callback_data_cache, method): + changing_button_1 = InlineKeyboardButton('changing', callback_data='some data 1') + changing_button_2 = InlineKeyboardButton('changing', callback_data='some data 2') + reply_markup = InlineKeyboardMarkup.from_row([changing_button_1, changing_button_2]) + + for i in range(100): + out = callback_data_cache.process_keyboard(reply_markup) + callback_query = CallbackQuery( + str(i), + from_user=None, + chat_instance=None, + data=out.inline_keyboard[0][1].callback_data, + ) + callback_data_cache.process_callback_query(callback_query) + + if method == 'callback_data': + callback_data_cache.clear_callback_data() + # callback_data was cleared, callback_queries weren't + assert len(callback_data_cache.persistence_data[0]) == 0 + assert len(callback_data_cache.persistence_data[1]) == 100 + else: + callback_data_cache.clear_callback_queries() + # callback_queries were cleared, callback_data wasn't + assert len(callback_data_cache.persistence_data[0]) == 100 + assert len(callback_data_cache.persistence_data[1]) == 0 + + @pytest.mark.parametrize('time_method', ['time', 'datetime', 'defaults']) + def test_clear_cutoff(self, callback_data_cache, time_method, tz_bot): + # Fill the cache with some fake data + for i in range(50): + reply_markup = InlineKeyboardMarkup.from_button( + InlineKeyboardButton('changing', callback_data=str(i)) + ) + out = callback_data_cache.process_keyboard(reply_markup) + callback_query = CallbackQuery( + str(i), + from_user=None, + chat_instance=None, + data=out.inline_keyboard[0][0].callback_data, + ) + callback_data_cache.process_callback_query(callback_query) + + # sleep a bit before saving the time cutoff, to make test more reliable + time.sleep(0.1) + if time_method == 'time': + cutoff = time.time() + elif time_method == 'datetime': + cutoff = datetime.now(pytz.utc) + else: + cutoff = datetime.now(tz_bot.defaults.tzinfo).replace(tzinfo=None) + callback_data_cache.bot = tz_bot + time.sleep(0.1) + + # more fake data after the time cutoff + for i in range(50, 100): + reply_markup = InlineKeyboardMarkup.from_button( + InlineKeyboardButton('changing', callback_data=str(i)) + ) + out = callback_data_cache.process_keyboard(reply_markup) + callback_query = CallbackQuery( + str(i), + from_user=None, + chat_instance=None, + data=out.inline_keyboard[0][0].callback_data, + ) + callback_data_cache.process_callback_query(callback_query) + + callback_data_cache.clear_callback_data(time_cutoff=cutoff) + assert len(callback_data_cache.persistence_data[0]) == 50 + assert len(callback_data_cache.persistence_data[1]) == 100 + callback_data = [ + list(data[2].values())[0] for data in callback_data_cache.persistence_data[0] + ] + assert callback_data == list(str(i) for i in range(50, 100)) From 70820eacb33a83508fb68644d73c4118c5d86054 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 15 Feb 2022 22:11:12 +0100 Subject: [PATCH 014/153] re-add test_filters.py --- tests/test_filters.py | 2274 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 2274 insertions(+) create mode 100644 tests/test_filters.py diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 00000000000..853460730c9 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,2274 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import datetime + +import pytest + +from telegram import Message, User, Chat, MessageEntity, Document, Update, Dice, CallbackQuery +from telegram.ext import filters +import inspect +import re + + +@pytest.fixture(scope='function') +def update(): + return Update( + 0, + Message( + 0, + datetime.datetime.utcnow(), + Chat(0, 'private'), + from_user=User(0, 'Testuser', False), + via_bot=User(0, "Testbot", True), + sender_chat=Chat(0, 'Channel'), + forward_from=User(0, "HAL9000", False), + forward_from_chat=Chat(0, "Channel"), + ), + ) + + +@pytest.fixture(scope='function', params=MessageEntity.ALL_TYPES) +def message_entity(request): + return MessageEntity(request.param, 0, 0, url='', user=User(1, 'first_name', False)) + + +@pytest.fixture( + scope='class', + params=[{'class': filters.MessageFilter}, {'class': filters.UpdateFilter}], + ids=['MessageFilter', 'UpdateFilter'], +) +def base_class(request): + return request.param['class'] + + +class TestFilters: + def test_all_filters_slot_behaviour(self, mro_slots): + """ + Use depth first search to get all nested filters, and instantiate them (which need it) with + the correct number of arguments, then test each filter separately. Also tests setting + custom attributes on custom filters. + """ + + def filter_class(obj): + return True if inspect.isclass(obj) and "filters" in repr(obj) else False + + # The total no. of filters is about 72 as of 31/10/21. + # Gather all the filters to test using DFS- + visited = [] + classes = inspect.getmembers(filters, predicate=filter_class) # List[Tuple[str, type]] + stack = classes.copy() + while stack: + cls = stack[-1][-1] # get last element and its class + for inner_cls in inspect.getmembers( + cls, # Get inner filters + lambda a: inspect.isclass(a) and not issubclass(a, cls.__class__), + ): + if inner_cls[1] not in visited: + stack.append(inner_cls) + visited.append(inner_cls[1]) + classes.append(inner_cls) + break + else: + stack.pop() + + # Now start the actual testing + for name, cls in classes: + # Can't instantiate abstract classes without overriding methods, so skip them for now + exclude = {'_MergedFilter', '_XORFilter'} + if inspect.isabstract(cls) or name in {'__class__', '__base__'} | exclude: + continue + + assert '__slots__' in cls.__dict__, f"Filter {name!r} doesn't have __slots__" + # get no. of args minus the 'self', 'args' and 'kwargs' argument + init_sig = inspect.signature(cls.__init__).parameters + extra = 0 + for param in init_sig: + if param in {'self', 'args', 'kwargs'}: + extra += 1 + args = len(init_sig) - extra + + if not args: + inst = cls() + elif args == 1: + inst = cls('1') + else: + inst = cls(*['blah']) + + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), f"same slot in {name}" + + for attr in cls.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}' for {name}" + + def test_filters_all(self, update): + assert filters.ALL.check_update(update) + + def test_filters_text(self, update): + update.message.text = 'test' + assert filters.TEXT.check_update(update) + update.message.text = '/test' + assert filters.Text().check_update(update) + + def test_filters_text_strings(self, update): + update.message.text = '/test' + assert filters.Text(('/test', 'test1')).check_update(update) + assert not filters.Text(['test1', 'test2']).check_update(update) + + def test_filters_caption(self, update): + update.message.caption = 'test' + assert filters.CAPTION.check_update(update) + update.message.caption = None + assert not filters.CAPTION.check_update(update) + + def test_filters_caption_strings(self, update): + update.message.caption = 'test' + assert filters.Caption(('test', 'test1')).check_update(update) + assert not filters.Caption(['test1', 'test2']).check_update(update) + + def test_filters_command_default(self, update): + update.message.text = 'test' + assert not filters.COMMAND.check_update(update) + update.message.text = '/test' + assert not filters.COMMAND.check_update(update) + # Only accept commands at the beginning + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 3, 5)] + assert not filters.COMMAND.check_update(update) + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + assert filters.COMMAND.check_update(update) + + def test_filters_command_anywhere(self, update): + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 5, 4)] + assert filters.Command(False).check_update(update) + + def test_filters_regex(self, update): + sre_type = type(re.match("", "")) + update.message.text = '/start deep-linked param' + result = filters.Regex(r'deep-linked param').check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert type(matches[0]) is sre_type + update.message.text = '/help' + assert filters.Regex(r'help').check_update(update) + + update.message.text = 'test' + assert not filters.Regex(r'fail').check_update(update) + assert filters.Regex(r'test').check_update(update) + assert filters.Regex(re.compile(r'test')).check_update(update) + assert filters.Regex(re.compile(r'TEST', re.IGNORECASE)).check_update(update) + + update.message.text = 'i love python' + assert filters.Regex(r'.\b[lo]{2}ve python').check_update(update) + + update.message.text = None + assert not filters.Regex(r'fail').check_update(update) + + def test_filters_regex_multiple(self, update): + sre_type = type(re.match("", "")) + update.message.text = '/start deep-linked param' + result = (filters.Regex('deep') & filters.Regex(r'linked param')).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.Regex('deep') | filters.Regex(r'linked param')).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.Regex('not int') | filters.Regex(r'linked param')).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.Regex('not int') & filters.Regex(r'linked param')).check_update(update) + assert not result + + def test_filters_merged_with_regex(self, update): + sre_type = type(re.match("", "")) + update.message.text = '/start deep-linked param' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = (filters.COMMAND & filters.Regex(r'linked param')).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.Regex(r'linked param') & filters.COMMAND).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.Regex(r'linked param') | filters.COMMAND).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + # Should not give a match since it's a or filter and it short circuits + result = (filters.COMMAND | filters.Regex(r'linked param')).check_update(update) + assert result is True + + def test_regex_complex_merges(self, update): + sre_type = type(re.match("", "")) + update.message.text = 'test it out' + test_filter = filters.Regex('test') & ( + (filters.StatusUpdate.ALL | filters.FORWARDED) | filters.Regex('out') + ) + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 2 + assert all(type(res) is sre_type for res in matches) + update.message.forward_date = datetime.datetime.utcnow() + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.text = 'test it' + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.forward_date = None + result = test_filter.check_update(update) + assert not result + update.message.text = 'test it out' + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.pinned_message = True + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.text = 'it out' + result = test_filter.check_update(update) + assert not result + + update.message.text = 'test it out' + update.message.forward_date = None + update.message.pinned_message = None + test_filter = (filters.Regex('test') | filters.COMMAND) & ( + filters.Regex('it') | filters.StatusUpdate.ALL + ) + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 2 + assert all(type(res) is sre_type for res in matches) + update.message.text = 'test' + result = test_filter.check_update(update) + assert not result + update.message.pinned_message = True + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 1 + assert all(type(res) is sre_type for res in matches) + update.message.text = 'nothing' + result = test_filter.check_update(update) + assert not result + update.message.text = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = test_filter.check_update(update) + assert result + assert isinstance(result, bool) + update.message.text = '/start it' + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 1 + assert all(type(res) is sre_type for res in matches) + + def test_regex_inverted(self, update): + update.message.text = '/start deep-linked param' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + inv = ~filters.Regex(r'deep-linked param') + result = inv.check_update(update) + assert not result + update.message.text = 'not it' + result = inv.check_update(update) + assert result + assert isinstance(result, bool) + + inv = ~filters.Regex('linked') & filters.COMMAND + update.message.text = "it's linked" + result = inv.check_update(update) + assert not result + update.message.text = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = inv.check_update(update) + assert result + update.message.text = '/linked' + result = inv.check_update(update) + assert not result + + inv = ~filters.Regex('linked') | filters.COMMAND + update.message.text = "it's linked" + update.message.entities = [] + result = inv.check_update(update) + assert not result + update.message.text = '/start linked' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = inv.check_update(update) + assert result + update.message.text = '/start' + result = inv.check_update(update) + assert result + update.message.text = 'nothig' + update.message.entities = [] + result = inv.check_update(update) + assert result + + def test_filters_caption_regex(self, update): + sre_type = type(re.match("", "")) + update.message.caption = '/start deep-linked param' + result = filters.CaptionRegex(r'deep-linked param').check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert type(matches[0]) is sre_type + update.message.caption = '/help' + assert filters.CaptionRegex(r'help').check_update(update) + + update.message.caption = 'test' + assert not filters.CaptionRegex(r'fail').check_update(update) + assert filters.CaptionRegex(r'test').check_update(update) + assert filters.CaptionRegex(re.compile(r'test')).check_update(update) + assert filters.CaptionRegex(re.compile(r'TEST', re.IGNORECASE)).check_update(update) + + update.message.caption = 'i love python' + assert filters.CaptionRegex(r'.\b[lo]{2}ve python').check_update(update) + + update.message.caption = None + assert not filters.CaptionRegex(r'fail').check_update(update) + + def test_filters_caption_regex_multiple(self, update): + sre_type = type(re.match("", "")) + update.message.caption = '/start deep-linked param' + _and = filters.CaptionRegex('deep') & filters.CaptionRegex(r'linked param') + result = _and.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + _or = filters.CaptionRegex('deep') | filters.CaptionRegex(r'linked param') + result = _or.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + _or = filters.CaptionRegex('not int') | filters.CaptionRegex(r'linked param') + result = _or.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + _and = filters.CaptionRegex('not int') & filters.CaptionRegex(r'linked param') + result = _and.check_update(update) + assert not result + + def test_filters_merged_with_caption_regex(self, update): + sre_type = type(re.match("", "")) + update.message.caption = '/start deep-linked param' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = (filters.COMMAND & filters.CaptionRegex(r'linked param')).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.CaptionRegex(r'linked param') & filters.COMMAND).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + result = (filters.CaptionRegex(r'linked param') | filters.COMMAND).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + # Should not give a match since it's a or filter and it short circuits + result = (filters.COMMAND | filters.CaptionRegex(r'linked param')).check_update(update) + assert result is True + + def test_caption_regex_complex_merges(self, update): + sre_type = type(re.match("", "")) + update.message.caption = 'test it out' + test_filter = filters.CaptionRegex('test') & ( + (filters.StatusUpdate.ALL | filters.FORWARDED) | filters.CaptionRegex('out') + ) + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 2 + assert all(type(res) is sre_type for res in matches) + update.message.forward_date = datetime.datetime.utcnow() + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.caption = 'test it' + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.forward_date = None + result = test_filter.check_update(update) + assert not result + update.message.caption = 'test it out' + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.pinned_message = True + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert all(type(res) is sre_type for res in matches) + update.message.caption = 'it out' + result = test_filter.check_update(update) + assert not result + + update.message.caption = 'test it out' + update.message.forward_date = None + update.message.pinned_message = None + test_filter = (filters.CaptionRegex('test') | filters.COMMAND) & ( + filters.CaptionRegex('it') | filters.StatusUpdate.ALL + ) + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 2 + assert all(type(res) is sre_type for res in matches) + update.message.caption = 'test' + result = test_filter.check_update(update) + assert not result + update.message.pinned_message = True + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 1 + assert all(type(res) is sre_type for res in matches) + update.message.caption = 'nothing' + result = test_filter.check_update(update) + assert not result + update.message.caption = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = test_filter.check_update(update) + assert result + assert isinstance(result, bool) + update.message.caption = '/start it' + result = test_filter.check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert len(matches) == 1 + assert all(type(res) is sre_type for res in matches) + + def test_caption_regex_inverted(self, update): + update.message.caption = '/start deep-linked param' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + test_filter = ~filters.CaptionRegex(r'deep-linked param') + result = test_filter.check_update(update) + assert not result + update.message.caption = 'not it' + result = test_filter.check_update(update) + assert result + assert isinstance(result, bool) + + test_filter = ~filters.CaptionRegex('linked') & filters.COMMAND + update.message.caption = "it's linked" + result = test_filter.check_update(update) + assert not result + update.message.caption = '/start' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = test_filter.check_update(update) + assert result + update.message.caption = '/linked' + result = test_filter.check_update(update) + assert not result + + test_filter = ~filters.CaptionRegex('linked') | filters.COMMAND + update.message.caption = "it's linked" + update.message.entities = [] + result = test_filter.check_update(update) + assert not result + update.message.caption = '/start linked' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 6)] + result = test_filter.check_update(update) + assert result + update.message.caption = '/start' + result = test_filter.check_update(update) + assert result + update.message.caption = 'nothig' + update.message.entities = [] + result = test_filter.check_update(update) + assert result + + def test_filters_reply(self, update): + another_message = Message( + 1, + datetime.datetime.utcnow(), + Chat(0, 'private'), + from_user=User(1, 'TestOther', False), + ) + update.message.text = 'test' + assert not filters.REPLY.check_update(update) + update.message.reply_to_message = another_message + assert filters.REPLY.check_update(update) + + def test_filters_audio(self, update): + assert not filters.AUDIO.check_update(update) + update.message.audio = 'test' + assert filters.AUDIO.check_update(update) + + def test_filters_document(self, update): + assert not filters.DOCUMENT.check_update(update) + update.message.document = 'test' + assert filters.DOCUMENT.check_update(update) + + def test_filters_document_type(self, update): + update.message.document = Document( + "file_id", 'unique_id', mime_type="application/vnd.android.package-archive" + ) + assert filters.Document.APK.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.DOC.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "application/msword" + assert filters.Document.DOC.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.DOCX.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ) + assert filters.Document.DOCX.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.EXE.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "application/octet-stream" + assert filters.Document.EXE.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.DOCX.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "image/gif" + assert filters.Document.GIF.check_update(update) + assert filters.Document.IMAGE.check_update(update) + assert not filters.Document.JPG.check_update(update) + assert not filters.Document.TEXT.check_update(update) + + update.message.document.mime_type = "image/jpeg" + assert filters.Document.JPG.check_update(update) + assert filters.Document.IMAGE.check_update(update) + assert not filters.Document.MP3.check_update(update) + assert not filters.Document.VIDEO.check_update(update) + + update.message.document.mime_type = "audio/mpeg" + assert filters.Document.MP3.check_update(update) + assert filters.Document.AUDIO.check_update(update) + assert not filters.Document.PDF.check_update(update) + assert not filters.Document.IMAGE.check_update(update) + + update.message.document.mime_type = "application/pdf" + assert filters.Document.PDF.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.PY.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "text/x-python" + assert filters.Document.PY.check_update(update) + assert filters.Document.TEXT.check_update(update) + assert not filters.Document.SVG.check_update(update) + assert not filters.Document.APPLICATION.check_update(update) + + update.message.document.mime_type = "image/svg+xml" + assert filters.Document.SVG.check_update(update) + assert filters.Document.IMAGE.check_update(update) + assert not filters.Document.TXT.check_update(update) + assert not filters.Document.VIDEO.check_update(update) + + update.message.document.mime_type = "text/plain" + assert filters.Document.TXT.check_update(update) + assert filters.Document.TEXT.check_update(update) + assert not filters.Document.TARGZ.check_update(update) + assert not filters.Document.APPLICATION.check_update(update) + + update.message.document.mime_type = "application/x-compressed-tar" + assert filters.Document.TARGZ.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.WAV.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "audio/x-wav" + assert filters.Document.WAV.check_update(update) + assert filters.Document.AUDIO.check_update(update) + assert not filters.Document.XML.check_update(update) + assert not filters.Document.IMAGE.check_update(update) + + update.message.document.mime_type = "text/xml" + assert filters.Document.XML.check_update(update) + assert filters.Document.TEXT.check_update(update) + assert not filters.Document.ZIP.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "application/zip" + assert filters.Document.ZIP.check_update(update) + assert filters.Document.APPLICATION.check_update(update) + assert not filters.Document.APK.check_update(update) + assert not filters.Document.AUDIO.check_update(update) + + update.message.document.mime_type = "image/x-rgb" + assert not filters.Document.Category("application/").check_update(update) + assert not filters.Document.MimeType("application/x-sh").check_update(update) + update.message.document.mime_type = "application/x-sh" + assert filters.Document.Category("application/").check_update(update) + assert filters.Document.MimeType("application/x-sh").check_update(update) + + update.message.document.mime_type = None + assert not filters.Document.Category("application/").check_update(update) + assert not filters.Document.MimeType("application/x-sh").check_update(update) + + def test_filters_file_extension_basic(self, update): + update.message.document = Document( + "file_id", + "unique_id", + file_name="file.jpg", + mime_type="image/jpeg", + ) + assert filters.Document.FileExtension("jpg").check_update(update) + assert not filters.Document.FileExtension("jpeg").check_update(update) + assert not filters.Document.FileExtension("file.jpg").check_update(update) + + update.message.document.file_name = "file.tar.gz" + assert filters.Document.FileExtension("tar.gz").check_update(update) + assert filters.Document.FileExtension("gz").check_update(update) + assert not filters.Document.FileExtension("tgz").check_update(update) + assert not filters.Document.FileExtension("jpg").check_update(update) + + update.message.document.file_name = None + assert not filters.Document.FileExtension("jpg").check_update(update) + + update.message.document = None + assert not filters.Document.FileExtension("jpg").check_update(update) + + def test_filters_file_extension_minds_dots(self, update): + update.message.document = Document( + "file_id", + "unique_id", + file_name="file.jpg", + mime_type="image/jpeg", + ) + assert not filters.Document.FileExtension(".jpg").check_update(update) + assert not filters.Document.FileExtension("e.jpg").check_update(update) + assert not filters.Document.FileExtension("file.jpg").check_update(update) + assert not filters.Document.FileExtension("").check_update(update) + + update.message.document.file_name = "file..jpg" + assert filters.Document.FileExtension("jpg").check_update(update) + assert filters.Document.FileExtension(".jpg").check_update(update) + assert not filters.Document.FileExtension("..jpg").check_update(update) + + update.message.document.file_name = "file.docx" + assert filters.Document.FileExtension("docx").check_update(update) + assert not filters.Document.FileExtension("doc").check_update(update) + assert not filters.Document.FileExtension("ocx").check_update(update) + + update.message.document.file_name = "file" + assert not filters.Document.FileExtension("").check_update(update) + assert not filters.Document.FileExtension("file").check_update(update) + + update.message.document.file_name = "file." + assert filters.Document.FileExtension("").check_update(update) + + def test_filters_file_extension_none_arg(self, update): + update.message.document = Document( + "file_id", + "unique_id", + file_name="file.jpg", + mime_type="image/jpeg", + ) + assert not filters.Document.FileExtension(None).check_update(update) + + update.message.document.file_name = "file" + assert filters.Document.FileExtension(None).check_update(update) + assert not filters.Document.FileExtension("None").check_update(update) + + update.message.document.file_name = "file." + assert not filters.Document.FileExtension(None).check_update(update) + + update.message.document = None + assert not filters.Document.FileExtension(None).check_update(update) + + def test_filters_file_extension_case_sensitivity(self, update): + update.message.document = Document( + "file_id", + "unique_id", + file_name="file.jpg", + mime_type="image/jpeg", + ) + assert filters.Document.FileExtension("JPG").check_update(update) + assert filters.Document.FileExtension("jpG").check_update(update) + + update.message.document.file_name = "file.JPG" + assert filters.Document.FileExtension("jpg").check_update(update) + assert not filters.Document.FileExtension("jpg", case_sensitive=True).check_update(update) + + update.message.document.file_name = "file.Dockerfile" + assert filters.Document.FileExtension("Dockerfile", case_sensitive=True).check_update( + update + ) + assert not filters.Document.FileExtension("DOCKERFILE", case_sensitive=True).check_update( + update + ) + + def test_filters_file_extension_name(self): + assert filters.Document.FileExtension("jpg").name == ( + "filters.Document.FileExtension('jpg')" + ) + assert filters.Document.FileExtension("JPG").name == ( + "filters.Document.FileExtension('jpg')" + ) + assert filters.Document.FileExtension("jpg", case_sensitive=True).name == ( + "filters.Document.FileExtension('jpg', case_sensitive=True)" + ) + assert filters.Document.FileExtension("JPG", case_sensitive=True).name == ( + "filters.Document.FileExtension('JPG', case_sensitive=True)" + ) + assert filters.Document.FileExtension(".jpg").name == ( + "filters.Document.FileExtension('.jpg')" + ) + assert filters.Document.FileExtension("").name == "filters.Document.FileExtension('')" + assert filters.Document.FileExtension(None).name == "filters.Document.FileExtension(None)" + + def test_filters_animation(self, update): + assert not filters.ANIMATION.check_update(update) + update.message.animation = 'test' + assert filters.ANIMATION.check_update(update) + + def test_filters_photo(self, update): + assert not filters.PHOTO.check_update(update) + update.message.photo = 'test' + assert filters.PHOTO.check_update(update) + + def test_filters_sticker(self, update): + assert not filters.STICKER.check_update(update) + update.message.sticker = 'test' + assert filters.STICKER.check_update(update) + + def test_filters_video(self, update): + assert not filters.VIDEO.check_update(update) + update.message.video = 'test' + assert filters.VIDEO.check_update(update) + + def test_filters_voice(self, update): + assert not filters.VOICE.check_update(update) + update.message.voice = 'test' + assert filters.VOICE.check_update(update) + + def test_filters_video_note(self, update): + assert not filters.VIDEO_NOTE.check_update(update) + update.message.video_note = 'test' + assert filters.VIDEO_NOTE.check_update(update) + + def test_filters_contact(self, update): + assert not filters.CONTACT.check_update(update) + update.message.contact = 'test' + assert filters.CONTACT.check_update(update) + + def test_filters_location(self, update): + assert not filters.LOCATION.check_update(update) + update.message.location = 'test' + assert filters.LOCATION.check_update(update) + + def test_filters_venue(self, update): + assert not filters.VENUE.check_update(update) + update.message.venue = 'test' + assert filters.VENUE.check_update(update) + + def test_filters_status_update(self, update): + assert not filters.StatusUpdate.ALL.check_update(update) + + update.message.new_chat_members = ['test'] + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.NEW_CHAT_MEMBERS.check_update(update) + update.message.new_chat_members = None + + update.message.left_chat_member = 'test' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.LEFT_CHAT_MEMBER.check_update(update) + update.message.left_chat_member = None + + update.message.new_chat_title = 'test' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.NEW_CHAT_TITLE.check_update(update) + update.message.new_chat_title = '' + + update.message.new_chat_photo = 'test' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.NEW_CHAT_PHOTO.check_update(update) + update.message.new_chat_photo = None + + update.message.delete_chat_photo = True + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.DELETE_CHAT_PHOTO.check_update(update) + update.message.delete_chat_photo = False + + update.message.group_chat_created = True + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.CHAT_CREATED.check_update(update) + update.message.group_chat_created = False + + update.message.supergroup_chat_created = True + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.CHAT_CREATED.check_update(update) + update.message.supergroup_chat_created = False + + update.message.channel_chat_created = True + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.CHAT_CREATED.check_update(update) + update.message.channel_chat_created = False + + update.message.message_auto_delete_timer_changed = True + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.MESSAGE_AUTO_DELETE_TIMER_CHANGED.check_update(update) + update.message.message_auto_delete_timer_changed = False + + update.message.migrate_to_chat_id = 100 + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.MIGRATE.check_update(update) + update.message.migrate_to_chat_id = 0 + + update.message.migrate_from_chat_id = 100 + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.MIGRATE.check_update(update) + update.message.migrate_from_chat_id = 0 + + update.message.pinned_message = 'test' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.PINNED_MESSAGE.check_update(update) + update.message.pinned_message = None + + update.message.connected_website = 'https://example.com/' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.CONNECTED_WEBSITE.check_update(update) + update.message.connected_website = None + + update.message.proximity_alert_triggered = 'alert' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.PROXIMITY_ALERT_TRIGGERED.check_update(update) + update.message.proximity_alert_triggered = None + + update.message.voice_chat_scheduled = 'scheduled' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.VOICE_CHAT_SCHEDULED.check_update(update) + update.message.voice_chat_scheduled = None + + update.message.voice_chat_started = 'hello' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.VOICE_CHAT_STARTED.check_update(update) + update.message.voice_chat_started = None + + update.message.voice_chat_ended = 'bye' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.VOICE_CHAT_ENDED.check_update(update) + update.message.voice_chat_ended = None + + update.message.voice_chat_participants_invited = 'invited' + assert filters.StatusUpdate.ALL.check_update(update) + assert filters.StatusUpdate.VOICE_CHAT_PARTICIPANTS_INVITED.check_update(update) + update.message.voice_chat_participants_invited = None + + def test_filters_forwarded(self, update): + assert not filters.FORWARDED.check_update(update) + update.message.forward_date = datetime.datetime.utcnow() + assert filters.FORWARDED.check_update(update) + + def test_filters_game(self, update): + assert not filters.GAME.check_update(update) + update.message.game = 'test' + assert filters.GAME.check_update(update) + + def test_entities_filter(self, update, message_entity): + update.message.entities = [message_entity] + assert filters.Entity(message_entity.type).check_update(update) + + update.message.entities = [] + assert not filters.Entity(MessageEntity.MENTION).check_update(update) + + second = message_entity.to_dict() + second['type'] = 'bold' + second = MessageEntity.de_json(second, None) + update.message.entities = [message_entity, second] + assert filters.Entity(message_entity.type).check_update(update) + assert not filters.CaptionEntity(message_entity.type).check_update(update) + + def test_caption_entities_filter(self, update, message_entity): + update.message.caption_entities = [message_entity] + assert filters.CaptionEntity(message_entity.type).check_update(update) + + update.message.caption_entities = [] + assert not filters.CaptionEntity(MessageEntity.MENTION).check_update(update) + + second = message_entity.to_dict() + second['type'] = 'bold' + second = MessageEntity.de_json(second, None) + update.message.caption_entities = [message_entity, second] + assert filters.CaptionEntity(message_entity.type).check_update(update) + assert not filters.Entity(message_entity.type).check_update(update) + + @pytest.mark.parametrize( + 'chat_type, results', + [ + (Chat.PRIVATE, (True, False, False, False, False)), + (Chat.GROUP, (False, True, False, True, False)), + (Chat.SUPERGROUP, (False, False, True, True, False)), + (Chat.CHANNEL, (False, False, False, False, True)), + ], + ) + def test_filters_chat_types(self, update, chat_type, results): + update.message.chat.type = chat_type + assert filters.ChatType.PRIVATE.check_update(update) is results[0] + assert filters.ChatType.GROUP.check_update(update) is results[1] + assert filters.ChatType.SUPERGROUP.check_update(update) is results[2] + assert filters.ChatType.GROUPS.check_update(update) is results[3] + assert filters.ChatType.CHANNEL.check_update(update) is results[4] + + def test_filters_user_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): + filters.User(user_id=1, username='user') + + def test_filters_user_allow_empty(self, update): + assert not filters.User().check_update(update) + assert filters.User(allow_empty=True).check_update(update) + + def test_filters_user_id(self, update): + assert not filters.User(user_id=1).check_update(update) + update.message.from_user.id = 1 + assert filters.User(user_id=1).check_update(update) + assert filters.USER.check_update(update) + update.message.from_user.id = 2 + assert filters.User(user_id=[1, 2]).check_update(update) + assert not filters.User(user_id=[3, 4]).check_update(update) + update.message.from_user = None + assert not filters.USER.check_update(update) + assert not filters.User(user_id=[3, 4]).check_update(update) + + def test_filters_username(self, update): + assert not filters.User(username='user').check_update(update) + assert not filters.User(username='Testuser').check_update(update) + update.message.from_user.username = 'user@' + assert filters.User(username='@user@').check_update(update) + assert filters.User(username='user@').check_update(update) + assert filters.User(username=['user1', 'user@', 'user2']).check_update(update) + assert not filters.User(username=['@username', '@user_2']).check_update(update) + update.message.from_user = None + assert not filters.User(username=['@username', '@user_2']).check_update(update) + + def test_filters_user_change_id(self, update): + f = filters.User(user_id=1) + assert f.user_ids == {1} + update.message.from_user.id = 1 + assert f.check_update(update) + update.message.from_user.id = 2 + assert not f.check_update(update) + f.user_ids = 2 + assert f.user_ids == {2} + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'user' + + def test_filters_user_change_username(self, update): + f = filters.User(username='user') + update.message.from_user.username = 'user' + assert f.check_update(update) + update.message.from_user.username = 'User' + assert not f.check_update(update) + f.usernames = 'User' + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='user_id in conjunction'): + f.user_ids = 1 + + def test_filters_user_add_user_by_name(self, update): + users = ['user_a', 'user_b', 'user_c'] + f = filters.User() + + for user in users: + update.message.from_user.username = user + assert not f.check_update(update) + + f.add_usernames('user_a') + f.add_usernames(['user_b', 'user_c']) + + for user in users: + update.message.from_user.username = user + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='user_id in conjunction'): + f.add_user_ids(1) + + def test_filters_user_add_user_by_id(self, update): + users = [1, 2, 3] + f = filters.User() + + for user in users: + update.message.from_user.id = user + assert not f.check_update(update) + + f.add_user_ids(1) + f.add_user_ids([2, 3]) + + for user in users: + update.message.from_user.username = user + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('user') + + def test_filters_user_remove_user_by_name(self, update): + users = ['user_a', 'user_b', 'user_c'] + f = filters.User(username=users) + + with pytest.raises(RuntimeError, match='user_id in conjunction'): + f.remove_user_ids(1) + + for user in users: + update.message.from_user.username = user + assert f.check_update(update) + + f.remove_usernames('user_a') + f.remove_usernames(['user_b', 'user_c']) + + for user in users: + update.message.from_user.username = user + assert not f.check_update(update) + + def test_filters_user_remove_user_by_id(self, update): + users = [1, 2, 3] + f = filters.User(user_id=users) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('user') + + for user in users: + update.message.from_user.id = user + assert f.check_update(update) + + f.remove_user_ids(1) + f.remove_user_ids([2, 3]) + + for user in users: + update.message.from_user.username = user + assert not f.check_update(update) + + def test_filters_user_repr(self): + f = filters.User([1, 2]) + assert str(f) == 'filters.User(1, 2)' + f.remove_user_ids(1) + f.remove_user_ids(2) + assert str(f) == 'filters.User()' + f.add_usernames('@foobar') + assert str(f) == 'filters.User(foobar)' + f.add_usernames('@barfoo') + assert str(f).startswith('filters.User(') + # we don't know th exact order + assert 'barfoo' in str(f) and 'foobar' in str(f) + + with pytest.raises(RuntimeError, match='Cannot set name'): + f.name = 'foo' + + def test_filters_chat_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): + filters.Chat(chat_id=1, username='chat') + + def test_filters_chat_allow_empty(self, update): + assert not filters.Chat().check_update(update) + assert filters.Chat(allow_empty=True).check_update(update) + + def test_filters_chat_id(self, update): + assert not filters.Chat(chat_id=1).check_update(update) + assert filters.CHAT.check_update(update) + update.message.chat.id = 1 + assert filters.Chat(chat_id=1).check_update(update) + assert filters.CHAT.check_update(update) + update.message.chat.id = 2 + assert filters.Chat(chat_id=[1, 2]).check_update(update) + assert not filters.Chat(chat_id=[3, 4]).check_update(update) + update.message.chat = None + assert not filters.CHAT.check_update(update) + assert not filters.Chat(chat_id=[3, 4]).check_update(update) + + def test_filters_chat_username(self, update): + assert not filters.Chat(username='chat').check_update(update) + assert not filters.Chat(username='Testchat').check_update(update) + update.message.chat.username = 'chat@' + assert filters.Chat(username='@chat@').check_update(update) + assert filters.Chat(username='chat@').check_update(update) + assert filters.Chat(username=['chat1', 'chat@', 'chat2']).check_update(update) + assert not filters.Chat(username=['@username', '@chat_2']).check_update(update) + update.message.chat = None + assert not filters.Chat(username=['@username', '@chat_2']).check_update(update) + + def test_filters_chat_change_id(self, update): + f = filters.Chat(chat_id=1) + assert f.chat_ids == {1} + update.message.chat.id = 1 + assert f.check_update(update) + update.message.chat.id = 2 + assert not f.check_update(update) + f.chat_ids = 2 + assert f.chat_ids == {2} + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'chat' + + def test_filters_chat_change_username(self, update): + f = filters.Chat(username='chat') + update.message.chat.username = 'chat' + assert f.check_update(update) + update.message.chat.username = 'User' + assert not f.check_update(update) + f.usernames = 'User' + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.chat_ids = 1 + + def test_filters_chat_add_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = filters.Chat() + + for chat in chats: + update.message.chat.username = chat + assert not f.check_update(update) + + f.add_usernames('chat_a') + f.add_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.chat.username = chat + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.add_chat_ids(1) + + def test_filters_chat_add_chat_by_id(self, update): + chats = [1, 2, 3] + f = filters.Chat() + + for chat in chats: + update.message.chat.id = chat + assert not f.check_update(update) + + f.add_chat_ids(1) + f.add_chat_ids([2, 3]) + + for chat in chats: + update.message.chat.username = chat + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('chat') + + def test_filters_chat_remove_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = filters.Chat(username=chats) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.remove_chat_ids(1) + + for chat in chats: + update.message.chat.username = chat + assert f.check_update(update) + + f.remove_usernames('chat_a') + f.remove_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.chat.username = chat + assert not f.check_update(update) + + def test_filters_chat_remove_chat_by_id(self, update): + chats = [1, 2, 3] + f = filters.Chat(chat_id=chats) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('chat') + + for chat in chats: + update.message.chat.id = chat + assert f.check_update(update) + + f.remove_chat_ids(1) + f.remove_chat_ids([2, 3]) + + for chat in chats: + update.message.chat.username = chat + assert not f.check_update(update) + + def test_filters_chat_repr(self): + f = filters.Chat([1, 2]) + assert str(f) == 'filters.Chat(1, 2)' + f.remove_chat_ids(1) + f.remove_chat_ids(2) + assert str(f) == 'filters.Chat()' + f.add_usernames('@foobar') + assert str(f) == 'filters.Chat(foobar)' + f.add_usernames('@barfoo') + assert str(f).startswith('filters.Chat(') + # we don't know th exact order + assert 'barfoo' in str(f) and 'foobar' in str(f) + + with pytest.raises(RuntimeError, match='Cannot set name'): + f.name = 'foo' + + def test_filters_forwarded_from_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): + filters.ForwardedFrom(chat_id=1, username='chat') + + def test_filters_forwarded_from_allow_empty(self, update): + assert not filters.ForwardedFrom().check_update(update) + assert filters.ForwardedFrom(allow_empty=True).check_update(update) + + def test_filters_forwarded_from_id(self, update): + # Test with User id- + assert not filters.ForwardedFrom(chat_id=1).check_update(update) + update.message.forward_from.id = 1 + assert filters.ForwardedFrom(chat_id=1).check_update(update) + update.message.forward_from.id = 2 + assert filters.ForwardedFrom(chat_id=[1, 2]).check_update(update) + assert not filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) + update.message.forward_from = None + assert not filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) + + # Test with Chat id- + update.message.forward_from_chat.id = 4 + assert filters.ForwardedFrom(chat_id=[4]).check_update(update) + assert filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) + + update.message.forward_from_chat.id = 2 + assert not filters.ForwardedFrom(chat_id=[3, 4]).check_update(update) + assert filters.ForwardedFrom(chat_id=2).check_update(update) + update.message.forward_from_chat = None + + def test_filters_forwarded_from_username(self, update): + # For User username + assert not filters.ForwardedFrom(username='chat').check_update(update) + assert not filters.ForwardedFrom(username='Testchat').check_update(update) + update.message.forward_from.username = 'chat@' + assert filters.ForwardedFrom(username='@chat@').check_update(update) + assert filters.ForwardedFrom(username='chat@').check_update(update) + assert filters.ForwardedFrom(username=['chat1', 'chat@', 'chat2']).check_update(update) + assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) + update.message.forward_from = None + assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) + + # For Chat username + assert not filters.ForwardedFrom(username='chat').check_update(update) + assert not filters.ForwardedFrom(username='Testchat').check_update(update) + update.message.forward_from_chat.username = 'chat@' + assert filters.ForwardedFrom(username='@chat@').check_update(update) + assert filters.ForwardedFrom(username='chat@').check_update(update) + assert filters.ForwardedFrom(username=['chat1', 'chat@', 'chat2']).check_update(update) + assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) + update.message.forward_from_chat = None + assert not filters.ForwardedFrom(username=['@username', '@chat_2']).check_update(update) + + def test_filters_forwarded_from_change_id(self, update): + f = filters.ForwardedFrom(chat_id=1) + # For User ids- + assert f.chat_ids == {1} + update.message.forward_from.id = 1 + assert f.check_update(update) + update.message.forward_from.id = 2 + assert not f.check_update(update) + f.chat_ids = 2 + assert f.chat_ids == {2} + assert f.check_update(update) + + # For Chat ids- + f = filters.ForwardedFrom(chat_id=1) # reset this + update.message.forward_from = None # and change this to None, only one of them can be True + assert f.chat_ids == {1} + update.message.forward_from_chat.id = 1 + assert f.check_update(update) + update.message.forward_from_chat.id = 2 + assert not f.check_update(update) + f.chat_ids = 2 + assert f.chat_ids == {2} + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'chat' + + def test_filters_forwarded_from_change_username(self, update): + # For User usernames + f = filters.ForwardedFrom(username='chat') + update.message.forward_from.username = 'chat' + assert f.check_update(update) + update.message.forward_from.username = 'User' + assert not f.check_update(update) + f.usernames = 'User' + assert f.check_update(update) + + # For Chat usernames + update.message.forward_from = None + f = filters.ForwardedFrom(username='chat') + update.message.forward_from_chat.username = 'chat' + assert f.check_update(update) + update.message.forward_from_chat.username = 'User' + assert not f.check_update(update) + f.usernames = 'User' + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.chat_ids = 1 + + def test_filters_forwarded_from_add_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = filters.ForwardedFrom() + + # For User usernames + for chat in chats: + update.message.forward_from.username = chat + assert not f.check_update(update) + + f.add_usernames('chat_a') + f.add_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.forward_from.username = chat + assert f.check_update(update) + + # For Chat usernames + update.message.forward_from = None + f = filters.ForwardedFrom() + for chat in chats: + update.message.forward_from_chat.username = chat + assert not f.check_update(update) + + f.add_usernames('chat_a') + f.add_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.forward_from_chat.username = chat + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.add_chat_ids(1) + + def test_filters_forwarded_from_add_chat_by_id(self, update): + chats = [1, 2, 3] + f = filters.ForwardedFrom() + + # For User ids + for chat in chats: + update.message.forward_from.id = chat + assert not f.check_update(update) + + f.add_chat_ids(1) + f.add_chat_ids([2, 3]) + + for chat in chats: + update.message.forward_from.username = chat + assert f.check_update(update) + + # For Chat ids- + update.message.forward_from = None + f = filters.ForwardedFrom() + for chat in chats: + update.message.forward_from_chat.id = chat + assert not f.check_update(update) + + f.add_chat_ids(1) + f.add_chat_ids([2, 3]) + + for chat in chats: + update.message.forward_from_chat.username = chat + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('chat') + + def test_filters_forwarded_from_remove_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = filters.ForwardedFrom(username=chats) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.remove_chat_ids(1) + + # For User usernames + for chat in chats: + update.message.forward_from.username = chat + assert f.check_update(update) + + f.remove_usernames('chat_a') + f.remove_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.forward_from.username = chat + assert not f.check_update(update) + + # For Chat usernames + update.message.forward_from = None + f = filters.ForwardedFrom(username=chats) + for chat in chats: + update.message.forward_from_chat.username = chat + assert f.check_update(update) + + f.remove_usernames('chat_a') + f.remove_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.forward_from_chat.username = chat + assert not f.check_update(update) + + def test_filters_forwarded_from_remove_chat_by_id(self, update): + chats = [1, 2, 3] + f = filters.ForwardedFrom(chat_id=chats) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('chat') + + # For User ids + for chat in chats: + update.message.forward_from.id = chat + assert f.check_update(update) + + f.remove_chat_ids(1) + f.remove_chat_ids([2, 3]) + + for chat in chats: + update.message.forward_from.username = chat + assert not f.check_update(update) + + # For Chat ids + update.message.forward_from = None + f = filters.ForwardedFrom(chat_id=chats) + for chat in chats: + update.message.forward_from_chat.id = chat + assert f.check_update(update) + + f.remove_chat_ids(1) + f.remove_chat_ids([2, 3]) + + for chat in chats: + update.message.forward_from_chat.username = chat + assert not f.check_update(update) + + def test_filters_forwarded_from_repr(self): + f = filters.ForwardedFrom([1, 2]) + assert str(f) == 'filters.ForwardedFrom(1, 2)' + f.remove_chat_ids(1) + f.remove_chat_ids(2) + assert str(f) == 'filters.ForwardedFrom()' + f.add_usernames('@foobar') + assert str(f) == 'filters.ForwardedFrom(foobar)' + f.add_usernames('@barfoo') + assert str(f).startswith('filters.ForwardedFrom(') + # we don't know the exact order + assert 'barfoo' in str(f) and 'foobar' in str(f) + + with pytest.raises(RuntimeError, match='Cannot set name'): + f.name = 'foo' + + def test_filters_sender_chat_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): + filters.SenderChat(chat_id=1, username='chat') + + def test_filters_sender_chat_allow_empty(self, update): + assert not filters.SenderChat().check_update(update) + assert filters.SenderChat(allow_empty=True).check_update(update) + + def test_filters_sender_chat_id(self, update): + assert not filters.SenderChat(chat_id=1).check_update(update) + update.message.sender_chat.id = 1 + assert filters.SenderChat(chat_id=1).check_update(update) + update.message.sender_chat.id = 2 + assert filters.SenderChat(chat_id=[1, 2]).check_update(update) + assert not filters.SenderChat(chat_id=[3, 4]).check_update(update) + assert filters.SenderChat.ALL.check_update(update) + update.message.sender_chat = None + assert not filters.SenderChat(chat_id=[3, 4]).check_update(update) + assert not filters.SenderChat.ALL.check_update(update) + + def test_filters_sender_chat_username(self, update): + assert not filters.SenderChat(username='chat').check_update(update) + assert not filters.SenderChat(username='Testchat').check_update(update) + update.message.sender_chat.username = 'chat@' + assert filters.SenderChat(username='@chat@').check_update(update) + assert filters.SenderChat(username='chat@').check_update(update) + assert filters.SenderChat(username=['chat1', 'chat@', 'chat2']).check_update(update) + assert not filters.SenderChat(username=['@username', '@chat_2']).check_update(update) + assert filters.SenderChat.ALL.check_update(update) + update.message.sender_chat = None + assert not filters.SenderChat(username=['@username', '@chat_2']).check_update(update) + assert not filters.SenderChat.ALL.check_update(update) + + def test_filters_sender_chat_change_id(self, update): + f = filters.SenderChat(chat_id=1) + assert f.chat_ids == {1} + update.message.sender_chat.id = 1 + assert f.check_update(update) + update.message.sender_chat.id = 2 + assert not f.check_update(update) + f.chat_ids = 2 + assert f.chat_ids == {2} + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'chat' + + def test_filters_sender_chat_change_username(self, update): + f = filters.SenderChat(username='chat') + update.message.sender_chat.username = 'chat' + assert f.check_update(update) + update.message.sender_chat.username = 'User' + assert not f.check_update(update) + f.usernames = 'User' + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.chat_ids = 1 + + def test_filters_sender_chat_add_sender_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = filters.SenderChat() + + for chat in chats: + update.message.sender_chat.username = chat + assert not f.check_update(update) + + f.add_usernames('chat_a') + f.add_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.sender_chat.username = chat + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.add_chat_ids(1) + + def test_filters_sender_chat_add_sender_chat_by_id(self, update): + chats = [1, 2, 3] + f = filters.SenderChat() + + for chat in chats: + update.message.sender_chat.id = chat + assert not f.check_update(update) + + f.add_chat_ids(1) + f.add_chat_ids([2, 3]) + + for chat in chats: + update.message.sender_chat.username = chat + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('chat') + + def test_filters_sender_chat_remove_sender_chat_by_name(self, update): + chats = ['chat_a', 'chat_b', 'chat_c'] + f = filters.SenderChat(username=chats) + + with pytest.raises(RuntimeError, match='chat_id in conjunction'): + f.remove_chat_ids(1) + + for chat in chats: + update.message.sender_chat.username = chat + assert f.check_update(update) + + f.remove_usernames('chat_a') + f.remove_usernames(['chat_b', 'chat_c']) + + for chat in chats: + update.message.sender_chat.username = chat + assert not f.check_update(update) + + def test_filters_sender_chat_remove_sender_chat_by_id(self, update): + chats = [1, 2, 3] + f = filters.SenderChat(chat_id=chats) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('chat') + + for chat in chats: + update.message.sender_chat.id = chat + assert f.check_update(update) + + f.remove_chat_ids(1) + f.remove_chat_ids([2, 3]) + + for chat in chats: + update.message.sender_chat.username = chat + assert not f.check_update(update) + + def test_filters_sender_chat_repr(self): + f = filters.SenderChat([1, 2]) + assert str(f) == 'filters.SenderChat(1, 2)' + f.remove_chat_ids(1) + f.remove_chat_ids(2) + assert str(f) == 'filters.SenderChat()' + f.add_usernames('@foobar') + assert str(f) == 'filters.SenderChat(foobar)' + f.add_usernames('@barfoo') + assert str(f).startswith('filters.SenderChat(') + # we don't know th exact order + assert 'barfoo' in str(f) and 'foobar' in str(f) + + with pytest.raises(RuntimeError, match='Cannot set name'): + f.name = 'foo' + + def test_filters_sender_chat_super_group(self, update): + update.message.sender_chat.type = Chat.PRIVATE + assert not filters.SenderChat.SUPER_GROUP.check_update(update) + assert filters.SenderChat.ALL.check_update(update) + update.message.sender_chat.type = Chat.CHANNEL + assert not filters.SenderChat.SUPER_GROUP.check_update(update) + update.message.sender_chat.type = Chat.SUPERGROUP + assert filters.SenderChat.SUPER_GROUP.check_update(update) + assert filters.SenderChat.ALL.check_update(update) + update.message.sender_chat = None + assert not filters.SenderChat.SUPER_GROUP.check_update(update) + assert not filters.SenderChat.ALL.check_update(update) + + def test_filters_sender_chat_channel(self, update): + update.message.sender_chat.type = Chat.PRIVATE + assert not filters.SenderChat.CHANNEL.check_update(update) + update.message.sender_chat.type = Chat.SUPERGROUP + assert not filters.SenderChat.CHANNEL.check_update(update) + update.message.sender_chat.type = Chat.CHANNEL + assert filters.SenderChat.CHANNEL.check_update(update) + update.message.sender_chat = None + assert not filters.SenderChat.CHANNEL.check_update(update) + + def test_filters_is_automatic_forward(self, update): + assert not filters.IS_AUTOMATIC_FORWARD.check_update(update) + update.message.is_automatic_forward = True + assert filters.IS_AUTOMATIC_FORWARD.check_update(update) + + def test_filters_has_protected_content(self, update): + assert not filters.HAS_PROTECTED_CONTENT.check_update(update) + update.message.has_protected_content = True + assert filters.HAS_PROTECTED_CONTENT.check_update(update) + + def test_filters_invoice(self, update): + assert not filters.INVOICE.check_update(update) + update.message.invoice = 'test' + assert filters.INVOICE.check_update(update) + + def test_filters_successful_payment(self, update): + assert not filters.SUCCESSFUL_PAYMENT.check_update(update) + update.message.successful_payment = 'test' + assert filters.SUCCESSFUL_PAYMENT.check_update(update) + + def test_filters_passport_data(self, update): + assert not filters.PASSPORT_DATA.check_update(update) + update.message.passport_data = 'test' + assert filters.PASSPORT_DATA.check_update(update) + + def test_filters_poll(self, update): + assert not filters.POLL.check_update(update) + update.message.poll = 'test' + assert filters.POLL.check_update(update) + + @pytest.mark.parametrize('emoji', Dice.ALL_EMOJI) + def test_filters_dice(self, update, emoji): + update.message.dice = Dice(4, emoji) + assert filters.Dice.ALL.check_update(update) and filters.Dice().check_update(update) + + to_camel = emoji.name.title().replace('_', '') + assert repr(filters.Dice.ALL) == "filters.Dice.ALL" + assert repr(getattr(filters.Dice, to_camel)(4)) == f"filters.Dice.{to_camel}([4])" + + update.message.dice = None + assert not filters.Dice.ALL.check_update(update) + + @pytest.mark.parametrize('emoji', Dice.ALL_EMOJI) + def test_filters_dice_list(self, update, emoji): + update.message.dice = None + assert not filters.Dice(5).check_update(update) + + update.message.dice = Dice(5, emoji) + assert filters.Dice(5).check_update(update) + assert repr(filters.Dice(5)) == "filters.Dice([5])" + assert filters.Dice({5, 6}).check_update(update) + assert not filters.Dice(1).check_update(update) + assert not filters.Dice([2, 3]).check_update(update) + + def test_filters_dice_type(self, update): + update.message.dice = Dice(5, '🎲') + assert filters.Dice.DICE.check_update(update) + assert repr(filters.Dice.DICE) == "filters.Dice.DICE" + assert filters.Dice.Dice([4, 5]).check_update(update) + assert not filters.Dice.Darts(5).check_update(update) + assert not filters.Dice.BASKETBALL.check_update(update) + assert not filters.Dice.Dice([6]).check_update(update) + + update.message.dice = Dice(5, '🎯') + assert filters.Dice.DARTS.check_update(update) + assert filters.Dice.Darts([4, 5]).check_update(update) + assert not filters.Dice.Dice(5).check_update(update) + assert not filters.Dice.BASKETBALL.check_update(update) + assert not filters.Dice.Darts([6]).check_update(update) + + update.message.dice = Dice(5, '🏀') + assert filters.Dice.BASKETBALL.check_update(update) + assert filters.Dice.Basketball([4, 5]).check_update(update) + assert not filters.Dice.Dice(5).check_update(update) + assert not filters.Dice.DARTS.check_update(update) + assert not filters.Dice.Basketball([4]).check_update(update) + + update.message.dice = Dice(5, '⚽') + assert filters.Dice.FOOTBALL.check_update(update) + assert filters.Dice.Football([4, 5]).check_update(update) + assert not filters.Dice.Dice(5).check_update(update) + assert not filters.Dice.DARTS.check_update(update) + assert not filters.Dice.Football([4]).check_update(update) + + update.message.dice = Dice(5, '🎰') + assert filters.Dice.SLOT_MACHINE.check_update(update) + assert filters.Dice.SlotMachine([4, 5]).check_update(update) + assert not filters.Dice.Dice(5).check_update(update) + assert not filters.Dice.DARTS.check_update(update) + assert not filters.Dice.SlotMachine([4]).check_update(update) + + update.message.dice = Dice(5, '🎳') + assert filters.Dice.BOWLING.check_update(update) + assert filters.Dice.Bowling([4, 5]).check_update(update) + assert not filters.Dice.Dice(5).check_update(update) + assert not filters.Dice.DARTS.check_update(update) + assert not filters.Dice.Bowling([4]).check_update(update) + + def test_language_filter_single(self, update): + update.message.from_user.language_code = 'en_US' + assert filters.Language('en_US').check_update(update) + assert filters.Language('en').check_update(update) + assert not filters.Language('en_GB').check_update(update) + assert not filters.Language('da').check_update(update) + update.message.from_user.language_code = 'da' + assert not filters.Language('en_US').check_update(update) + assert not filters.Language('en').check_update(update) + assert not filters.Language('en_GB').check_update(update) + assert filters.Language('da').check_update(update) + + update.message.from_user = None + assert not filters.Language('da').check_update(update) + + def test_language_filter_multiple(self, update): + f = filters.Language(['en_US', 'da']) + update.message.from_user.language_code = 'en_US' + assert f.check_update(update) + update.message.from_user.language_code = 'en_GB' + assert not f.check_update(update) + update.message.from_user.language_code = 'da' + assert f.check_update(update) + + def test_and_filters(self, update): + update.message.text = 'test' + update.message.forward_date = datetime.datetime.utcnow() + assert (filters.TEXT & filters.FORWARDED).check_update(update) + update.message.text = '/test' + assert (filters.TEXT & filters.FORWARDED).check_update(update) + update.message.text = 'test' + update.message.forward_date = None + assert not (filters.TEXT & filters.FORWARDED).check_update(update) + + update.message.text = 'test' + update.message.forward_date = datetime.datetime.utcnow() + assert (filters.TEXT & filters.FORWARDED & filters.ChatType.PRIVATE).check_update(update) + + def test_or_filters(self, update): + update.message.text = 'test' + assert (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) + update.message.group_chat_created = True + assert (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) + update.message.text = None + assert (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) + update.message.group_chat_created = False + assert not (filters.TEXT | filters.StatusUpdate.ALL).check_update(update) + + def test_and_or_filters(self, update): + update.message.text = 'test' + update.message.forward_date = datetime.datetime.utcnow() + assert (filters.TEXT & (filters.StatusUpdate.ALL | filters.FORWARDED)).check_update(update) + update.message.forward_date = None + assert not (filters.TEXT & (filters.FORWARDED | filters.StatusUpdate.ALL)).check_update( + update + ) + update.message.pinned_message = True + assert filters.TEXT & (filters.FORWARDED | filters.StatusUpdate.ALL).check_update(update) + + assert ( + str(filters.TEXT & (filters.FORWARDED | filters.Entity(MessageEntity.MENTION))) + == '>' + ) + + def test_xor_filters(self, update): + update.message.text = 'test' + update.effective_user.id = 123 + assert not (filters.TEXT ^ filters.User(123)).check_update(update) + update.message.text = None + update.effective_user.id = 1234 + assert not (filters.TEXT ^ filters.User(123)).check_update(update) + update.message.text = 'test' + assert (filters.TEXT ^ filters.User(123)).check_update(update) + update.message.text = None + update.effective_user.id = 123 + assert (filters.TEXT ^ filters.User(123)).check_update(update) + + def test_xor_filters_repr(self, update): + assert str(filters.TEXT ^ filters.User(123)) == '' + with pytest.raises(RuntimeError, match='Cannot set name'): + (filters.TEXT ^ filters.User(123)).name = 'foo' + + def test_and_xor_filters(self, update): + update.message.text = 'test' + update.message.forward_date = datetime.datetime.utcnow() + assert (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) + update.message.text = None + update.effective_user.id = 123 + assert (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) + update.message.text = 'test' + assert not (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) + update.message.forward_date = None + update.message.text = None + update.effective_user.id = 123 + assert not (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) + update.message.text = 'test' + update.effective_user.id = 456 + assert not (filters.FORWARDED & (filters.TEXT ^ filters.User(123))).check_update(update) + + assert ( + str(filters.FORWARDED & (filters.TEXT ^ filters.User(123))) + == '>' + ) + + def test_xor_regex_filters(self, update): + sre_type = type(re.match("", "")) + update.message.text = 'test' + update.message.forward_date = datetime.datetime.utcnow() + assert not (filters.FORWARDED ^ filters.Regex('^test$')).check_update(update) + update.message.forward_date = None + result = (filters.FORWARDED ^ filters.Regex('^test$')).check_update(update) + assert result + assert isinstance(result, dict) + matches = result['matches'] + assert isinstance(matches, list) + assert type(matches[0]) is sre_type + update.message.forward_date = datetime.datetime.utcnow() + update.message.text = None + assert (filters.FORWARDED ^ filters.Regex('^test$')).check_update(update) is True + + def test_inverted_filters(self, update): + update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + assert filters.COMMAND.check_update(update) + assert not (~filters.COMMAND).check_update(update) + update.message.text = 'test' + update.message.entities = [] + assert not filters.COMMAND.check_update(update) + assert (~filters.COMMAND).check_update(update) + + def test_inverted_filters_repr(self, update): + assert str(~filters.TEXT) == '' + with pytest.raises(RuntimeError, match='Cannot set name'): + (~filters.TEXT).name = 'foo' + + def test_inverted_and_filters(self, update): + update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + update.message.forward_date = 1 + assert (filters.FORWARDED & filters.COMMAND).check_update(update) + assert not (~filters.FORWARDED & filters.COMMAND).check_update(update) + assert not (filters.FORWARDED & ~filters.COMMAND).check_update(update) + assert not (~(filters.FORWARDED & filters.COMMAND)).check_update(update) + update.message.forward_date = None + assert not (filters.FORWARDED & filters.COMMAND).check_update(update) + assert (~filters.FORWARDED & filters.COMMAND).check_update(update) + assert not (filters.FORWARDED & ~filters.COMMAND).check_update(update) + assert (~(filters.FORWARDED & filters.COMMAND)).check_update(update) + update.message.text = 'test' + update.message.entities = [] + assert not (filters.FORWARDED & filters.COMMAND).check_update(update) + assert not (~filters.FORWARDED & filters.COMMAND).check_update(update) + assert not (filters.FORWARDED & ~filters.COMMAND).check_update(update) + assert (~(filters.FORWARDED & filters.COMMAND)).check_update(update) + + def test_indirect_message(self, update): + class _CustomFilter(filters.MessageFilter): + test_flag = False + + def filter(self, message: Message): + self.test_flag = True + return self.test_flag + + c = _CustomFilter() + u = Update(0, callback_query=CallbackQuery('0', update.effective_user, '', update.message)) + assert not c.check_update(u) + assert not c.test_flag + assert c.check_update(update) + assert c.test_flag + + def test_custom_unnamed_filter(self, update, base_class): + class Unnamed(base_class): + def filter(self, _): + return True + + unnamed = Unnamed() + assert str(unnamed) == Unnamed.__name__ + + def test_update_type_message(self, update): + assert filters.UpdateType.MESSAGE.check_update(update) + assert not filters.UpdateType.EDITED_MESSAGE.check_update(update) + assert filters.UpdateType.MESSAGES.check_update(update) + assert not filters.UpdateType.CHANNEL_POST.check_update(update) + assert not filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) + assert not filters.UpdateType.CHANNEL_POSTS.check_update(update) + assert not filters.UpdateType.EDITED.check_update(update) + + def test_update_type_edited_message(self, update): + update.edited_message, update.message = update.message, update.edited_message + assert not filters.UpdateType.MESSAGE.check_update(update) + assert filters.UpdateType.EDITED_MESSAGE.check_update(update) + assert filters.UpdateType.MESSAGES.check_update(update) + assert not filters.UpdateType.CHANNEL_POST.check_update(update) + assert not filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) + assert not filters.UpdateType.CHANNEL_POSTS.check_update(update) + assert filters.UpdateType.EDITED.check_update(update) + + def test_update_type_channel_post(self, update): + update.channel_post, update.message = update.message, update.edited_message + assert not filters.UpdateType.MESSAGE.check_update(update) + assert not filters.UpdateType.EDITED_MESSAGE.check_update(update) + assert not filters.UpdateType.MESSAGES.check_update(update) + assert filters.UpdateType.CHANNEL_POST.check_update(update) + assert not filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) + assert filters.UpdateType.CHANNEL_POSTS.check_update(update) + assert not filters.UpdateType.EDITED.check_update(update) + + def test_update_type_edited_channel_post(self, update): + update.edited_channel_post, update.message = update.message, update.edited_message + assert not filters.UpdateType.MESSAGE.check_update(update) + assert not filters.UpdateType.EDITED_MESSAGE.check_update(update) + assert not filters.UpdateType.MESSAGES.check_update(update) + assert not filters.UpdateType.CHANNEL_POST.check_update(update) + assert filters.UpdateType.EDITED_CHANNEL_POST.check_update(update) + assert filters.UpdateType.CHANNEL_POSTS.check_update(update) + assert filters.UpdateType.EDITED.check_update(update) + + def test_merged_short_circuit_and(self, update, base_class): + update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + + class TestException(Exception): + pass + + class RaisingFilter(base_class): + def filter(self, _): + raise TestException + + raising_filter = RaisingFilter() + + with pytest.raises(TestException): + (filters.COMMAND & raising_filter).check_update(update) + + update.message.text = 'test' + update.message.entities = [] + (filters.COMMAND & raising_filter).check_update(update) + + def test_merged_filters_repr(self, update): + with pytest.raises(RuntimeError, match='Cannot set name'): + (filters.TEXT & filters.PHOTO).name = 'foo' + + def test_merged_short_circuit_or(self, update, base_class): + update.message.text = 'test' + + class TestException(Exception): + pass + + class RaisingFilter(base_class): + def filter(self, _): + raise TestException + + raising_filter = RaisingFilter() + + with pytest.raises(TestException): + (filters.COMMAND | raising_filter).check_update(update) + + update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + (filters.COMMAND | raising_filter).check_update(update) + + def test_merged_data_merging_and(self, update, base_class): + update.message.text = '/test' + update.message.entities = [MessageEntity(MessageEntity.BOT_COMMAND, 0, 5)] + + class DataFilter(base_class): + data_filter = True + + def __init__(self, data): + self.data = data + + def filter(self, _): + return {'test': [self.data]} + + result = (filters.COMMAND & DataFilter('blah')).check_update(update) + assert result['test'] == ['blah'] + + result = (DataFilter('blah1') & DataFilter('blah2')).check_update(update) + assert result['test'] == ['blah1', 'blah2'] + + update.message.text = 'test' + update.message.entities = [] + result = (filters.COMMAND & DataFilter('blah')).check_update(update) + assert not result + + def test_merged_data_merging_or(self, update, base_class): + update.message.text = '/test' + + class DataFilter(base_class): + data_filter = True + + def __init__(self, data): + self.data = data + + def filter(self, _): + return {'test': [self.data]} + + result = (filters.COMMAND | DataFilter('blah')).check_update(update) + assert result + + result = (DataFilter('blah1') | DataFilter('blah2')).check_update(update) + assert result['test'] == ['blah1'] + + update.message.text = 'test' + result = (filters.COMMAND | DataFilter('blah')).check_update(update) + assert result['test'] == ['blah'] + + def test_filters_via_bot_init(self): + with pytest.raises(RuntimeError, match='in conjunction with'): + filters.ViaBot(bot_id=1, username='bot') + + def test_filters_via_bot_allow_empty(self, update): + assert not filters.ViaBot().check_update(update) + assert filters.ViaBot(allow_empty=True).check_update(update) + + def test_filters_via_bot_id(self, update): + assert not filters.ViaBot(bot_id=1).check_update(update) + update.message.via_bot.id = 1 + assert filters.ViaBot(bot_id=1).check_update(update) + update.message.via_bot.id = 2 + assert filters.ViaBot(bot_id=[1, 2]).check_update(update) + assert not filters.ViaBot(bot_id=[3, 4]).check_update(update) + update.message.via_bot = None + assert not filters.ViaBot(bot_id=[3, 4]).check_update(update) + + def test_filters_via_bot_username(self, update): + assert not filters.ViaBot(username='bot').check_update(update) + assert not filters.ViaBot(username='Testbot').check_update(update) + update.message.via_bot.username = 'bot@' + assert filters.ViaBot(username='@bot@').check_update(update) + assert filters.ViaBot(username='bot@').check_update(update) + assert filters.ViaBot(username=['bot1', 'bot@', 'bot2']).check_update(update) + assert not filters.ViaBot(username=['@username', '@bot_2']).check_update(update) + update.message.via_bot = None + assert not filters.User(username=['@username', '@bot_2']).check_update(update) + + def test_filters_via_bot_change_id(self, update): + f = filters.ViaBot(bot_id=3) + assert f.bot_ids == {3} + update.message.via_bot.id = 3 + assert f.check_update(update) + update.message.via_bot.id = 2 + assert not f.check_update(update) + f.bot_ids = 2 + assert f.bot_ids == {2} + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.usernames = 'user' + + def test_filters_via_bot_change_username(self, update): + f = filters.ViaBot(username='bot') + update.message.via_bot.username = 'bot' + assert f.check_update(update) + update.message.via_bot.username = 'Bot' + assert not f.check_update(update) + f.usernames = 'Bot' + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='bot_id in conjunction'): + f.bot_ids = 1 + + def test_filters_via_bot_add_user_by_name(self, update): + users = ['bot_a', 'bot_b', 'bot_c'] + f = filters.ViaBot() + + for user in users: + update.message.via_bot.username = user + assert not f.check_update(update) + + f.add_usernames('bot_a') + f.add_usernames(['bot_b', 'bot_c']) + + for user in users: + update.message.via_bot.username = user + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='bot_id in conjunction'): + f.add_bot_ids(1) + + def test_filters_via_bot_add_user_by_id(self, update): + users = [1, 2, 3] + f = filters.ViaBot() + + for user in users: + update.message.via_bot.id = user + assert not f.check_update(update) + + f.add_bot_ids(1) + f.add_bot_ids([2, 3]) + + for user in users: + update.message.via_bot.username = user + assert f.check_update(update) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.add_usernames('bot') + + def test_filters_via_bot_remove_user_by_name(self, update): + users = ['bot_a', 'bot_b', 'bot_c'] + f = filters.ViaBot(username=users) + + with pytest.raises(RuntimeError, match='bot_id in conjunction'): + f.remove_bot_ids(1) + + for user in users: + update.message.via_bot.username = user + assert f.check_update(update) + + f.remove_usernames('bot_a') + f.remove_usernames(['bot_b', 'bot_c']) + + for user in users: + update.message.via_bot.username = user + assert not f.check_update(update) + + def test_filters_via_bot_remove_user_by_id(self, update): + users = [1, 2, 3] + f = filters.ViaBot(bot_id=users) + + with pytest.raises(RuntimeError, match='username in conjunction'): + f.remove_usernames('bot') + + for user in users: + update.message.via_bot.id = user + assert f.check_update(update) + + f.remove_bot_ids(1) + f.remove_bot_ids([2, 3]) + + for user in users: + update.message.via_bot.username = user + assert not f.check_update(update) + + def test_filters_via_bot_repr(self): + f = filters.ViaBot([1, 2]) + assert str(f) == 'filters.ViaBot(1, 2)' + f.remove_bot_ids(1) + f.remove_bot_ids(2) + assert str(f) == 'filters.ViaBot()' + f.add_usernames('@foobar') + assert str(f) == 'filters.ViaBot(foobar)' + f.add_usernames('@barfoo') + assert str(f).startswith('filters.ViaBot(') + # we don't know th exact order + assert 'barfoo' in str(f) and 'foobar' in str(f) + + with pytest.raises(RuntimeError, match='Cannot set name'): + f.name = 'foo' + + def test_filters_attachment(self, update): + assert not filters.ATTACHMENT.check_update(update) + # we need to define a new Update (or rather, message class) here because + # effective_attachment is only evaluated once per instance, and the filter relies on that + up = Update( + 0, + Message( + 0, + datetime.datetime.utcnow(), + Chat(0, 'private'), + document=Document("str", "other_str"), + ), + ) + assert filters.ATTACHMENT.check_update(up) From 1f564e3b702f0decf4e65f08bc5a142040ca1eda Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 16 Feb 2022 08:51:42 +0100 Subject: [PATCH 015/153] Try fixing stuck CI --- tests/test_request.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_request.py b/tests/test_request.py index ac6e868e086..ffc4d73736e 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -517,17 +517,20 @@ async def request(_, **kwargs): httpx_request.do_request(method='GET', url='URL'), ) - @pytest.mark.asyncio @flaky(3, 1) + @pytest.mark.asyncio async def test_do_request_wait_for_pool(self, monkeypatch, httpx_request): """The pool logic is buried rather deeply in httpxcore, so we make actual requests here instead of mocking""" - - task_1 = httpx_request.do_request( - method='GET', url='https://python-telegram-bot.org/static/testfiles/telegram.mp4' + task_1 = asyncio.create_task( + httpx_request.do_request( + method='GET', url='https://python-telegram-bot.org/static/testfiles/telegram.mp4' + ) ) - task_2 = httpx_request.do_request( - method='GET', url='https://python-telegram-bot.org/static/testfiles/telegram.mp4' + task_2 = asyncio.create_task( + httpx_request.do_request( + method='GET', url='https://python-telegram-bot.org/static/testfiles/telegram.mp4' + ) ) done, pending = await asyncio.wait({task_1, task_2}, return_when=asyncio.FIRST_COMPLETED) assert len(done) == len(pending) == 1 From d3d99cf19d123369fa9cf08dabe68a997f6de117 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 16 Feb 2022 22:15:53 +0100 Subject: [PATCH 016/153] Get started on builder tests --- telegram/_files/inputfile.py | 15 +-- telegram/_utils/files.py | 5 +- telegram/ext/_builders.py | 1 - tests/test_builders.py | 188 +++++++++++++++++++++++++++++++++++ 4 files changed, 198 insertions(+), 11 deletions(-) create mode 100644 tests/test_builders.py diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index 247089ce7cd..cc68ef622a1 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -82,10 +82,6 @@ def __init__(self, obj: Union[IO[bytes], bytes, str], filename: str = None): self.filename = filename or self.mimetype.replace('/', '.') - @property - def field_tuple(self) -> FieldTuple: # skipcq: PY-D0003 - return self.filename, self.input_file_content, self.mimetype - @staticmethod def is_image(stream: bytes) -> Optional[str]: """Check if the content file is an image by analyzing its headers. @@ -109,9 +105,14 @@ def is_image(stream: bytes) -> Optional[str]: ) return None - @staticmethod - def is_file(obj: object) -> bool: # skipcq: PY-D0003 - return hasattr(obj, 'read') + @property + def field_tuple(self) -> FieldTuple: + """Field tuple representing the contents of the file for upload to the Telegram servers. + + Returns: + Tuple[:obj:`str`, :obj:`bytes`, :obj:`str`]: + """ + return self.filename, self.input_file_content, self.mimetype @property def attach_uri(self) -> str: diff --git a/telegram/_utils/files.py b/telegram/_utils/files.py index dc2d97d70d3..b074d99f2fd 100644 --- a/telegram/_utils/files.py +++ b/telegram/_utils/files.py @@ -95,9 +95,8 @@ def parse_file_input( return out if isinstance(file_input, bytes): return InputFile(file_input, filename=filename) - if InputFile.is_file(file_input): - file_input = cast(IO, file_input) - return InputFile(file_input, filename=filename) + if hasattr(file_input, 'read'): + return InputFile(cast(IO, file_input), filename=filename) if tg_type and isinstance(file_input, tg_type): return file_input.file_id # type: ignore[attr-defined] return file_input diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index d330bfb4099..f213aeccec9 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -77,7 +77,6 @@ _BOT_CHECKS = [ ('updater', 'Updater instance'), ('request', 'Request instance'), - ('request_kwargs', 'request_kwargs'), ('base_file_url', 'base_file_url'), ('base_url', 'base_url'), ('token', 'token'), diff --git a/tests/test_builders.py b/tests/test_builders.py new file mode 100644 index 00000000000..e97e4a8aa31 --- /dev/null +++ b/tests/test_builders.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +from pathlib import Path + +import pytest + +from telegram.request import HTTPXRequest +from .conftest import PRIVATE_KEY, data_file + +from telegram.ext import ( + ApplicationBuilder, + Defaults, + JobQueue, + PicklePersistence, + ContextTypes, + Application, + Updater, +) +from telegram.ext._builders import _BOT_CHECKS + + +@pytest.fixture(scope='function') +def builder(): + return ApplicationBuilder() + + +class TestApplicationBuilder: + @pytest.mark.parametrize( + 'method, description', _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS] + ) + def test_mutually_exclusive_for_bot(self, builder, method, description): + # First test that e.g. `bot` can't be set if `request` was already set + # We pass the private key since `private_key` is the only method that doesn't just save + # the passed value + getattr(builder, method)(data_file('private.key')) + with pytest.raises(RuntimeError, match=f'`bot` may only be set, if no {description}'): + builder.bot(None) + + # Now test that `request` can't be set if `bot` was already set + builder = builder.__class__() + builder.bot(None) + with pytest.raises(RuntimeError, match=f'`{method}` may only be set, if no bot instance'): + getattr(builder, method)(None) + + # def test_mutually_exclusive_for_request(self, builder): + # builder.request(None) + # with pytest.raises( + # RuntimeError, match='`request_kwargs` may only be set, if no Request instance' + # ): + # builder.request_kwargs(None) + # + # builder = builder.__class__() + # builder.request_kwargs(None) + # with pytest.raises(RuntimeError, match='`request` may only be set, if no request_kwargs'): + # builder.request(None) + # + # def test_build_without_token(self, builder): + # with pytest.raises(RuntimeError, match='No bot token was set.'): + # builder.build() + # + # def test_build_custom_bot(self, builder, bot): + # builder.bot(bot) + # obj = builder.build() + # assert obj.bot is bot + # + # if isinstance(obj, Updater): + # assert obj.dispatcher.bot is bot + # assert obj.dispatcher.job_queue.dispatcher is obj.dispatcher + # assert obj.exception_event is obj.dispatcher.exception_event + # + # def test_all_bot_args_custom(self, builder, bot): + # defaults = Defaults() + # request = HTTPXRequest(connection_pool_size=8) + # builder.token(bot.token).base_url('base_url').base_file_url('base_file_url').private_key( + # PRIVATE_KEY + # ).defaults(defaults).arbitrary_callback_data(42).request(request) + # built_bot = builder.build().bot + # + # assert built_bot.token == bot.token + # assert built_bot.base_url == 'base_url' + bot.token + # assert built_bot.base_file_url == 'base_file_url' + bot.token + # assert built_bot.defaults is defaults + # assert built_bot.request is request + # assert built_bot.callback_data_cache.maxsize == 42 + # + # builder = builder.__class__() + # builder.token(bot.token).request_kwargs({'connect_timeout': 42}) + # built_bot = builder.build().bot + # + # assert built_bot.token == bot.token + # assert built_bot.request._connect_timeout == 42 + # + # def test_all_dispatcher_args_custom(self, app, builder): + # job_queue = JobQueue() + # persistence = PicklePersistence('filename') + # context_types = ContextTypes() + # builder.bot(app.bot).update_queue(app.update_queue).exception_event( + # app.exception_event + # ).job_queue(job_queue).persistence(persistence).context_types(context_types).workers(3) + # dispatcher = builder.build() + # + # assert dispatcher.bot is app.bot + # assert dispatcher.update_queue is app.update_queue + # assert dispatcher.exception_event is app.exception_event + # assert dispatcher.job_queue is job_queue + # assert dispatcher.job_queue.dispatcher is dispatcher + # assert dispatcher.persistence is persistence + # assert dispatcher.context_types is context_types + # assert dispatcher.workers == 3 + # + # def test_all_updater_args_custom(self, app, builder): + # updater = ( + # builder.bot(app.bot) + # .exception_event(app.exception_event) + # .update_queue(app.update_queue) + # .user_signal_handler(42) + # .build() + # ) + # + # assert updater.dispatcher is None + # assert updater.bot is app.bot + # assert updater.exception_event is app.exception_event + # assert updater.update_queue is app.update_queue + # assert updater.user_signal_handler == 42 + # + # def test_connection_pool_size_with_workers(self, bot, builder): + # app = builder.token(bot.token).workers(42).build() + # assert app.workers == 42 + # assert app.bot.request.con_pool_size == 46 + # + # def test_connection_pool_size_warning(self, bot, builder, recwarn): + # builder.token(bot.token).workers(42).request_kwargs({'con_pool_size': 1}) + # app = builder.build() + # assert app.workers == 42 + # assert app.bot.request.con_pool_size == 1 + # + # assert len(recwarn) == 1 + # message = str(recwarn[-1].message) + # assert 'smaller (1)' in message + # assert 'recommended value of 46.' in message + # assert recwarn[-1].filename == __file__, "wrong stacklevel" + # + # def test_custom_classes(self, bot, builder): + # class CustomApplication(Application): + # def __init__(self, arg, **kwargs): + # super().__init__(**kwargs) + # self.arg = arg + # + # builder.application_class(CustomApplication, kwargs={'arg': 2}).token(bot.token) + # + # obj = builder.build() + # assert isinstance(obj, CustomApplication) + # assert obj.arg == 2 + # + # @pytest.mark.parametrize('input_type', ('bytes', 'str', 'Path')) + # def test_all_private_key_input_types(self, builder, bot, input_type): + # private_key = Path('tests/data/private.key') + # password = Path('tests/data/private_key.password') + # + # if input_type == 'bytes': + # private_key = private_key.read_bytes() + # password = password.read_bytes() + # if input_type == 'str': + # private_key = str(private_key) + # password = str(password) + # + # builder.token(bot.token).private_key( + # private_key=private_key, + # password=password, + # ) + # bot = builder.build().bot + # assert bot.private_key From ac07df1e939359c11852e6ddd788a9a6883a9117 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 17 Feb 2022 19:24:24 +0100 Subject: [PATCH 017/153] Builder tests --- telegram/ext/_builders.py | 63 ++++++++++++++++++++------------- tests/test_builders.py | 74 +++++++++++++++++++++++++++------------ 2 files changed, 90 insertions(+), 47 deletions(-) diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index f213aeccec9..804a9b14903 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -76,7 +76,8 @@ _BOT_CHECKS = [ ('updater', 'Updater instance'), - ('request', 'Request instance'), + ('request', 'request instance'), + ('get_updates_request', 'get_updates_request instance'), ('base_file_url', 'base_file_url'), ('base_url', 'base_url'), ('token', 'token'), @@ -351,23 +352,28 @@ def base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: return self def _request_check(self, get_updates: bool) -> None: - name = 'get_updates_request' if get_updates else 'request' + prefix = 'get_updates_' if get_updates else '' + name = prefix + 'request' for attr in ('connect_timeout', 'read_timeout', 'write_timeout', 'pool_timeout'): - if not isinstance(getattr(self, f"_{attr}"), DefaultValue): + if not isinstance(getattr(self, f"_{prefix}{attr}"), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, attr)) - if self._connection_pool_size is not None: + if getattr(self, f'_{prefix}connection_pool_size') is not None: raise RuntimeError(_TWO_ARGS_REQ.format(name, 'connection_pool_size')) - if self._proxy_url is not None: + if getattr(self, f'_{prefix}proxy_url') is not None: raise RuntimeError(_TWO_ARGS_REQ.format(name, 'proxy_url')) if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format(name, 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format(name, 'updater instance')) - def _request_param_check(self, get_updates: bool) -> None: + def _request_param_check(self, name: str, get_updates: bool) -> None: if get_updates and self._get_updates_request is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('get_updates_request', 'bot instance')) + raise RuntimeError( + _TWO_ARGS_REQ.format(f'get_updates_{name}', 'get_updates_request instance') + ) if self._request is not DEFAULT_NONE: - raise RuntimeError(_TWO_ARGS_REQ.format('request', 'bot instance')) + raise RuntimeError(_TWO_ARGS_REQ.format(name, 'request instance')) if self._bot is not DEFAULT_NONE: raise RuntimeError( @@ -393,36 +399,36 @@ def request(self: BuilderType, request: BaseRequest) -> BuilderType: return self def connection_pool_size(self: BuilderType, connection_pool_size: int) -> BuilderType: - self._request_param_check(get_updates=False) + self._request_param_check(name='connection_pool_size', get_updates=False) self._connection_pool_size = connection_pool_size return self def proxy_url(self: BuilderType, proxy_url: str) -> BuilderType: - self._request_param_check(get_updates=False) + self._request_param_check(name='proxy_url', get_updates=False) self._proxy_url = proxy_url return self def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType: - self._request_param_check(get_updates=False) + self._request_param_check(name='connect_timeout', get_updates=False) self._connect_timeout = connect_timeout return self def read_timeout(self: BuilderType, read_timeout: Optional[float]) -> BuilderType: - self._request_param_check(get_updates=False) + self._request_param_check(name='read_timeout', get_updates=False) self._read_timeout = read_timeout return self def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderType: - self._request_param_check(get_updates=False) + self._request_param_check(name='write_timeout', get_updates=False) self._write_timeout = write_timeout return self def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderType: - self._request_param_check(get_updates=False) + self._request_param_check(name='pool_timeout', get_updates=False) self._pool_timeout = pool_timeout return self - def get_updates_request(self: BuilderType, request: BaseRequest) -> BuilderType: + def get_updates_request(self: BuilderType, get_updates_request: BaseRequest) -> BuilderType: """Sets a :class:`telegram.request.BaseRequest` object to be used for the :paramref:`~telegram.Bot.get_updates_request` parameter of :attr:`telegram.ext.Application.bot`. @@ -430,52 +436,52 @@ def get_updates_request(self: BuilderType, request: BaseRequest) -> BuilderType: .. seealso:: :meth:`request` Args: - request (:class:`telegram.request.BaseRequest`): The request object. + get_updates_request (:class:`telegram.request.BaseRequest`): The request object. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. """ self._request_check(get_updates=True) - self._request = request + self._get_updates_request = get_updates_request return self def get_updates_connection_pool_size( self: BuilderType, get_updates_connection_pool_size: int ) -> BuilderType: - self._request_param_check(get_updates=True) + self._request_param_check(name='connection_pool_size', get_updates=True) self._get_updates_connection_pool_size = get_updates_connection_pool_size return self def get_updates_proxy_url(self: BuilderType, get_updates_proxy_url: str) -> BuilderType: - self._request_param_check(get_updates=True) + self._request_param_check(name='proxy_url', get_updates=True) self._get_updates_proxy_url = get_updates_proxy_url return self def get_updates_connect_timeout( self: BuilderType, get_updates_connect_timeout: Optional[float] ) -> BuilderType: - self._request_param_check(get_updates=True) + self._request_param_check(name='connect_timeout', get_updates=True) self._get_updates_connect_timeout = get_updates_connect_timeout return self def get_updates_read_timeout( self: BuilderType, get_updates_read_timeout: Optional[float] ) -> BuilderType: - self._request_param_check(get_updates=True) + self._request_param_check(name='read_timeout', get_updates=True) self._get_updates_read_timeout = get_updates_read_timeout return self def get_updates_write_timeout( self: BuilderType, get_updates_write_timeout: Optional[float] ) -> BuilderType: - self._request_param_check(get_updates=True) + self._request_param_check(name='write_timeout', get_updates=True) self._get_updates_write_timeout = get_updates_write_timeout return self def get_updates_pool_timeout( self: BuilderType, get_updates_pool_timeout: Optional[float] ) -> BuilderType: - self._request_param_check(get_updates=True) + self._request_param_check(name='pool_timeout', get_updates=True) self._get_updates_pool_timeout = get_updates_pool_timeout return self @@ -694,7 +700,9 @@ def context_types( def updater(self: BuilderType, updater: Union[Updater, None]) -> BuilderType: """Sets a :class:`telegram.ext.Updater` instance to be used for - :attr:`telegram.ext.Application.updater`. + :attr:`telegram.ext.Application.updater`. The :attr:`telegram.ext.Updater.bot` and + :attr:`telegram.ext.Updater.update_queue` be used for :attr:`telegram.ext.Application.bot` + and :attr:`telegram.ext.Application.update_queue`, respectively. Args: updater (:class:`telegram.ext.Updater` | :obj:`None`): The updater instance or @@ -703,7 +711,12 @@ def updater(self: BuilderType, updater: Union[Updater, None]) -> BuilderType: Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. """ - for attr, error in (self._bot, 'bot instance'), (self._update_queue, 'update queue'): + for attr, error in ( + (self._bot, 'bot instance'), + (self._request, 'request instance'), + (self._get_updates_request, 'get_updates_request instance'), + (self._update_queue, 'update queue'), + ): if not isinstance(attr, DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format('updater', error)) diff --git a/tests/test_builders.py b/tests/test_builders.py index e97e4a8aa31..ecbabaafb14 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -16,21 +16,12 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -from pathlib import Path - import pytest -from telegram.request import HTTPXRequest -from .conftest import PRIVATE_KEY, data_file +from .conftest import data_file from telegram.ext import ( ApplicationBuilder, - Defaults, - JobQueue, - PicklePersistence, - ContextTypes, - Application, - Updater, ) from telegram.ext._builders import _BOT_CHECKS @@ -58,18 +49,57 @@ def test_mutually_exclusive_for_bot(self, builder, method, description): with pytest.raises(RuntimeError, match=f'`{method}` may only be set, if no bot instance'): getattr(builder, method)(None) - # def test_mutually_exclusive_for_request(self, builder): - # builder.request(None) - # with pytest.raises( - # RuntimeError, match='`request_kwargs` may only be set, if no Request instance' - # ): - # builder.request_kwargs(None) - # - # builder = builder.__class__() - # builder.request_kwargs(None) - # with pytest.raises(RuntimeError, match='`request` may only be set, if no request_kwargs'): - # builder.request(None) - # + def test_mutually_exclusive_for_request(self, builder): + builder.request(None) + methods = ( + 'connection_pool_size', + 'connect_timeout', + 'pool_timeout', + 'read_timeout', + 'write_timeout', + 'proxy_url', + 'bot', + 'updater', + ) + + for method in methods: + with pytest.raises( + RuntimeError, match=f'`{method}` may only be set, if no request instance' + ): + getattr(builder, method)(None) + + for method in methods: + builder = ApplicationBuilder() + getattr(builder, method)(1) + with pytest.raises(RuntimeError, match='`request` may only be set, if no'): + builder.request(None) + + def test_mutually_exclusive_for_get_updates_request(self, builder): + builder.get_updates_request(None) + methods = ( + 'get_updates_connection_pool_size', + 'get_updates_connect_timeout', + 'get_updates_pool_timeout', + 'get_updates_read_timeout', + 'get_updates_write_timeout', + 'get_updates_proxy_url', + 'bot', + 'updater', + ) + + for method in methods: + with pytest.raises( + RuntimeError, + match=f'`{method}` may only be set, if no get_updates_request instance', + ): + getattr(builder, method)(None) + + for method in methods: + builder = ApplicationBuilder() + getattr(builder, method)(1) + with pytest.raises(RuntimeError, match='`get_updates_request` may only be set, if no'): + builder.get_updates_request(None) + # def test_build_without_token(self, builder): # with pytest.raises(RuntimeError, match='No bot token was set.'): # builder.build() From ff04f6554a910b44adc669096f0a287afe905480 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 17 Feb 2022 22:33:14 +0100 Subject: [PATCH 018/153] Builder tests --- tests/test_bot.py | 2 +- tests/test_builders.py | 77 ++++++++++++++++++++++-------------------- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/tests/test_bot.py b/tests/test_bot.py index d020c884402..f6dda21d70e 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -1664,7 +1664,7 @@ async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot): url, max_connections=max_connections, allowed_updates=allowed_updates, - ip_address='192.0.2.142', + ip_address='198.51.100.127', ) await asyncio.sleep(2) live_info = await bot.get_webhook_info() diff --git a/tests/test_builders.py b/tests/test_builders.py index ecbabaafb14..fbfe60d41b9 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -18,10 +18,12 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. import pytest -from .conftest import data_file +from telegram.request import HTTPXRequest +from .conftest import data_file, PRIVATE_KEY from telegram.ext import ( ApplicationBuilder, + Defaults, ) from telegram.ext._builders import _BOT_CHECKS @@ -100,41 +102,44 @@ def test_mutually_exclusive_for_get_updates_request(self, builder): with pytest.raises(RuntimeError, match='`get_updates_request` may only be set, if no'): builder.get_updates_request(None) - # def test_build_without_token(self, builder): - # with pytest.raises(RuntimeError, match='No bot token was set.'): - # builder.build() - # - # def test_build_custom_bot(self, builder, bot): - # builder.bot(bot) - # obj = builder.build() - # assert obj.bot is bot - # - # if isinstance(obj, Updater): - # assert obj.dispatcher.bot is bot - # assert obj.dispatcher.job_queue.dispatcher is obj.dispatcher - # assert obj.exception_event is obj.dispatcher.exception_event - # - # def test_all_bot_args_custom(self, builder, bot): - # defaults = Defaults() - # request = HTTPXRequest(connection_pool_size=8) - # builder.token(bot.token).base_url('base_url').base_file_url('base_file_url').private_key( - # PRIVATE_KEY - # ).defaults(defaults).arbitrary_callback_data(42).request(request) - # built_bot = builder.build().bot - # - # assert built_bot.token == bot.token - # assert built_bot.base_url == 'base_url' + bot.token - # assert built_bot.base_file_url == 'base_file_url' + bot.token - # assert built_bot.defaults is defaults - # assert built_bot.request is request - # assert built_bot.callback_data_cache.maxsize == 42 - # - # builder = builder.__class__() - # builder.token(bot.token).request_kwargs({'connect_timeout': 42}) - # built_bot = builder.build().bot - # - # assert built_bot.token == bot.token - # assert built_bot.request._connect_timeout == 42 + def test_build_without_token(self, builder): + with pytest.raises(RuntimeError, match='No bot token was set.'): + builder.build() + + def test_build_custom_bot(self, builder, bot): + builder.bot(bot) + app = builder.build() + assert app.bot is bot + assert app.updater.bot is bot + + def test_all_bot_args_custom(self, builder, bot): + defaults = Defaults() + request = HTTPXRequest() + get_updates_request = HTTPXRequest() + builder.token(bot.token).base_url('base_url').base_file_url('base_file_url').private_key( + PRIVATE_KEY + ).defaults(defaults).arbitrary_callback_data(42).request(request).get_updates_request( + get_updates_request + ) + built_bot = builder.build().bot + + assert built_bot.token == bot.token + assert built_bot.base_url == 'base_url' + bot.token + assert built_bot.base_file_url == 'base_file_url' + bot.token + assert built_bot.defaults is defaults + assert built_bot.request is request + assert built_bot._request[0] is get_updates_request + assert built_bot.callback_data_cache.maxsize == 42 + + builder = ApplicationBuilder() + builder.connection_pool_size(1).connect_timeout(2).pool_timeout(3) + # TODO: This test is not finished + # builder.token(bot.token).request_kwargs({'connect_timeout': 42}) + # built_bot = builder.build().bot + # + # assert built_bot.token == bot.token + # assert built_bot.request._connect_timeout == 42 + # # def test_all_dispatcher_args_custom(self, app, builder): # job_queue = JobQueue() From 21c7cb7a3ffbad45482a6fba9727cd76020a5c47 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Feb 2022 23:36:41 +0100 Subject: [PATCH 019/153] Builder tests finished --- telegram/ext/_application.py | 5 +- telegram/ext/_builders.py | 72 +++++-- tests/test_builders.py | 395 +++++++++++++++++++++++------------ 3 files changed, 325 insertions(+), 147 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index be5bbfef4f2..61d0050c2a2 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -222,7 +222,7 @@ def __init__( if concurrent_updates is True: concurrent_updates = 4096 self._concurrent_updates_sem = asyncio.BoundedSemaphore(concurrent_updates or 1) - self._concurrent_updates = bool(concurrent_updates) + self._concurrent_updates: int = concurrent_updates or 0 if self.job_queue: self.job_queue.set_application(self) @@ -270,7 +270,8 @@ def running(self) -> bool: return self._running @property - def concurrent_updates(self) -> bool: + def concurrent_updates(self) -> int: + """0 == not concurrent""" return self._concurrent_updates async def initialize(self) -> None: diff --git a/telegram/ext/_builders.py b/telegram/ext/_builders.py index 804a9b14903..c4471ba6f41 100644 --- a/telegram/ext/_builders.py +++ b/telegram/ext/_builders.py @@ -75,13 +75,24 @@ _BOT_CHECKS = [ - ('updater', 'Updater instance'), ('request', 'request instance'), + ('connection_pool_size', 'connection_pool_size'), + ('proxy_url', 'proxy_url'), + ('pool_timeout', 'pool_timeout'), + ('connect_timeout', 'connect_timeout'), + ('read_timeout', 'read_timeout'), + ('write_timeout', 'write_timeout'), + ('get_updates_connection_pool_size', 'get_updates_connection_pool_size'), + ('get_updates_proxy_url', 'get_updates_proxy_url'), + ('get_updates_pool_timeout', 'get_updates_pool_timeout'), + ('get_updates_connect_timeout', 'get_updates_connect_timeout'), + ('get_updates_read_timeout', 'get_updates_read_timeout'), + ('get_updates_write_timeout', 'get_updates_write_timeout'), ('get_updates_request', 'get_updates_request instance'), ('base_file_url', 'base_file_url'), ('base_url', 'base_url'), ('token', 'token'), - ('defaults', 'Defaults instance'), + ('defaults', 'defaults'), ('arbitrary_callback_data', 'arbitrary_callback_data'), ('private_key', 'private_key'), ] @@ -152,15 +163,15 @@ def __init__(self: 'InitApplicationBuilder'): self._token: DVInput[str] = DefaultValue('') self._base_url: DVInput[str] = DefaultValue('https://api.telegram.org/bot') self._base_file_url: DVInput[str] = DefaultValue('https://api.telegram.org/file/bot') - self._connection_pool_size: Optional[int] = None - self._proxy_url: Optional[str] = None + self._connection_pool_size: DVInput[int] = DEFAULT_NONE + self._proxy_url: DVInput[str] = DEFAULT_NONE self._connect_timeout: ODVInput[float] = DEFAULT_NONE self._read_timeout: ODVInput[float] = DEFAULT_NONE self._write_timeout: ODVInput[float] = DEFAULT_NONE self._pool_timeout: ODVInput[float] = DEFAULT_NONE self._request: DVInput['BaseRequest'] = DEFAULT_NONE - self._get_updates_connection_pool_size: Optional[int] = None - self._get_updates_proxy_url: Optional[str] = None + self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE + self._get_updates_proxy_url: DVInput[str] = DEFAULT_NONE self._get_updates_connect_timeout: ODVInput[float] = DEFAULT_NONE self._get_updates_read_timeout: ODVInput[float] = DEFAULT_NONE self._get_updates_write_timeout: ODVInput[float] = DEFAULT_NONE @@ -185,11 +196,15 @@ def _build_request(self, get_updates: bool) -> BaseRequest: if not isinstance(getattr(self, f'{prefix}request'), DefaultValue): return getattr(self, f'{prefix}request') - proxy_url = getattr(self, f'{prefix}proxy_url') + proxy_url = DefaultValue.get_value(getattr(self, f'{prefix}proxy_url')) if get_updates: - connection_pool_size = getattr(self, f'{prefix}connection_pool_size') or 1 + connection_pool_size = ( + DefaultValue.get_value(getattr(self, f'{prefix}connection_pool_size')) or 1 + ) else: - connection_pool_size = getattr(self, f'{prefix}connection_pool_size') or 128 + connection_pool_size = ( + DefaultValue.get_value(getattr(self, f'{prefix}connection_pool_size')) or 128 + ) timeouts = dict( connect_timeout=getattr(self, f'{prefix}connect_timeout'), @@ -249,6 +264,8 @@ def build( bot = self._updater.bot update_queue = self._updater.update_queue + print(self._concurrent_updates) + application: Application[ BT, CCT, UD, CD, BD, JQ ] = DefaultValue.get_value( # type: ignore[call-arg] # pylint: disable=not-callable @@ -310,6 +327,8 @@ def token(self: BuilderType, token: str) -> BuilderType: """ if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format('token', 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('token', 'updater')) self._token = token return self @@ -329,6 +348,8 @@ def base_url(self: BuilderType, base_url: str) -> BuilderType: """ if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format('base_url', 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('base_url', 'updater')) self._base_url = base_url return self @@ -348,6 +369,8 @@ def base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: """ if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format('base_file_url', 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('base_file_url', 'updater')) self._base_file_url = base_file_url return self @@ -358,9 +381,9 @@ def _request_check(self, get_updates: bool) -> None: for attr in ('connect_timeout', 'read_timeout', 'write_timeout', 'pool_timeout'): if not isinstance(getattr(self, f"_{prefix}{attr}"), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, attr)) - if getattr(self, f'_{prefix}connection_pool_size') is not None: + if not isinstance(getattr(self, f'_{prefix}connection_pool_size'), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, 'connection_pool_size')) - if getattr(self, f'_{prefix}proxy_url') is not None: + if not isinstance(getattr(self, f'_{prefix}proxy_url'), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, 'proxy_url')) if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format(name, 'bot instance')) @@ -378,10 +401,15 @@ def _request_param_check(self, name: str, get_updates: bool) -> None: if self._bot is not DEFAULT_NONE: raise RuntimeError( _TWO_ARGS_REQ.format( - 'get_updates_request' if get_updates else 'request', 'bot instance' + f'get_updates_{name}' if get_updates else name, 'bot instance' ) ) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError( + _TWO_ARGS_REQ.format(f'get_updates_{name}' if get_updates else name, 'updater') + ) + def request(self: BuilderType, request: BaseRequest) -> BuilderType: """Sets a :class:`telegram.request.BaseRequest` object to be used for the ``request`` parameter of :attr:`telegram.ext.Application.bot`. @@ -510,6 +538,8 @@ def private_key( """ if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format('private_key', 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('private_key', 'updater')) self._private_key = ( private_key if isinstance(private_key, bytes) else Path(private_key).read_bytes() @@ -536,6 +566,8 @@ def defaults(self: BuilderType, defaults: 'Defaults') -> BuilderType: """ if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format('defaults', 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('defaults', 'updater')) self._defaults = defaults return self @@ -562,6 +594,8 @@ def arbitrary_callback_data( """ if self._bot is not DEFAULT_NONE: raise RuntimeError(_TWO_ARGS_REQ.format('arbitrary_callback_data', 'bot instance')) + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('arbitrary_callback_data', 'updater')) self._arbitrary_callback_data = arbitrary_callback_data return self @@ -579,6 +613,8 @@ def bot( Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. """ + if self._updater not in (DEFAULT_NONE, None): + raise RuntimeError(_TWO_ARGS_REQ.format('bot', 'updater')) for attr, error in _BOT_CHECKS: if not isinstance(getattr(self, f'_{attr}'), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format('bot', error)) @@ -599,7 +635,7 @@ def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. """ - if isinstance(self._updater, DefaultValue): + if self._updater not in (DEFAULT_NONE, None): raise RuntimeError(_TWO_ARGS_REQ.format('update_queue', 'updater instance')) self._update_queue = update_queue return self @@ -711,14 +747,22 @@ def updater(self: BuilderType, updater: Union[Updater, None]) -> BuilderType: Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. """ + if updater is None: + self._updater = updater + return self + for attr, error in ( (self._bot, 'bot instance'), (self._request, 'request instance'), (self._get_updates_request, 'get_updates_request instance'), - (self._update_queue, 'update queue'), + (self._update_queue, 'update_queue'), ): if not isinstance(attr, DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format('updater', error)) + for attr_name, error in _BOT_CHECKS: + if not isinstance(getattr(self, f'_{attr_name}'), DefaultValue): + raise RuntimeError(_TWO_ARGS_REQ.format('updater', error)) + self._updater = updater return self diff --git a/tests/test_builders.py b/tests/test_builders.py index fbfe60d41b9..aa59e352357 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -16,6 +16,10 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio +from dataclasses import dataclass + +import httpx import pytest from telegram.request import HTTPXRequest @@ -24,6 +28,12 @@ from telegram.ext import ( ApplicationBuilder, Defaults, + Application, + JobQueue, + PicklePersistence, + ContextTypes, + Updater, + ExtBot, ) from telegram.ext._builders import _BOT_CHECKS @@ -34,6 +44,64 @@ def builder(): class TestApplicationBuilder: + def test_build_without_token(self, builder): + with pytest.raises(RuntimeError, match='No bot token was set.'): + builder.build() + + def test_build_custom_bot(self, builder, bot): + builder.bot(bot) + app = builder.build() + assert app.bot is bot + assert app.updater.bot is bot + + def test_default_values(self, bot, monkeypatch, builder): + @dataclass + class Client: + timeout: object + proxies: object + limits: object + + monkeypatch.setattr(httpx, 'AsyncClient', Client) + + app = builder.token(bot.token).build() + + assert isinstance(app, Application) + assert app.concurrent_updates == 0 + + assert isinstance(app.bot, ExtBot) + assert isinstance(app.bot.request, HTTPXRequest) + assert 'api.telegram.org' in app.bot.base_url + assert bot.token in app.bot.base_url + assert 'api.telegram.org' in app.bot.base_file_url + assert bot.token in app.bot.base_file_url + assert app.bot.private_key is None + assert app.bot.arbitrary_callback_data is False + assert app.bot.defaults is None + + get_updates_client = app.bot._request[0]._client + assert get_updates_client.limits == httpx.Limits( + max_connections=1, max_keepalive_connections=1 + ) + assert get_updates_client.proxies is None + assert get_updates_client.timeout == httpx.Timeout( + connect=5.0, read=5.0, write=5.0, pool=1.0 + ) + + client = app.bot.request._client + assert client.limits == httpx.Limits(max_connections=128, max_keepalive_connections=128) + assert client.proxies is None + assert client.timeout == httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0) + + assert isinstance(app.update_queue, asyncio.Queue) + assert isinstance(app.updater, Updater) + assert app.updater.bot is app.bot + assert app.updater.update_queue is app.update_queue + + assert isinstance(app.job_queue, JobQueue) + assert app.job_queue.application is app + + assert app.persistence is None + @pytest.mark.parametrize( 'method, description', _BOT_CHECKS, ids=[entry[0] for entry in _BOT_CHECKS] ) @@ -49,11 +117,11 @@ def test_mutually_exclusive_for_bot(self, builder, method, description): builder = builder.__class__() builder.bot(None) with pytest.raises(RuntimeError, match=f'`{method}` may only be set, if no bot instance'): - getattr(builder, method)(None) + getattr(builder, method)(data_file('private.key')) - def test_mutually_exclusive_for_request(self, builder): - builder.request(None) - methods = ( + @pytest.mark.parametrize( + 'method', + ( 'connection_pool_size', 'connect_timeout', 'pool_timeout', @@ -62,23 +130,24 @@ def test_mutually_exclusive_for_request(self, builder): 'proxy_url', 'bot', 'updater', - ) + ), + ) + def test_mutually_exclusive_for_request(self, builder, method): + builder.request(1) + + with pytest.raises( + RuntimeError, match=f'`{method}` may only be set, if no request instance' + ): + getattr(builder, method)(data_file('private.key')) + + builder = ApplicationBuilder() + getattr(builder, method)(1) + with pytest.raises(RuntimeError, match='`request` may only be set, if no'): + builder.request(1) - for method in methods: - with pytest.raises( - RuntimeError, match=f'`{method}` may only be set, if no request instance' - ): - getattr(builder, method)(None) - - for method in methods: - builder = ApplicationBuilder() - getattr(builder, method)(1) - with pytest.raises(RuntimeError, match='`request` may only be set, if no'): - builder.request(None) - - def test_mutually_exclusive_for_get_updates_request(self, builder): - builder.get_updates_request(None) - methods = ( + @pytest.mark.parametrize( + 'method', + ( 'get_updates_connection_pool_size', 'get_updates_connect_timeout', 'get_updates_pool_timeout', @@ -87,32 +156,87 @@ def test_mutually_exclusive_for_get_updates_request(self, builder): 'get_updates_proxy_url', 'bot', 'updater', - ) + ), + ) + def test_mutually_exclusive_for_get_updates_request(self, builder, method): + builder.get_updates_request(1) - for method in methods: - with pytest.raises( - RuntimeError, - match=f'`{method}` may only be set, if no get_updates_request instance', - ): - getattr(builder, method)(None) + with pytest.raises( + RuntimeError, + match=f'`{method}` may only be set, if no get_updates_request instance', + ): + getattr(builder, method)(data_file('private.key')) - for method in methods: - builder = ApplicationBuilder() - getattr(builder, method)(1) - with pytest.raises(RuntimeError, match='`get_updates_request` may only be set, if no'): - builder.get_updates_request(None) + builder = ApplicationBuilder() + getattr(builder, method)(1) + with pytest.raises(RuntimeError, match='`get_updates_request` may only be set, if no'): + builder.get_updates_request(1) - def test_build_without_token(self, builder): - with pytest.raises(RuntimeError, match='No bot token was set.'): - builder.build() + @pytest.mark.parametrize( + 'method', + [ + 'get_updates_connection_pool_size', + 'get_updates_connect_timeout', + 'get_updates_pool_timeout', + 'get_updates_read_timeout', + 'get_updates_write_timeout', + 'get_updates_proxy_url', + 'connection_pool_size', + 'connect_timeout', + 'pool_timeout', + 'read_timeout', + 'write_timeout', + 'proxy_url', + 'bot', + 'update_queue', + ] + + [entry[0] for entry in _BOT_CHECKS], + ) + def test_mutually_exclusive_for_updater(self, builder, method): + builder.updater(1) - def test_build_custom_bot(self, builder, bot): - builder.bot(bot) - app = builder.build() - assert app.bot is bot - assert app.updater.bot is bot + with pytest.raises( + RuntimeError, + match=f'`{method}` may only be set, if no updater', + ): + getattr(builder, method)(data_file('private.key')) - def test_all_bot_args_custom(self, builder, bot): + builder = ApplicationBuilder() + getattr(builder, method)(data_file('private.key')) + with pytest.raises(RuntimeError, match=f'`updater` may only be set, if no {method}'): + builder.updater(1) + + @pytest.mark.parametrize( + 'method', + [ + 'get_updates_connection_pool_size', + 'get_updates_connect_timeout', + 'get_updates_pool_timeout', + 'get_updates_read_timeout', + 'get_updates_write_timeout', + 'get_updates_proxy_url', + 'connection_pool_size', + 'connect_timeout', + 'pool_timeout', + 'read_timeout', + 'write_timeout', + 'proxy_url', + 'bot', + ] + + [entry[0] for entry in _BOT_CHECKS], + ) + def test_mutually_non_exclusive_for_updater(self, builder, method): + # If no updater is to be used, all these parameters should be settable + # Since the parameters themself are tested in the other tests, we here just make sure + # that no exception is raised + builder.updater(None) + getattr(builder, method)(data_file('private.key')) + + builder = ApplicationBuilder() + getattr(builder, method)(data_file('private.key')) + builder.updater(None) + + def test_all_bot_args_custom(self, builder, bot, monkeypatch): defaults = Defaults() request = HTTPXRequest() get_updates_request = HTTPXRequest() @@ -123,6 +247,10 @@ def test_all_bot_args_custom(self, builder, bot): ) built_bot = builder.build().bot + # In the following we access some private attributes of bot and request. this is not + # really nice as we want to test the public interface, but here it's hard to ensure by + # other means that the parameters are passed correctly + assert built_bot.token == bot.token assert built_bot.base_url == 'base_url' + bot.token assert built_bot.base_file_url == 'base_file_url' + bot.token @@ -130,94 +258,99 @@ def test_all_bot_args_custom(self, builder, bot): assert built_bot.request is request assert built_bot._request[0] is get_updates_request assert built_bot.callback_data_cache.maxsize == 42 + assert built_bot.private_key - builder = ApplicationBuilder() - builder.connection_pool_size(1).connect_timeout(2).pool_timeout(3) - # TODO: This test is not finished - # builder.token(bot.token).request_kwargs({'connect_timeout': 42}) - # built_bot = builder.build().bot - # - # assert built_bot.token == bot.token - # assert built_bot.request._connect_timeout == 42 - - # - # def test_all_dispatcher_args_custom(self, app, builder): - # job_queue = JobQueue() - # persistence = PicklePersistence('filename') - # context_types = ContextTypes() - # builder.bot(app.bot).update_queue(app.update_queue).exception_event( - # app.exception_event - # ).job_queue(job_queue).persistence(persistence).context_types(context_types).workers(3) - # dispatcher = builder.build() - # - # assert dispatcher.bot is app.bot - # assert dispatcher.update_queue is app.update_queue - # assert dispatcher.exception_event is app.exception_event - # assert dispatcher.job_queue is job_queue - # assert dispatcher.job_queue.dispatcher is dispatcher - # assert dispatcher.persistence is persistence - # assert dispatcher.context_types is context_types - # assert dispatcher.workers == 3 - # - # def test_all_updater_args_custom(self, app, builder): - # updater = ( - # builder.bot(app.bot) - # .exception_event(app.exception_event) - # .update_queue(app.update_queue) - # .user_signal_handler(42) - # .build() - # ) - # - # assert updater.dispatcher is None - # assert updater.bot is app.bot - # assert updater.exception_event is app.exception_event - # assert updater.update_queue is app.update_queue - # assert updater.user_signal_handler == 42 - # - # def test_connection_pool_size_with_workers(self, bot, builder): - # app = builder.token(bot.token).workers(42).build() - # assert app.workers == 42 - # assert app.bot.request.con_pool_size == 46 - # - # def test_connection_pool_size_warning(self, bot, builder, recwarn): - # builder.token(bot.token).workers(42).request_kwargs({'con_pool_size': 1}) - # app = builder.build() - # assert app.workers == 42 - # assert app.bot.request.con_pool_size == 1 - # - # assert len(recwarn) == 1 - # message = str(recwarn[-1].message) - # assert 'smaller (1)' in message - # assert 'recommended value of 46.' in message - # assert recwarn[-1].filename == __file__, "wrong stacklevel" - # - # def test_custom_classes(self, bot, builder): - # class CustomApplication(Application): - # def __init__(self, arg, **kwargs): - # super().__init__(**kwargs) - # self.arg = arg - # - # builder.application_class(CustomApplication, kwargs={'arg': 2}).token(bot.token) - # - # obj = builder.build() - # assert isinstance(obj, CustomApplication) - # assert obj.arg == 2 - # - # @pytest.mark.parametrize('input_type', ('bytes', 'str', 'Path')) - # def test_all_private_key_input_types(self, builder, bot, input_type): - # private_key = Path('tests/data/private.key') - # password = Path('tests/data/private_key.password') - # - # if input_type == 'bytes': - # private_key = private_key.read_bytes() - # password = password.read_bytes() - # if input_type == 'str': - # private_key = str(private_key) - # password = str(password) - # - # builder.token(bot.token).private_key( - # private_key=private_key, - # password=password, - # ) - # bot = builder.build().bot - # assert bot.private_key + @dataclass + class Client: + timeout: object + proxies: object + limits: object + + monkeypatch.setattr(httpx, 'AsyncClient', Client) + + builder = ApplicationBuilder().token(bot.token) + builder.connection_pool_size(1).connect_timeout(2).pool_timeout(3).read_timeout( + 4 + ).write_timeout(5).proxy_url('proxy_url') + app = builder.build() + client = app.bot.request._client + + assert client.timeout == httpx.Timeout(pool=3, connect=2, read=4, write=5) + assert client.limits == httpx.Limits(max_connections=1, max_keepalive_connections=1) + assert client.proxies == 'proxy_url' + + builder = ApplicationBuilder().token(bot.token) + builder.get_updates_connection_pool_size(1).get_updates_connect_timeout( + 2 + ).get_updates_pool_timeout(3).get_updates_read_timeout(4).get_updates_write_timeout( + 5 + ).get_updates_proxy_url( + 'proxy_url' + ) + app = builder.build() + client = app.bot._request[0]._client + + assert client.timeout == httpx.Timeout(pool=3, connect=2, read=4, write=5) + assert client.limits == httpx.Limits(max_connections=1, max_keepalive_connections=1) + assert client.proxies == 'proxy_url' + + def test_custom_application_class(self, bot, builder): + class CustomApplication(Application): + def __init__(self, arg, **kwargs): + super().__init__(**kwargs) + self.arg = arg + + builder.application_class(CustomApplication, kwargs={'arg': 2}).token(bot.token) + + app = builder.build() + assert isinstance(app, CustomApplication) + assert app.arg == 2 + + def test_all_application_args_custom(self, builder, bot, monkeypatch): + job_queue = JobQueue() + persistence = PicklePersistence('file_path') + update_queue = asyncio.Queue() + context_types = ContextTypes() + concurrent_updates = 123 + app = ( + builder.token(bot.token) + .job_queue(job_queue) + .persistence(persistence) + .update_queue(update_queue) + .context_types(context_types) + .concurrent_updates(concurrent_updates) + ).build() + assert app.job_queue is job_queue + assert app.job_queue.application is app + assert app.persistence is persistence + assert app.persistence.bot is app.bot + assert app.update_queue is update_queue + assert app.updater.update_queue is update_queue + assert app.updater.bot is app.bot + assert app.context_types is context_types + assert app.concurrent_updates == concurrent_updates + + updater = Updater(bot=bot, update_queue=update_queue) + app = ApplicationBuilder().updater(updater).build() + assert app.updater is updater + assert app.bot is updater.bot + assert app.update_queue is updater.update_queue + + @pytest.mark.parametrize('input_type', ('bytes', 'str', 'Path')) + def test_all_private_key_input_types(self, builder, bot, input_type): + private_key = data_file('private.key') + password = data_file('private_key.password') + + if input_type == 'bytes': + private_key = private_key.read_bytes() + password = password.read_bytes() + if input_type == 'str': + private_key = str(private_key) + password = str(password) + + builder.token(bot.token).private_key( + private_key=private_key, + password=password, + ) + bot = builder.build().bot + assert bot.private_key From b7d007fcf34830645dfcc8002318c58196957c16 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Feb 2022 23:41:40 +0100 Subject: [PATCH 020/153] rename a file --- telegram/ext/__init__.py | 2 +- telegram/ext/_application.py | 4 ++-- telegram/ext/{_builders.py => _applicationbuilder.py} | 0 tests/test_builders.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename telegram/ext/{_builders.py => _applicationbuilder.py} (100%) diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index e2ad407d8bc..d193549b986 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -83,4 +83,4 @@ from ._chatjoinrequesthandler import ChatJoinRequestHandler from ._defaults import Defaults from ._callbackdatacache import CallbackDataCache, InvalidCallbackData -from ._builders import ApplicationBuilder +from ._applicationbuilder import ApplicationBuilder diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 61d0050c2a2..0844d762c7a 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: from telegram import Message from telegram.ext._jobqueue import Job - from telegram.ext._builders import InitApplicationBuilder + from telegram.ext._applicationbuilder import InitApplicationBuilder DEFAULT_GROUP: int = 0 @@ -202,7 +202,7 @@ def __init__( context_types: ContextTypes[CCT, UD, CD, BD], ): if not was_called_by( - inspect.currentframe(), Path(__file__).parent.resolve() / '_builders.py' + inspect.currentframe(), Path(__file__).parent.resolve() / '_applicationbuilder.py' ): warn( '`Application` instances should be built via the `ApplicationBuilder`.', diff --git a/telegram/ext/_builders.py b/telegram/ext/_applicationbuilder.py similarity index 100% rename from telegram/ext/_builders.py rename to telegram/ext/_applicationbuilder.py diff --git a/tests/test_builders.py b/tests/test_builders.py index aa59e352357..cc6a6078d50 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -35,7 +35,7 @@ Updater, ExtBot, ) -from telegram.ext._builders import _BOT_CHECKS +from telegram.ext._applicationbuilder import _BOT_CHECKS @pytest.fixture(scope='function') From 683edf6488e051e3dbc857b67e885a68ddd1ec1e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 20 Feb 2022 15:01:12 +0100 Subject: [PATCH 021/153] Get started on updater tests --- tests/test_updater.py | 601 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 601 insertions(+) create mode 100644 tests/test_updater.py diff --git a/tests/test_updater.py b/tests/test_updater.py new file mode 100644 index 00000000000..d66d66f03eb --- /dev/null +++ b/tests/test_updater.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio +from threading import Event + +from urllib.request import Request, urlopen + +import pytest + +from telegram import ( + Bot, +) +from telegram.ext import ( + Updater, +) + + +class TestUpdater: + message_count = 0 + received = None + attempts = 0 + err_handler_called = Event() + cb_handler_called = Event() + offset = 0 + test_flag = False + + @pytest.fixture(autouse=True) + def reset(self): + self.message_count = 0 + self.received = None + self.attempts = 0 + self.err_handler_called.clear() + self.cb_handler_called.clear() + self.test_flag = False + + def error_handler(self, update, context): + self.received = context.error.message + self.err_handler_called.set() + + def callback(self, update, context): + self.received = update.message.text + self.cb_handler_called.set() + + def _send_webhook_msg( + self, + ip, + port, + payload_str, + url_path='', + content_len=-1, + content_type='application/json', + get_method=None, + ): + headers = { + 'content-type': content_type, + } + + if not payload_str: + content_len = None + payload = None + else: + payload = bytes(payload_str, encoding='utf-8') + + if content_len == -1: + content_len = len(payload) + + if content_len is not None: + headers['content-length'] = str(content_len) + + url = f'http://{ip}:{port}/{url_path}' + + req = Request(url, data=payload, headers=headers) + + if get_method is not None: + req.get_method = get_method + + return urlopen(req) + + def test_slot_behaviour(self, updater, mro_slots): + for at in updater.__slots__: + at = f"_Updater{at}" if at.startswith('__') and not at.endswith('__') else at + assert getattr(updater, at, 'err') != 'err', f"got extra slot '{at}'" + assert len(mro_slots(updater)) == len(set(mro_slots(updater))), "duplicate slot" + + def test_init(self, bot): + queue = asyncio.Queue() + updater = Updater(bot=bot, update_queue=queue) + assert updater.bot is bot + assert updater.update_queue is queue + + @pytest.mark.asyncio + async def test_initialize(self, bot, monkeypatch): + async def initialize_bot(*args, **kwargs): + self.test_flag = True + + async with Bot(bot.token) as test_bot: + monkeypatch.setattr(test_bot, 'initialize', initialize_bot) + + updater = Updater(bot=test_bot, update_queue=asyncio.Queue()) + await updater.initialize() + + assert self.test_flag + + @pytest.mark.asyncio + async def test_shutdown(self, bot, monkeypatch): + async def shutdown_bot(*args, **kwargs): + self.test_flag = True + + async with Bot(bot.token) as test_bot: + monkeypatch.setattr(test_bot, 'shutdown', shutdown_bot) + + updater = Updater(bot=test_bot, update_queue=asyncio.Queue()) + await updater.initialize() + await updater.shutdown() + + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_manager(self, monkeypatch, updater): + async def initialize(*args, **kwargs): + self.test_flag = ['initialize'] + + async def shutdown(*args, **kwargs): + self.test_flag.append('stop') + + monkeypatch.setattr(Updater, 'initialize', initialize) + monkeypatch.setattr(Updater, 'shutdown', shutdown) + + async with updater: + pass + + assert self.test_flag == ['initialize', 'stop'] + + @pytest.mark.asyncio + async def test_context_manager_exception_on_init(self, monkeypatch, updater): + async def initialize(*args, **kwargs): + raise RuntimeError('initialize') + + async def shutdown(*args): + self.test_flag = 'stop' + + monkeypatch.setattr(Updater, 'initialize', initialize) + monkeypatch.setattr(Updater, 'shutdown', shutdown) + + with pytest.raises(RuntimeError, match='initialize'): + async with updater: + pass + + assert self.test_flag == 'stop' + + # @pytest.mark.asyncio + # async def test_polling(self, updater, monkeypatch): + # updates = asyncio.Queue() + # await updates.put(Update(update_id=1)) + # await updates.put(Update(update_id=2)) + # await updates.put(Update(update_id=3)) + # await updates.put(Update(update_id=4)) + # + # async def get_updates(*args, **kwargs): + # if not updates.empty(): + # return [updates.get_nowait()] + # return [] + # + # monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + # + # async with updater: + # await updater.start_polling() + # assert updater.running + # await asyncio.sleep(1) + # await updater.stop() + # + # while not updater.update_queue.empty(): + # update = updater.update_queue.get_nowait() + # self.message_count += update.update_id + # + # assert self.message_count == 10 + + # @pytest.mark.parametrize( + # ('error',), + # argvalues=[(TelegramError('Test Error 2'),), (Unauthorized('Test Unauthorized'),)], + # ids=('TelegramError', 'Unauthorized'), + # ) + # def test_get_updates_normal_err(self, monkeypatch, updater, error): + # def test(*args, **kwargs): + # raise error + # + # monkeypatch.setattr(updater.bot, 'get_updates', test) + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # updater.dispatcher.add_error_handler(self.error_handler) + # updater.start_polling(0.01) + # + # # Make sure that the error handler was called + # self.err_handler_called.wait() + # assert self.received == error.message + # + # # Make sure that Updater polling thread keeps running + # self.err_handler_called.clear() + # self.err_handler_called.wait() + # + # @pytest.mark.filterwarnings('ignore:.*:pytest.PytestUnhandledThreadExceptionWarning') + # def test_get_updates_bailout_err(self, monkeypatch, updater, caplog): + # error = InvalidToken() + # + # def test(*args, **kwargs): + # raise error + # + # with caplog.at_level(logging.DEBUG): + # monkeypatch.setattr(updater.bot, 'get_updates', test) + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # updater.dispatcher.add_error_handler(self.error_handler) + # updater.start_polling(0.01) + # assert self.err_handler_called.wait(1) is not True + # + # sleep(1) + # # NOTE: This test might hit a race condition and fail (though the 1 seconds delay above + # # should work around it). + # # NOTE: Checking Updater.running is problematic because it is not set to False when there's + # # an unhandled exception. + # # TODO: We should have a way to poll Updater status and decide if it's running or not. + # import pprint + # + # pprint.pprint([rec.getMessage() for rec in caplog.get_records('call')]) + # assert any( + # f'unhandled exception in Bot:{updater.bot.id}:updater' in rec.getMessage() + # for rec in caplog.get_records('call') + # ) + # + # @pytest.mark.parametrize( + # ('error',), argvalues=[(RetryAfter(0.01),), (TimedOut(),)], ids=('RetryAfter', 'TimedOut') + # ) + # def test_get_updates_retries(self, monkeypatch, updater, error): + # event = Event() + # + # def test(*args, **kwargs): + # event.set() + # raise error + # + # monkeypatch.setattr(updater.bot, 'get_updates', test) + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # updater.dispatcher.add_error_handler(self.error_handler) + # updater.start_polling(0.01) + # + # # Make sure that get_updates was called, but not the error handler + # event.wait() + # assert self.err_handler_called.wait(0.5) is not True + # assert self.received != error.message + # + # # Make sure that Updater polling thread keeps running + # event.clear() + # event.wait() + # assert self.err_handler_called.wait(0.5) is not True + # + # @pytest.mark.parametrize('ext_bot', [True, False]) + # def test_webhook(self, monkeypatch, updater, ext_bot): + # # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler + # # that depends on this distinction works + # if ext_bot and not isinstance(updater.bot, ExtBot): + # updater.bot = ExtBot(updater.bot.token) + # if not ext_bot and not type(updater.bot) is Bot: + # updater.bot = DictBot(updater.bot.token) + # + # q = Queue() + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) + # + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # updater.start_webhook(ip, port, url_path='TOKEN') + # sleep(0.2) + # try: + # # Now, we send an update to the server via urlopen + # update = Update( + # 1, + # message=Message( + # 1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook' + # ), + # ) + # self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') + # sleep(0.2) + # assert q.get(False) == update + # + # # Returns 404 if path is incorrect + # with pytest.raises(HTTPError) as excinfo: + # self._send_webhook_msg(ip, port, None, 'webookhandler.py') + # assert excinfo.value.code == 404 + # + # with pytest.raises(HTTPError) as excinfo: + # self._send_webhook_msg( + # ip, port, None, 'webookhandler.py', get_method=lambda: 'HEAD' + # ) + # assert excinfo.value.code == 404 + # + # # Test multiple shutdown() calls + # updater.httpd.shutdown() + # finally: + # updater.httpd.shutdown() + # sleep(0.2) + # assert not updater.httpd.is_running + # updater.stop() + # + # @pytest.mark.parametrize('invalid_data', [True, False]) + # def test_webhook_arbitrary_callback_data(self, monkeypatch, updater, invalid_data): + # """Here we only test one simple setup. telegram.ext.ExtBot.insert_callback_data is tested + # extensively in test_bot.py in conjunction with get_updates.""" + # updater.bot.arbitrary_callback_data = True + # try: + # q = Queue() + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) + # + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # updater.start_webhook(ip, port, url_path='TOKEN') + # sleep(0.2) + # try: + # # Now, we send an update to the server via urlopen + # reply_markup = InlineKeyboardMarkup.from_button( + # InlineKeyboardButton(text='text', callback_data='callback_data') + # ) + # if not invalid_data: + # reply_markup = updater.bot.callback_data_cache.process_keyboard(reply_markup) + # + # message = Message( + # 1, + # None, + # None, + # reply_markup=reply_markup, + # ) + # update = Update(1, message=message) + # self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') + # sleep(0.2) + # received_update = q.get(False) + # assert received_update == update + # + # button = received_update.message.reply_markup.inline_keyboard[0][0] + # if invalid_data: + # assert isinstance(button.callback_data, InvalidCallbackData) + # else: + # assert button.callback_data == 'callback_data' + # + # # Test multiple shutdown() calls + # updater.httpd.shutdown() + # finally: + # updater.httpd.shutdown() + # sleep(0.2) + # assert not updater.httpd.is_running + # updater.stop() + # finally: + # updater.bot.arbitrary_callback_data = False + # updater.bot.callback_data_cache.clear_callback_data() + # updater.bot.callback_data_cache.clear_callback_queries() + # + # @pytest.mark.parametrize('use_dispatcher', (True, False)) + # def test_start_webhook_no_warning_or_error_logs( + # self, caplog, updater, monkeypatch, use_dispatcher + # ): + # if not use_dispatcher: + # updater.dispatcher = None + # + # self.test_flag = 0 + # + # def set_flag(): + # self.test_flag += 1 + # + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr(updater.bot._request, 'stop', lambda *args, **kwargs: set_flag()) + # # prevent api calls from @info decorator when updater.bot.id is used in thread names + # monkeypatch.setattr(updater.bot, '_bot', User(id=123, first_name='bot', is_bot=True)) + # + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # with caplog.at_level(logging.WARNING): + # updater.start_webhook(ip, port) + # updater.stop() + # assert not caplog.records + # # Make sure that bot.request.stop() has been called exactly once + # assert self.test_flag == 1 + # + # def test_webhook_ssl(self, monkeypatch, updater): + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # tg_err = False + # try: + # updater._start_webhook( + # ip, + # port, + # url_path='TOKEN', + # cert=Path(__file__).as_posix(), + # key=Path(__file__).as_posix(), + # bootstrap_retries=0, + # drop_pending_updates=False, + # webhook_url=None, + # allowed_updates=None, + # ) + # except TelegramError: + # tg_err = True + # assert tg_err + # + # def test_webhook_no_ssl(self, monkeypatch, updater): + # q = Queue() + # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) + # + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # updater.start_webhook(ip, port, webhook_url=None) + # sleep(0.2) + # + # # Now, we send an update to the server via urlopen + # update = Update( + # 1, + # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), + # ) + # self._send_webhook_msg(ip, port, update.to_json()) + # sleep(0.2) + # assert q.get(False) == update + # updater.stop() + # + # def test_webhook_ssl_just_for_telegram(self, monkeypatch, updater): + # q = Queue() + # + # def set_webhook(**kwargs): + # self.test_flag.append(bool(kwargs.get('certificate'))) + # return True + # + # orig_wh_server_init = WebhookServer.__init__ + # + # def webhook_server_init(*args): + # self.test_flag = [args[-1] is None] + # orig_wh_server_init(*args) + # + # monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) + # monkeypatch.setattr( + # 'telegram.ext._utils.webhookhandler.WebhookServer.__init__', webhook_server_init + # ) + # + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # updater.start_webhook(ip, port, webhook_url=None, cert=Path(__file__).as_posix()) + # sleep(0.2) + # + # # Now, we send an update to the server via urlopen + # update = Update( + # 1, + # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), + # ) + # self._send_webhook_msg(ip, port, update.to_json()) + # sleep(0.2) + # assert q.get(False) == update + # updater.stop() + # assert self.test_flag == [True, True] + # + # @pytest.mark.parametrize('pass_max_connections', [True, False]) + # def test_webhook_max_connections(self, monkeypatch, updater, pass_max_connections): + # q = Queue() + # max_connections = 42 + # + # def set_webhook(**kwargs): + # print(kwargs) + # self.test_flag = kwargs.get('max_connections') == ( + # max_connections if pass_max_connections else 40 + # ) + # return True + # + # monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) + # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) + # + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # Select random port + # if pass_max_connections: + # updater.start_webhook(ip, port, webhook_url=None, max_connections=max_connections) + # else: + # updater.start_webhook(ip, port, webhook_url=None) + # + # sleep(0.2) + # + # # Now, we send an update to the server via urlopen + # update = Update( + # 1, + # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), + # ) + # self._send_webhook_msg(ip, port, update.to_json()) + # sleep(0.2) + # assert q.get(False) == update + # updater.stop() + # assert self.test_flag is True + # + # @pytest.mark.parametrize(('error',), argvalues=[(TelegramError(''),)], ids=('TelegramError',)) + # def test_bootstrap_retries_success(self, monkeypatch, updater, error): + # retries = 2 + # + # def attempt(*args, **kwargs): + # if self.attempts < retries: + # self.attempts += 1 + # raise error + # + # monkeypatch.setattr(updater.bot, 'set_webhook', attempt) + # + # updater.running = True + # updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) + # assert self.attempts == retries + # + # @pytest.mark.parametrize( + # ('error', 'attempts'), + # argvalues=[(TelegramError(''), 2), (Unauthorized(''), 1), (InvalidToken(), 1)], + # ids=('TelegramError', 'Unauthorized', 'InvalidToken'), + # ) + # def test_bootstrap_retries_error(self, monkeypatch, updater, error, attempts): + # retries = 1 + # + # def attempt(*args, **kwargs): + # self.attempts += 1 + # raise error + # + # monkeypatch.setattr(updater.bot, 'set_webhook', attempt) + # + # updater.running = True + # with pytest.raises(type(error)): + # updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) + # assert self.attempts == attempts + # + # @pytest.mark.parametrize('drop_pending_updates', (True, False)) + # def test_bootstrap_clean_updates(self, monkeypatch, updater, drop_pending_updates): + # # As dropping pending updates is done by passing `drop_pending_updates` to + # # set_webhook, we just check that we pass the correct value + # self.test_flag = False + # + # def delete_webhook(**kwargs): + # self.test_flag = kwargs.get('drop_pending_updates') == drop_pending_updates + # + # monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) + # + # updater.running = True + # updater._bootstrap( + # 1, + # drop_pending_updates=drop_pending_updates, + # webhook_url=None, + # allowed_updates=None, + # bootstrap_interval=0, + # ) + # assert self.test_flag is True + # + # @flaky(3, 1) + # def test_webhook_invalid_posts(self, updater): + # ip = '127.0.0.1' + # port = randrange(1024, 49152) # select random port for travis + # thr = Thread( + # target=updater._start_webhook, args=(ip, port, '', None, None, 0, False, None, None) + # ) + # thr.start() + # + # sleep(0.2) + # + # try: + # with pytest.raises(HTTPError) as excinfo: + # self._send_webhook_msg( + # ip, port, 'data', content_type='application/xml' + # ) + # assert excinfo.value.code == 403 + # + # with pytest.raises(HTTPError) as excinfo: + # self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2) + # assert excinfo.value.code == 500 + # + # # TODO: prevent urllib or the underlying from adding content-length + # # with pytest.raises(HTTPError) as excinfo: + # # self._send_webhook_msg(ip, port, 'dummy-payload', content_len=None) + # # assert excinfo.value.code == 411 + # + # with pytest.raises(HTTPError): + # self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number') + # assert excinfo.value.code == 500 + # + # finally: + # updater.httpd.shutdown() + # thr.join() From 01cb345714ddb2a112d3b0fd19f61ee7bbcca562 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 22 Feb 2022 16:20:03 +0100 Subject: [PATCH 022/153] One more test & some code simplifications --- telegram/ext/_updater.py | 112 +++++++++++++++++++-------------------- tests/conftest.py | 5 +- tests/test_updater.py | 65 ++++++++++++++--------- 3 files changed, 97 insertions(+), 85 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 945238e7c11..1b56afa6aad 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -75,7 +75,7 @@ class Updater: '_running', '_httpd', '__lock', - '__asyncio_tasks', + '__polling_task', ) def __init__( @@ -90,7 +90,7 @@ def __init__( self._running = False self._httpd: Optional[WebhookServer] = None self.__lock = asyncio.Lock() - self.__asyncio_tasks: List[asyncio.Task] = [] + self.__polling_task: Optional[asyncio.Task] = None self._logger = logging.getLogger(__name__) @property @@ -122,16 +122,6 @@ async def __aexit__( # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ await self.shutdown() - def _init_task( - self, target: Callable[..., Coroutine], name: str, *args: object, **kwargs: object - ) -> None: - task = asyncio.create_task( - coro=self._task_wrapper(target, name, *args, **kwargs), - # TODO: Add this once we drop py3.7 - # name=f"Updater:{self.bot.id}:{name}", - ) - self.__asyncio_tasks.append(task) - async def _task_wrapper( self, target: Callable, name: str, *args: object, **kwargs: object ) -> None: @@ -190,28 +180,29 @@ async def start_polling( Returns: :class:`asyncio.Queue`: The update queue that can be filled from the main thread. + Raises: + :exc:`RuntimeError`: If the updater is already running. + """ async with self.__lock: if self.running: - return self.update_queue + raise RuntimeError('This Updater is already running!') self._running = True # Create & start tasks polling_ready = asyncio.Event() - self._init_task( - self._start_polling, - "Polling Background task", - poll_interval, - timeout, - read_timeout, - write_timeout, - connect_timeout, - pool_timeout, - bootstrap_retries, - drop_pending_updates, - allowed_updates, + await self._start_polling( + poll_interval=poll_interval, + timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + bootstrap_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + allowed_updates=allowed_updates, ready=polling_ready, error_callback=error_callback, ) @@ -227,18 +218,15 @@ async def _start_polling( poll_interval: float, timeout: int, read_timeout: Optional[float], - write_timeout: Optional[float], - connect_timeout: Optional[float], - pool_timeout: Optional[float], + write_timeout: ODVInput[float], + connect_timeout: ODVInput[float], + pool_timeout: ODVInput[float], bootstrap_retries: int, - drop_pending_updates: bool, + drop_pending_updates: Optional[bool], allowed_updates: Optional[List[str]], - ready: asyncio.Event = None, - error_callback: Callable[[TelegramError], None] = None, + ready: asyncio.Event, + error_callback: Optional[Callable[[TelegramError], None]], ) -> None: - # Target of task 'updater.start_polling()'. Runs in background, pulls - # updates from Telegram and inserts them in the update queue of the - # Application. self._logger.debug('Updater started (polling)') @@ -275,16 +263,21 @@ async def polling_action_cb() -> bool: def default_error_callback(exc: TelegramError) -> None: self._logger.exception('Exception happened while polling for updates.', exc_info=exc) + # Start task that runs in background, pulls + # updates from Telegram and inserts them in the update queue of the + # Application. + self.__polling_task = asyncio.create_task( + self._network_loop_retry( + action_cb=polling_action_cb, + on_err_cb=error_callback or default_error_callback, + description='getting Updates', + interval=poll_interval, + ) + ) + if ready is not None: ready.set() - await self._network_loop_retry( - action_cb=polling_action_cb, - onerr_cb=error_callback or default_error_callback, - description='getting Updates', - interval=poll_interval, - ) - async def start_webhook( self, listen: str = '127.0.0.1', @@ -341,10 +334,13 @@ async def start_webhook( .. versionadded:: 13.6 Returns: :class:`queue.Queue`: The update queue that can be filled from the main thread. + + Raises: + :exc:`RuntimeError`: If the updater is already running. """ async with self.__lock: if self.running: - return self.update_queue + raise RuntimeError('This Updater is already running!') self._running = True @@ -448,7 +444,7 @@ def _gen_webhook_url(listen: str, port: int, url_path: str) -> str: async def _network_loop_retry( self, action_cb: Callable[..., Coroutine], - onerr_cb: Callable[[TelegramError], None], + on_err_cb: Callable[[TelegramError], None], description: str, interval: float, ) -> None: @@ -459,8 +455,8 @@ async def _network_loop_retry( Args: action_cb (:obj:`callable`): Network oriented callback function to call. - onerr_cb (:obj:`callable`): Callback to call when TelegramError is caught. Receives the - exception object as a parameter. + on_err_cb (:obj:`callable`): Callback to call when TelegramError is caught. Receives + the exception object as a parameter. description (:obj:`str`): Description text to use for logs and exception raised. interval (:obj:`float` | :obj:`int`): Interval to sleep between each call to `action_cb`. @@ -485,7 +481,7 @@ async def _network_loop_retry( raise pex except TelegramError as telegram_exc: self._logger.error('Error while %s: %s', description, telegram_exc) - onerr_cb(telegram_exc) + on_err_cb(telegram_exc) cur_interval = self._increase_poll_interval(cur_interval) else: cur_interval = interval @@ -542,8 +538,10 @@ async def bootstrap_set_webhook() -> bool: ) return False - def bootstrap_onerr_cb(exc: Exception) -> None: - if not isinstance(exc, Forbidden) and (max_retries < 0 or retries[0] < max_retries): + def bootstrap_on_err_cb(exc: Exception) -> None: + if not isinstance(exc, (Forbidden, InvalidToken)) and ( + max_retries < 0 or retries[0] < max_retries + ): retries[0] += 1 self._logger.warning( 'Failed bootstrap phase; try=%s max_retries=%s', retries[0], max_retries @@ -559,7 +557,7 @@ def bootstrap_onerr_cb(exc: Exception) -> None: if drop_pending_updates or not webhook_url: await self._network_loop_retry( bootstrap_del_webhook, - bootstrap_onerr_cb, + bootstrap_on_err_cb, 'bootstrap del webhook', bootstrap_interval, ) @@ -570,7 +568,7 @@ def bootstrap_onerr_cb(exc: Exception) -> None: if webhook_url: await self._network_loop_retry( bootstrap_set_webhook, - bootstrap_onerr_cb, + bootstrap_on_err_cb, 'bootstrap set webhook', bootstrap_interval, ) @@ -584,7 +582,7 @@ async def stop(self) -> None: self._running = False await self._stop_httpd() - await self._join_tasks() + await self._stop_polling() self._logger.debug('Updater.stop() is complete') @@ -594,9 +592,9 @@ async def _stop_httpd(self) -> None: await self._httpd.shutdown() self._httpd = None - async def _join_tasks(self) -> None: - self._logger.debug('Stopping Background tasks') - for task in self.__asyncio_tasks: - task.cancel() - await asyncio.gather(*self.__asyncio_tasks) - self.__asyncio_tasks = [] + async def _stop_polling(self) -> None: + if self.__polling_task: + self._logger.debug('Waiting background polling task to join.') + self.__polling_task.cancel() + await self.__polling_task + self.__polling_task = None diff --git a/tests/conftest.py b/tests/conftest.py index 6a061b41342..693fa0fbac8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -259,11 +259,12 @@ def app(_app): @pytest.fixture(scope='function') -def updater(bot): +@pytest.mark.asyncio +async def updater(bot): up = Updater(bot=bot, update_queue=asyncio.Queue()) yield up if up.running: - up.stop() + await up.stop() PROJECT_ROOT_PATH = Path(__file__).parent.parent.resolve() diff --git a/tests/test_updater.py b/tests/test_updater.py index d66d66f03eb..831fac3a8d5 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -25,6 +25,7 @@ from telegram import ( Bot, + Update, ) from telegram.ext import ( Updater, @@ -164,32 +165,44 @@ async def shutdown(*args): assert self.test_flag == 'stop' - # @pytest.mark.asyncio - # async def test_polling(self, updater, monkeypatch): - # updates = asyncio.Queue() - # await updates.put(Update(update_id=1)) - # await updates.put(Update(update_id=2)) - # await updates.put(Update(update_id=3)) - # await updates.put(Update(update_id=4)) - # - # async def get_updates(*args, **kwargs): - # if not updates.empty(): - # return [updates.get_nowait()] - # return [] - # - # monkeypatch.setattr(updater.bot, 'get_updates', get_updates) - # - # async with updater: - # await updater.start_polling() - # assert updater.running - # await asyncio.sleep(1) - # await updater.stop() - # - # while not updater.update_queue.empty(): - # update = updater.update_queue.get_nowait() - # self.message_count += update.update_id - # - # assert self.message_count == 10 + @pytest.mark.asyncio + async def test_polling_basic(self, monkeypatch, updater): + updates = asyncio.Queue() + await updates.put(Update(update_id=1)) + await updates.put(Update(update_id=2)) + + async def get_updates(*args, **kwargs): + next_update = await updates.get() + updates.task_done() + return [next_update] + + monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + + async with updater: + await updater.start_polling() + assert updater.running + await updates.join() + await updater.stop() + assert not updater.running + + await updates.put(Update(update_id=3)) + await updates.put(Update(update_id=4)) + + # We call the same logic twice to make sure that restarting the updater works as well + await updater.start_polling() + assert updater.running + await updates.join() + await updater.stop() + assert not updater.running + + self.received = [] + while not updater.update_queue.empty(): + update = updater.update_queue.get_nowait() + self.message_count += 1 + self.received.append(update.update_id) + + assert self.message_count == 4 + assert self.received == [1, 2, 3, 4] # @pytest.mark.parametrize( # ('error',), From aa8ea12838d56dab4cceec68c10ef23d90902774 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 22 Feb 2022 21:18:14 +0100 Subject: [PATCH 023/153] Tests for Updater.start_polling --- telegram/ext/_updater.py | 12 +- tests/conftest.py | 2 +- tests/test_updater.py | 250 +++++++++++++++++++++++++-------------- 3 files changed, 166 insertions(+), 98 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 1b56afa6aad..f10207d7ba5 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -35,7 +35,7 @@ from telegram._utils.defaultvalue import DEFAULT_NONE from telegram._utils.types import ODVInput -from telegram.error import InvalidToken, RetryAfter, TimedOut, Forbidden, TelegramError +from telegram.error import InvalidToken, RetryAfter, TimedOut, TelegramError from telegram.ext._utils.webhookhandler import WebhookAppClass, WebhookServer if TYPE_CHECKING: @@ -132,8 +132,6 @@ async def _task_wrapper( self._logger.exception('Unhandled exception in %s.', name) self._logger.debug('%s - ended', name) - # TODO: Probably drop `pool_connect` timeout again, because we probably want to just make - # sure that `getUpdates` always gets a connection without waiting async def start_polling( self, poll_interval: float = 0.0, @@ -241,7 +239,7 @@ async def _start_polling( async def polling_action_cb() -> bool: updates = await self.bot.get_updates( - self.last_update_id, + offset=self.last_update_id, timeout=timeout, read_timeout=read_timeout, connect_timeout=connect_timeout, @@ -511,7 +509,7 @@ async def _bootstrap( allowed_updates: Optional[List[str]], drop_pending_updates: bool = None, cert: Union[str, Path] = None, - bootstrap_interval: float = 5, + bootstrap_interval: float = 1, ip_address: str = None, max_connections: int = 40, ) -> None: @@ -539,9 +537,7 @@ async def bootstrap_set_webhook() -> bool: return False def bootstrap_on_err_cb(exc: Exception) -> None: - if not isinstance(exc, (Forbidden, InvalidToken)) and ( - max_retries < 0 or retries[0] < max_retries - ): + if not isinstance(exc, InvalidToken) and (max_retries < 0 or retries[0] < max_retries): retries[0] += 1 self._logger.warning( 'Failed bootstrap phase; try=%s max_retries=%s', retries[0], max_retries diff --git a/tests/conftest.py b/tests/conftest.py index 693fa0fbac8..d8145b37852 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,7 +156,7 @@ class DictApplication(Application): pass -@pytest.fixture(scope='session') +@pytest.fixture(scope='function') @pytest.mark.asyncio async def bot(bot_info): async with DictExtBot( diff --git a/tests/test_updater.py b/tests/test_updater.py index 831fac3a8d5..9075610f820 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio -from threading import Event from urllib.request import Request, urlopen @@ -27,17 +26,20 @@ Bot, Update, ) +from telegram._utils.defaultvalue import DEFAULT_NONE +from telegram.error import InvalidToken, TelegramError, TimedOut, RetryAfter from telegram.ext import ( Updater, ) +from telegram.request import HTTPXRequest class TestUpdater: message_count = 0 received = None attempts = 0 - err_handler_called = Event() - cb_handler_called = Event() + err_handler_called = None + cb_handler_called = None offset = 0 test_flag = False @@ -46,12 +48,12 @@ def reset(self): self.message_count = 0 self.received = None self.attempts = 0 - self.err_handler_called.clear() - self.cb_handler_called.clear() + self.err_handler_called = None + self.cb_handler_called = None self.test_flag = False - def error_handler(self, update, context): - self.received = context.error.message + def error_callback(self, error): + self.received = error self.err_handler_called.set() def callback(self, update, context): @@ -166,7 +168,8 @@ async def shutdown(*args): assert self.test_flag == 'stop' @pytest.mark.asyncio - async def test_polling_basic(self, monkeypatch, updater): + @pytest.mark.parametrize('drop_pending_updates', (True, False)) + async def test_polling_basic(self, monkeypatch, updater, drop_pending_updates): updates = asyncio.Queue() await updates.put(Update(update_id=1)) await updates.put(Update(update_id=2)) @@ -176,26 +179,42 @@ async def get_updates(*args, **kwargs): updates.task_done() return [next_update] + orig_del_webhook = updater.bot.delete_webhook + + async def delete_webhook(*args, **kwargs): + # Dropping pending updates is done by passing the parameter to delete_webhook + if kwargs.get('drop_pending_updates'): + self.message_count += 1 + return await orig_del_webhook(*args, **kwargs) + monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) async with updater: - await updater.start_polling() + await updater.start_polling(drop_pending_updates=drop_pending_updates) assert updater.running await updates.join() await updater.stop() assert not updater.running + assert not (await updater.bot.get_webhook_info()).url + if drop_pending_updates: + assert self.message_count == 1 + else: + assert self.message_count == 0 await updates.put(Update(update_id=3)) await updates.put(Update(update_id=4)) # We call the same logic twice to make sure that restarting the updater works as well - await updater.start_polling() + await updater.start_polling(drop_pending_updates=drop_pending_updates) assert updater.running await updates.join() await updater.stop() assert not updater.running + assert not (await updater.bot.get_webhook_info()).url self.received = [] + self.message_count = 0 while not updater.update_queue.empty(): update = updater.update_queue.get_nowait() self.message_count += 1 @@ -204,81 +223,134 @@ async def get_updates(*args, **kwargs): assert self.message_count == 4 assert self.received == [1, 2, 3, 4] - # @pytest.mark.parametrize( - # ('error',), - # argvalues=[(TelegramError('Test Error 2'),), (Unauthorized('Test Unauthorized'),)], - # ids=('TelegramError', 'Unauthorized'), - # ) - # def test_get_updates_normal_err(self, monkeypatch, updater, error): - # def test(*args, **kwargs): - # raise error - # - # monkeypatch.setattr(updater.bot, 'get_updates', test) - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # updater.dispatcher.add_error_handler(self.error_handler) - # updater.start_polling(0.01) - # - # # Make sure that the error handler was called - # self.err_handler_called.wait() - # assert self.received == error.message - # - # # Make sure that Updater polling thread keeps running - # self.err_handler_called.clear() - # self.err_handler_called.wait() - # - # @pytest.mark.filterwarnings('ignore:.*:pytest.PytestUnhandledThreadExceptionWarning') - # def test_get_updates_bailout_err(self, monkeypatch, updater, caplog): - # error = InvalidToken() - # - # def test(*args, **kwargs): - # raise error - # - # with caplog.at_level(logging.DEBUG): - # monkeypatch.setattr(updater.bot, 'get_updates', test) - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # updater.dispatcher.add_error_handler(self.error_handler) - # updater.start_polling(0.01) - # assert self.err_handler_called.wait(1) is not True - # - # sleep(1) - # # NOTE: This test might hit a race condition and fail (though the 1 seconds delay above - # # should work around it). - # # NOTE: Checking Updater.running is problematic because it is not set to False when there's - # # an unhandled exception. - # # TODO: We should have a way to poll Updater status and decide if it's running or not. - # import pprint - # - # pprint.pprint([rec.getMessage() for rec in caplog.get_records('call')]) - # assert any( - # f'unhandled exception in Bot:{updater.bot.id}:updater' in rec.getMessage() - # for rec in caplog.get_records('call') - # ) - # - # @pytest.mark.parametrize( - # ('error',), argvalues=[(RetryAfter(0.01),), (TimedOut(),)], ids=('RetryAfter', 'TimedOut') - # ) - # def test_get_updates_retries(self, monkeypatch, updater, error): - # event = Event() - # - # def test(*args, **kwargs): - # event.set() - # raise error - # - # monkeypatch.setattr(updater.bot, 'get_updates', test) - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # updater.dispatcher.add_error_handler(self.error_handler) - # updater.start_polling(0.01) - # - # # Make sure that get_updates was called, but not the error handler - # event.wait() - # assert self.err_handler_called.wait(0.5) is not True - # assert self.received != error.message - # - # # Make sure that Updater polling thread keeps running - # event.clear() - # event.wait() - # assert self.err_handler_called.wait(0.5) is not True - # + @pytest.mark.asyncio + async def test_start_polling_already_running(self, updater): + async with updater: + await updater.start_polling() + task = asyncio.create_task(updater.start_polling()) + with pytest.raises(RuntimeError, match='already running'): + await task + await updater.stop() + + @pytest.mark.asyncio + async def test_start_polling_get_updates_parameters(self, updater, monkeypatch): + update_queue = asyncio.Queue() + await update_queue.put(Update(update_id=1)) + + expected = dict( + timeout=10, + read_timeout=2, + write_timeout=DEFAULT_NONE, + connect_timeout=DEFAULT_NONE, + pool_timeout=DEFAULT_NONE, + allowed_updates=None, + ) + + async def get_updates(*args, **kwargs): + for key, value in expected.items(): + assert kwargs.get(key) == value + await update_queue.get() + update_queue.task_done() + return [] + + monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + + async with updater: + await updater.start_polling() + await update_queue.join() + await updater.stop() + + expected = dict( + timeout=42, + read_timeout=43, + write_timeout=44, + connect_timeout=45, + pool_timeout=46, + allowed_updates=['message'], + ) + + await update_queue.put(Update(update_id=1)) + await updater.start_polling( + timeout=42, + read_timeout=43, + write_timeout=44, + connect_timeout=45, + pool_timeout=46, + allowed_updates=['message'], + ) + await update_queue.join() + await updater.stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize('exception_class', (InvalidToken, TelegramError)) + @pytest.mark.parametrize('retries', (3, 0)) + async def test_start_polling_bootstrap_retries( + self, updater, monkeypatch, exception_class, retries + ): + async def do_request(*args, **kwargs): + self.message_count += 1 + raise exception_class(str(self.message_count)) + + monkeypatch.setattr(HTTPXRequest, 'do_request', do_request) + + async with updater: + if exception_class == InvalidToken: + with pytest.raises(InvalidToken, match='1'): + await updater.start_polling(bootstrap_retries=retries) + else: + with pytest.raises(TelegramError, match=str(retries + 1)): + await updater.start_polling( + bootstrap_retries=retries, + ) + + @pytest.mark.parametrize( + 'error,callback', + argvalues=[ + (TelegramError('TestMessage'), True), + (RetryAfter(1), False), + (TimedOut('TestMessage'), False), + ], + ids=('TelegramError', 'RetryAfter', 'TimedOut'), + ) + @pytest.mark.asyncio + async def test_start_polling_exceptions_and_error_callback( + self, monkeypatch, updater, error, callback + ): + get_updates_event = asyncio.Event() + + async def get_updates(*args, **kwargs): + # So that the main task has a chance to be called + await asyncio.sleep(0) + + get_updates_event.set() + raise error + + monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + + async with updater: + self.err_handler_called = asyncio.Event() + + await updater.start_polling(error_callback=self.error_callback) + await asyncio.sleep(1) + + if callback: + # Make sure that the error handler was called + assert self.err_handler_called.is_set() + assert self.received == error + # Make sure that get_updates was called + assert get_updates_event.is_set() + + # Make sure that Updater polling keeps running + self.err_handler_called.clear() + get_updates_event.clear() + await get_updates_event.wait() + if callback: + # Make sure that the error handler was called + assert self.err_handler_called.is_set() + assert self.received == error + await updater.stop() + # @pytest.mark.parametrize('ext_bot', [True, False]) # def test_webhook(self, monkeypatch, updater, ext_bot): # # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler @@ -444,7 +516,7 @@ async def get_updates(*args, **kwargs): # # Now, we send an update to the server via urlopen # update = Update( # 1, - # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), + # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), # ) # self._send_webhook_msg(ip, port, update.to_json()) # sleep(0.2) @@ -479,7 +551,7 @@ async def get_updates(*args, **kwargs): # # Now, we send an update to the server via urlopen # update = Update( # 1, - # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), + # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), # ) # self._send_webhook_msg(ip, port, update.to_json()) # sleep(0.2) @@ -515,7 +587,7 @@ async def get_updates(*args, **kwargs): # # Now, we send an update to the server via urlopen # update = Update( # 1, - # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), + # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), # ) # self._send_webhook_msg(ip, port, update.to_json()) # sleep(0.2) @@ -523,7 +595,7 @@ async def get_updates(*args, **kwargs): # updater.stop() # assert self.test_flag is True # - # @pytest.mark.parametrize(('error',), argvalues=[(TelegramError(''),)], ids=('TelegramError',)) + # @pytest.mark.parametrize(('error',),argvalues=[(TelegramError(''),)], ids=('TelegramError',)) # def test_bootstrap_retries_success(self, monkeypatch, updater, error): # retries = 2 # From 884eb4eb7d46961657081d3bfa0a4cdbb8f38a59 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 23 Feb 2022 15:53:32 +0100 Subject: [PATCH 024/153] First webhook test --- tests/conftest.py | 2 +- tests/test_updater.py | 115 +++++++++++++++++++++++++++++++++++------- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d8145b37852..8c0fab1ffda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -312,7 +312,7 @@ def make_bot(bot_info, **kwargs): DATE = datetime.datetime.now() -async def make_message(text, **kwargs): +def make_message(text, **kwargs): """ Testing utility factory to create a fake ``telegram.Message`` with reasonable defaults for mimicking a real message. diff --git a/tests/test_updater.py b/tests/test_updater.py index 9075610f820..d4fa0781048 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -17,10 +17,12 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio - -from urllib.request import Request, urlopen +from http import HTTPStatus +from random import randrange +from typing import Optional import pytest +from httpx import AsyncClient, Response from telegram import ( Bot, @@ -30,8 +32,10 @@ from telegram.error import InvalidToken, TelegramError, TimedOut, RetryAfter from telegram.ext import ( Updater, + ExtBot, ) from telegram.request import HTTPXRequest +from tests.conftest import make_message_update, make_message, DictBot class TestUpdater: @@ -60,16 +64,16 @@ def callback(self, update, context): self.received = update.message.text self.cb_handler_called.set() - def _send_webhook_msg( - self, - ip, - port, - payload_str, - url_path='', - content_len=-1, - content_type='application/json', - get_method=None, - ): + @staticmethod + async def _send_webhook_message( + ip: str, + port: int, + payload_str: Optional[str], + url_path: str = '', + content_len: int = -1, + content_type: str = 'application/json', + get_method: str = None, + ) -> Response: headers = { 'content-type': content_type, } @@ -88,12 +92,10 @@ def _send_webhook_msg( url = f'http://{ip}:{port}/{url_path}' - req = Request(url, data=payload, headers=headers) - - if get_method is not None: - req.get_method = get_method - - return urlopen(req) + async with AsyncClient() as client: + return await client.request( + url=url, method=get_method or 'POST', data=payload, headers=headers + ) def test_slot_behaviour(self, updater, mro_slots): for at in updater.__slots__: @@ -351,6 +353,82 @@ async def get_updates(*args, **kwargs): assert self.received == error await updater.stop() + @pytest.mark.asyncio + @pytest.mark.parametrize('ext_bot', [True, False]) + @pytest.mark.parametrize('drop_pending_updates', (True, False)) + async def test_webhook_basic(self, monkeypatch, updater, drop_pending_updates, ext_bot): + # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler + # that depends on this distinction works + if ext_bot and not isinstance(updater.bot, ExtBot): + updater.bot = ExtBot(updater.bot.token) + if not ext_bot and not type(updater.bot) is Bot: + updater.bot = DictBot(updater.bot.token) + + async def delete_webhook(*args, **kwargs): + # Dropping pending updates is done by passing the parameter to delete_webhook + if kwargs.get('drop_pending_updates'): + self.message_count += 1 + return True + + async def set_webhook(*args, **kwargs): + return True + + monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + + async with updater: + await updater.start_webhook( + drop_pending_updates=drop_pending_updates, + ip_address=ip, + port=port, + url_path='TOKEN', + ) + assert updater.running + + # Now, we send an update to the server + update = make_message_update('Webhook', message_factory=make_message) + await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') + assert (await updater.update_queue.get()).to_dict() == update.to_dict() + + # Returns Forbidden if wrong content types + response = await self._send_webhook_message( + ip, port, None, 'TOKEN', content_type='invalid' + ) + assert response.status_code == HTTPStatus.FORBIDDEN + + # Returns Not Found if path is incorrect + response = await self._send_webhook_message(ip, port, '123456', 'webhook_handler.py') + assert response.status_code == HTTPStatus.NOT_FOUND + + # Returns METHOD_NOT_ALLOWED if method is not allowed + response = await self._send_webhook_message(ip, port, None, 'TOKEN', get_method='HEAD') + assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED + + await updater.stop() + assert not updater.running + + if drop_pending_updates: + assert self.message_count == 1 + else: + assert self.message_count == 0 + + # We call the same logic twice to make sure that restarting the updater works as well + await updater.start_webhook( + drop_pending_updates=drop_pending_updates, + ip_address=ip, + port=port, + url_path='TOKEN', + ) + assert updater.running + update = make_message_update('Webhook', message_factory=make_message) + await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') + assert (await updater.update_queue.get()).to_dict() == update.to_dict() + await updater.stop() + assert not updater.running + # @pytest.mark.parametrize('ext_bot', [True, False]) # def test_webhook(self, monkeypatch, updater, ext_bot): # # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler @@ -399,6 +477,7 @@ async def get_updates(*args, **kwargs): # sleep(0.2) # assert not updater.httpd.is_running # updater.stop() + # # @pytest.mark.parametrize('invalid_data', [True, False]) # def test_webhook_arbitrary_callback_data(self, monkeypatch, updater, invalid_data): From f55ab059c5ec0cd39786828976faa6a0a7a08b70 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Feb 2022 14:08:20 +0100 Subject: [PATCH 025/153] test arbitrary callback data --- telegram/ext/_application.py | 8 +- telegram/ext/_jobqueue.py | 4 +- tests/test_updater.py | 159 +++++++++++++---------------------- 3 files changed, 61 insertions(+), 110 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 0844d762c7a..e71bb8ae719 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -631,7 +631,7 @@ async def __create_task_callback( raise exception finally: - self._mark_update_for_persistence_update(update=update) + self._mark_for_persistence_update(update=update) async def _update_fetcher(self) -> None: # Continuously fetch updates from the queue. Exit only once the signal object is found. @@ -701,7 +701,7 @@ async def process_update(self, update: object) -> None: _logger.debug('Error handler stopped further handlers.') break - self._mark_update_for_persistence_update(update=update) + self._mark_for_persistence_update(update=update) def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: """Register a handler. @@ -904,9 +904,7 @@ def migrate_chat_data( self._chat_ids_to_be_updated_in_persistence.add(new_chat_id) self._chat_ids_to_be_deleted_in_persistence.add(old_chat_id) - def _mark_update_for_persistence_update( - self, *, update: object = None, job: 'Job' = None - ) -> None: + def _mark_for_persistence_update(self, *, update: object = None, job: 'Job' = None) -> None: # TODO: This should be at the end of `Application.process_update`, when the task created # by `Application.create_task` is done and when a `Job` is done. Add tests to make sure # that this is happening diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 9a5eb09e1b4..5e40c3d04ed 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -626,9 +626,7 @@ async def _run(self, application: 'Application') -> None: await application.create_task(application.dispatch_error(None, exc, job=self)) finally: # This is internal logic of application - let's keep it private for now - application._mark_update_for_persistence_update( # pylint: disable=protected-access - job=self - ) + application._mark_for_persistence_update(job=self) # pylint: disable=protected-access def schedule_removal(self) -> None: """ diff --git a/tests/test_updater.py b/tests/test_updater.py index d4fa0781048..2aca7f1fa33 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -27,12 +27,15 @@ from telegram import ( Bot, Update, + InlineKeyboardMarkup, + InlineKeyboardButton, ) from telegram._utils.defaultvalue import DEFAULT_NONE from telegram.error import InvalidToken, TelegramError, TimedOut, RetryAfter from telegram.ext import ( Updater, ExtBot, + InvalidCallbackData, ) from telegram.request import HTTPXRequest from tests.conftest import make_message_update, make_message, DictBot @@ -429,109 +432,61 @@ async def set_webhook(*args, **kwargs): await updater.stop() assert not updater.running - # @pytest.mark.parametrize('ext_bot', [True, False]) - # def test_webhook(self, monkeypatch, updater, ext_bot): - # # Testing with both ExtBot and Bot to make sure any logic in WebhookHandler - # # that depends on this distinction works - # if ext_bot and not isinstance(updater.bot, ExtBot): - # updater.bot = ExtBot(updater.bot.token) - # if not ext_bot and not type(updater.bot) is Bot: - # updater.bot = DictBot(updater.bot.token) - # - # q = Queue() - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - # - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # updater.start_webhook(ip, port, url_path='TOKEN') - # sleep(0.2) - # try: - # # Now, we send an update to the server via urlopen - # update = Update( - # 1, - # message=Message( - # 1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook' - # ), - # ) - # self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') - # sleep(0.2) - # assert q.get(False) == update - # - # # Returns 404 if path is incorrect - # with pytest.raises(HTTPError) as excinfo: - # self._send_webhook_msg(ip, port, None, 'webookhandler.py') - # assert excinfo.value.code == 404 - # - # with pytest.raises(HTTPError) as excinfo: - # self._send_webhook_msg( - # ip, port, None, 'webookhandler.py', get_method=lambda: 'HEAD' - # ) - # assert excinfo.value.code == 404 - # - # # Test multiple shutdown() calls - # updater.httpd.shutdown() - # finally: - # updater.httpd.shutdown() - # sleep(0.2) - # assert not updater.httpd.is_running - # updater.stop() + @pytest.mark.parametrize('invalid_data', [True, False], ids=('invalid data', 'valid data')) + @pytest.mark.asyncio + async def test_webhook_arbitrary_callback_data( + self, monkeypatch, updater, invalid_data, chat_id + ): + """Here we only test one simple setup. telegram.ext.ExtBot.insert_callback_data is tested + extensively in test_bot.py in conjunction with get_updates.""" + updater.bot.arbitrary_callback_data = True + + async def return_true(*args, **kwargs): + return True + + try: + monkeypatch.setattr(updater.bot, 'set_webhook', return_true) + monkeypatch.setattr(updater.bot, 'delete_webhook', return_true) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + + async with updater: + await updater.start_webhook(ip, port, url_path='TOKEN') + # Now, we send an update to the server + reply_markup = InlineKeyboardMarkup.from_button( + InlineKeyboardButton(text='text', callback_data='callback_data') + ) + if not invalid_data: + reply_markup = updater.bot.callback_data_cache.process_keyboard(reply_markup) + + update = make_message_update( + message='test_webhook_arbitrary_callback_data', + message_factory=make_message, + reply_markup=reply_markup, + user=updater.bot.bot, + ) + + await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') + received_update = await updater.update_queue.get() + + assert received_update.update_id == update.update_id + message_dict = update.message.to_dict() + received_dict = received_update.message.to_dict() + message_dict.pop('reply_markup') + received_dict.pop('reply_markup') + assert message_dict == received_dict + + button = received_update.message.reply_markup.inline_keyboard[0][0] + if invalid_data: + assert isinstance(button.callback_data, InvalidCallbackData) + else: + assert button.callback_data == 'callback_data' + finally: + updater.bot.arbitrary_callback_data = False + updater.bot.callback_data_cache.clear_callback_data() + updater.bot.callback_data_cache.clear_callback_queries() - # - # @pytest.mark.parametrize('invalid_data', [True, False]) - # def test_webhook_arbitrary_callback_data(self, monkeypatch, updater, invalid_data): - # """Here we only test one simple setup. telegram.ext.ExtBot.insert_callback_data is tested - # extensively in test_bot.py in conjunction with get_updates.""" - # updater.bot.arbitrary_callback_data = True - # try: - # q = Queue() - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - # - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # updater.start_webhook(ip, port, url_path='TOKEN') - # sleep(0.2) - # try: - # # Now, we send an update to the server via urlopen - # reply_markup = InlineKeyboardMarkup.from_button( - # InlineKeyboardButton(text='text', callback_data='callback_data') - # ) - # if not invalid_data: - # reply_markup = updater.bot.callback_data_cache.process_keyboard(reply_markup) - # - # message = Message( - # 1, - # None, - # None, - # reply_markup=reply_markup, - # ) - # update = Update(1, message=message) - # self._send_webhook_msg(ip, port, update.to_json(), 'TOKEN') - # sleep(0.2) - # received_update = q.get(False) - # assert received_update == update - # - # button = received_update.message.reply_markup.inline_keyboard[0][0] - # if invalid_data: - # assert isinstance(button.callback_data, InvalidCallbackData) - # else: - # assert button.callback_data == 'callback_data' - # - # # Test multiple shutdown() calls - # updater.httpd.shutdown() - # finally: - # updater.httpd.shutdown() - # sleep(0.2) - # assert not updater.httpd.is_running - # updater.stop() - # finally: - # updater.bot.arbitrary_callback_data = False - # updater.bot.callback_data_cache.clear_callback_data() - # updater.bot.callback_data_cache.clear_callback_queries() - # # @pytest.mark.parametrize('use_dispatcher', (True, False)) # def test_start_webhook_no_warning_or_error_logs( # self, caplog, updater, monkeypatch, use_dispatcher From d4d28f52b740a18857593b294771dd89ee2d9ed7 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Feb 2022 17:44:42 +0100 Subject: [PATCH 026/153] test webhooks with ssl --- tests/test_updater.py | 224 ++++++++++++++++-------------------------- 1 file changed, 82 insertions(+), 142 deletions(-) diff --git a/tests/test_updater.py b/tests/test_updater.py index 2aca7f1fa33..8fff44c2998 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -18,6 +18,7 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio from http import HTTPStatus +from pathlib import Path from random import randrange from typing import Optional @@ -37,6 +38,7 @@ ExtBot, InvalidCallbackData, ) +from telegram.ext._utils.webhookhandler import WebhookServer from telegram.request import HTTPXRequest from tests.conftest import make_message_update, make_message, DictBot @@ -432,6 +434,23 @@ async def set_webhook(*args, **kwargs): await updater.stop() assert not updater.running + @pytest.mark.asyncio + async def test_start_webhook_already_running(self, updater, monkeypatch): + async def return_true(*args, **kwargs): + return True + + monkeypatch.setattr(updater.bot, 'set_webhook', return_true) + monkeypatch.setattr(updater.bot, 'delete_webhook', return_true) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + async with updater: + await updater.start_webhook(ip, port, url_path='TOKEN') + task = asyncio.create_task(updater.start_webhook(ip, port, url_path='TOKEN')) + with pytest.raises(RuntimeError, match='already running'): + await task + await updater.stop() + @pytest.mark.parametrize('invalid_data', [True, False], ids=('invalid data', 'valid data')) @pytest.mark.asyncio async def test_webhook_arbitrary_callback_data( @@ -487,148 +506,65 @@ async def return_true(*args, **kwargs): updater.bot.callback_data_cache.clear_callback_data() updater.bot.callback_data_cache.clear_callback_queries() - # @pytest.mark.parametrize('use_dispatcher', (True, False)) - # def test_start_webhook_no_warning_or_error_logs( - # self, caplog, updater, monkeypatch, use_dispatcher - # ): - # if not use_dispatcher: - # updater.dispatcher = None - # - # self.test_flag = 0 - # - # def set_flag(): - # self.test_flag += 1 - # - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr(updater.bot._request, 'stop', lambda *args, **kwargs: set_flag()) - # # prevent api calls from @info decorator when updater.bot.id is used in thread names - # monkeypatch.setattr(updater.bot, '_bot', User(id=123, first_name='bot', is_bot=True)) - # - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # with caplog.at_level(logging.WARNING): - # updater.start_webhook(ip, port) - # updater.stop() - # assert not caplog.records - # # Make sure that bot.request.stop() has been called exactly once - # assert self.test_flag == 1 - # - # def test_webhook_ssl(self, monkeypatch, updater): - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # tg_err = False - # try: - # updater._start_webhook( - # ip, - # port, - # url_path='TOKEN', - # cert=Path(__file__).as_posix(), - # key=Path(__file__).as_posix(), - # bootstrap_retries=0, - # drop_pending_updates=False, - # webhook_url=None, - # allowed_updates=None, - # ) - # except TelegramError: - # tg_err = True - # assert tg_err - # - # def test_webhook_no_ssl(self, monkeypatch, updater): - # q = Queue() - # monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - # - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # updater.start_webhook(ip, port, webhook_url=None) - # sleep(0.2) - # - # # Now, we send an update to the server via urlopen - # update = Update( - # 1, - # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), - # ) - # self._send_webhook_msg(ip, port, update.to_json()) - # sleep(0.2) - # assert q.get(False) == update - # updater.stop() - # - # def test_webhook_ssl_just_for_telegram(self, monkeypatch, updater): - # q = Queue() - # - # def set_webhook(**kwargs): - # self.test_flag.append(bool(kwargs.get('certificate'))) - # return True - # - # orig_wh_server_init = WebhookServer.__init__ - # - # def webhook_server_init(*args): - # self.test_flag = [args[-1] is None] - # orig_wh_server_init(*args) - # - # monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - # monkeypatch.setattr( - # 'telegram.ext._utils.webhookhandler.WebhookServer.__init__', webhook_server_init - # ) - # - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # updater.start_webhook(ip, port, webhook_url=None, cert=Path(__file__).as_posix()) - # sleep(0.2) - # - # # Now, we send an update to the server via urlopen - # update = Update( - # 1, - # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), - # ) - # self._send_webhook_msg(ip, port, update.to_json()) - # sleep(0.2) - # assert q.get(False) == update - # updater.stop() - # assert self.test_flag == [True, True] - # - # @pytest.mark.parametrize('pass_max_connections', [True, False]) - # def test_webhook_max_connections(self, monkeypatch, updater, pass_max_connections): - # q = Queue() - # max_connections = 42 - # - # def set_webhook(**kwargs): - # print(kwargs) - # self.test_flag = kwargs.get('max_connections') == ( - # max_connections if pass_max_connections else 40 - # ) - # return True - # - # monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) - # monkeypatch.setattr(updater.bot, 'delete_webhook', lambda *args, **kwargs: True) - # monkeypatch.setattr('telegram.ext.Dispatcher.process_update', lambda _, u: q.put(u)) - # - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # Select random port - # if pass_max_connections: - # updater.start_webhook(ip, port, webhook_url=None, max_connections=max_connections) - # else: - # updater.start_webhook(ip, port, webhook_url=None) - # - # sleep(0.2) - # - # # Now, we send an update to the server via urlopen - # update = Update( - # 1, - # message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Webhook 2'), - # ) - # self._send_webhook_msg(ip, port, update.to_json()) - # sleep(0.2) - # assert q.get(False) == update - # updater.stop() - # assert self.test_flag is True - # + @pytest.mark.asyncio + async def test_webhook_invalid_ssl(self, monkeypatch, updater): + async def return_true(*args, **kwargs): + return True + + monkeypatch.setattr(updater.bot, 'set_webhook', return_true) + monkeypatch.setattr(updater.bot, 'delete_webhook', return_true) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + async with updater: + with pytest.raises(TelegramError, match='Invalid SSL'): + await updater.start_webhook( + ip, + port, + url_path='TOKEN', + cert=Path(__file__).as_posix(), + key=Path(__file__).as_posix(), + bootstrap_retries=0, + drop_pending_updates=False, + webhook_url=None, + allowed_updates=None, + ) + + @pytest.mark.asyncio + async def test_webhook_ssl_just_for_telegram(self, monkeypatch, updater): + """Here we just test that the SSL info is pased to Telegram, but __not__ to the the + webhook server""" + + async def set_webhook(**kwargs): + self.test_flag.append(bool(kwargs.get('certificate'))) + return True + + async def return_true(*args, **kwargs): + return True + + orig_wh_server_init = WebhookServer.__init__ + + def webhook_server_init(*args, **kwargs): + self.test_flag = [kwargs.get('ssl_ctx') is None] + orig_wh_server_init(*args, **kwargs) + + monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + monkeypatch.setattr(updater.bot, 'delete_webhook', return_true) + monkeypatch.setattr( + 'telegram.ext._utils.webhookhandler.WebhookServer.__init__', webhook_server_init + ) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + async with updater: + await updater.start_webhook(ip, port, webhook_url=None, cert=Path(__file__).as_posix()) + + # Now, we send an update to the server + update = make_message_update(message='test_message', message_factory=make_message) + await self._send_webhook_message(ip, port, update.to_json()) + assert (await updater.update_queue.get()).to_dict() == update.to_dict() + assert self.test_flag == [True, True] + # @pytest.mark.parametrize(('error',),argvalues=[(TelegramError(''),)], ids=('TelegramError',)) # def test_bootstrap_retries_success(self, monkeypatch, updater, error): # retries = 2 @@ -718,3 +654,7 @@ async def return_true(*args, **kwargs): # finally: # updater.httpd.shutdown() # thr.join() + + # TODO: + # test_start_webhook_set/delete_webhook_parameters + # test_start_webhook_bootstrap_retries From 29a286017eb9a5371274577c98905b0034938f4e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Feb 2022 22:46:31 +0100 Subject: [PATCH 027/153] Another test --- telegram/ext/_updater.py | 2 + tests/test_updater.py | 81 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index f10207d7ba5..9a142636db7 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -437,6 +437,8 @@ async def _start_webhook( @staticmethod def _gen_webhook_url(listen: str, port: int, url_path: str) -> str: + # TODO: double check if this should be https in any case - the docs of start_webhook + # say differently! return f'https://{listen}:{port}{url_path}' async def _network_loop_retry( diff --git a/tests/test_updater.py b/tests/test_updater.py index 8fff44c2998..13368959e99 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -251,14 +251,24 @@ async def test_start_polling_get_updates_parameters(self, updater, monkeypatch): connect_timeout=DEFAULT_NONE, pool_timeout=DEFAULT_NONE, allowed_updates=None, + api_kwargs=None, ) async def get_updates(*args, **kwargs): for key, value in expected.items(): - assert kwargs.get(key) == value - await update_queue.get() + assert kwargs.pop(key, None) == value + + offset = kwargs.pop('offset', None) + # Check that we don't get any unexpected kwargs + assert kwargs == {} + + if offset is not None and self.message_count != 0: + assert offset == self.message_count + 1, "get_updates got wrong `offset` parameter" + + update = await update_queue.get() + self.message_count = update.update_id update_queue.task_done() - return [] + return [update] monkeypatch.setattr(updater.bot, 'get_updates', get_updates) @@ -274,9 +284,10 @@ async def get_updates(*args, **kwargs): connect_timeout=45, pool_timeout=46, allowed_updates=['message'], + api_kwargs=None, ) - await update_queue.put(Update(update_id=1)) + await update_queue.put(Update(update_id=2)) await updater.start_polling( timeout=42, read_timeout=43, @@ -451,6 +462,68 @@ async def return_true(*args, **kwargs): await task await updater.stop() + @pytest.mark.asyncio + async def test_start_webhook_parameters_passing(self, updater, monkeypatch): + expected_delete_webhook = dict( + drop_pending_updates=None, + ) + + expected_set_webhook = dict( + certificate=None, + max_connections=40, + allowed_updates=None, + ip_address=None, + **expected_delete_webhook, + ) + + async def set_webhook(*args, **kwargs): + for key, value in expected_set_webhook.items(): + assert kwargs.pop(key, None) == value, f"set, {key}, {value}" + + # TODO: double check if this should be https + assert kwargs in ({'url': 'https://127.0.0.1:80/'}, {'url': 'https://listen:80/'}) + return True + + async def delete_webhook(*args, **kwargs): + for key, value in expected_delete_webhook.items(): + assert kwargs.pop(key, None) == value, f"delete, {key}, {value}" + + assert kwargs == {} + return True + + async def serve_forever(*args, **kwargs): + kwargs.get('ready').set() + + monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) + monkeypatch.setattr(WebhookServer, 'serve_forever', serve_forever) + + async with updater: + await updater.start_webhook() + await updater.stop() + expected_delete_webhook = dict( + drop_pending_updates=True, + api_kwargs=None, + ) + + expected_set_webhook = dict( + certificate='certificate', + max_connections=47, + allowed_updates=['message'], + ip_address='123.456.789', + **expected_delete_webhook, + ) + + await updater.start_webhook( + listen='listen', + allowed_updates=['message'], + drop_pending_updates=True, + ip_address='123.456.789', + max_connections=47, + cert='certificate', + ) + await updater.stop() + @pytest.mark.parametrize('invalid_data', [True, False], ids=('invalid data', 'valid data')) @pytest.mark.asyncio async def test_webhook_arbitrary_callback_data( From c9180fae9e0636d630cf71ed6125d660fd8c6e53 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 25 Feb 2022 16:26:26 +0100 Subject: [PATCH 028/153] adjust ssl handling in webhooks --- telegram/ext/_updater.py | 11 ++++++++--- tests/data/sslcert.key | 28 ++++++++++++++++++++++++++++ tests/data/sslcert.pem | 23 +++++++++++++++++++++++ tests/test_updater.py | 23 ++++++++++++++++++++--- 4 files changed, 79 insertions(+), 6 deletions(-) create mode 100644 tests/data/sslcert.key create mode 100644 tests/data/sslcert.pem diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 9a142636db7..86440d4b67e 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -410,7 +410,12 @@ async def _start_webhook( self._httpd = WebhookServer(listen, port, app, ssl_ctx) if not webhook_url: - webhook_url = self._gen_webhook_url(listen, port, url_path) + webhook_url = self._gen_webhook_url( + protocol='https' if ssl_ctx else 'http', + listen=listen, + port=port, + url_path=url_path, + ) # We pass along the cert to the webhook if present. if cert is not None: @@ -436,10 +441,10 @@ async def _start_webhook( await self._httpd.serve_forever(ready=ready) @staticmethod - def _gen_webhook_url(listen: str, port: int, url_path: str) -> str: + def _gen_webhook_url(protocol: str, listen: str, port: int, url_path: str) -> str: # TODO: double check if this should be https in any case - the docs of start_webhook # say differently! - return f'https://{listen}:{port}{url_path}' + return f'{protocol}://{listen}:{port}{url_path}' async def _network_loop_retry( self, diff --git a/tests/data/sslcert.key b/tests/data/sslcert.key new file mode 100644 index 00000000000..aff15436273 --- /dev/null +++ b/tests/data/sslcert.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC56MwT6O0MyarZ +hpHdFDNugvFRQHYfnkuK7CLRTnR7XpDawW3ByByhKf/SWdjMphzuR1NklOPyKMSv +Wmr9+2grr1hIM8Ca++yqGb+GnkyHsdrIBlxDGgwdZ+nwzkRVcdmZhebwYuMGcNp6 +VnlZXnfPPFiqEd7ZsOM9GkTrL7dDnfpZG3CvpFMBQCLdQNyNmcg0dip9t93wU174 +tCNMWnTGTT5b/pghzdkJXblQUhIjNT9E1g1/iHgcckcFWUGqAXsRfI+EDz9afQB9 +8VHKZGPulVOxCtLAqKwg3JSTSLR+Z6omr+KmgRKfwzHpdctvBZrsoPOvr7zmG/cr +iTIB5KU7AgMBAAECggEANMEXsAqnwbo0The+qnKCCbkEi170ZhKAM0LAuo49xYhX +KIw8/gEwBpepbWJrf98fVIpO4rrRWDUzYuMQe1PtAoB2V76/x/r29GnsDGI9K0BP +6fTMF4p7p5iGLPwLLgfpjIQPvWUCMSCzDoYdVzvUWa0xJ8l8aF+mi/85UVev9HKS +l0RXXYg4nVT6EU+eXGZEINjC+DAd1zuVM1VMJJohgLSguvOgVypuloDRP4m6Uhuq +v/Lk1e+dwRQlXMURKgIIFU2Nnu9QY+KqVkSLomRg+hh2O2o5kKtiI3l4b22wgtOc ++gORzL5jIpkYPHOsMa7euCDqHHKhJgnKhvy6u1ZICQKBgQDr953WQXdDrVvWaB2v +IxXfyIo8ZIw5tacX6tONPW97mLySCL5mIEPEaHMvnSq759wYorEu85Jj3cvETUFy +u6022xipJ/NsVBkPywr8JFsnT26ENuei5KoF4fl/cg/INetSS/0HbmQZWvjOhh+Q +0LlngelkaCLdi0ymDue1uLgL9wKBgQDJsT/zsExVXBoMTj1u25dvIXp53J04unQp +qmndUxFgy1vuT08SbjHjK2EeZD/M5OLkXurdIZZ3kXPHMM/bKui92uGRTw9LCAAN +tDkNw+E+EwZfwbZsu4k3mWbSN16dO85K+Yo8hjsRLgvqQadMwbAz7RxEFiKX5tlG +gGaZkIH33QKBgFT8lgh5A6+IXK9YSHivtk0nOUKPJEIUvt3KYe9Y1TI6zI/8Pjci +H8Y5qGLZxG5xD8B/uDkk2PDHDYDiIlRka/p55uPl07KMh4o8ovQ1U+9QmIleDQeK +PAJqZSYVusFtShgV7kgi5kKLlVkszWmnA1/YVmsnZodMiIq2i5XTtdX5AoGBAK2r +4tWDSTd3RzaxaFS84Xjf6wZj4T2nz77Q7reVf7FJaq+ZuwyztmFWSRpSWF2l+XmM +AdDHyzjKFle+wDyIhkB06SamXRTOnr0uIrKnqJw65ZIuy1Z1ZYJqpQ7+fooFpW0J +0u6q5tG0RK5COjztyzvrQBugs8j5Dr6WccJpnIBBAoGBAMOm2g9OlSu8tbFXK9GJ +sFadmjXgM1quDkCfLJgJInw20YCy6NFnujbgczbrxpOg9sk6Gqbznw0iguU2mAZQ +UtDt3mbKrtUtR4kPFFwG51OgFx3D4TJM8EkKLKzthxGKjgJuRtP6glRgHTMIlwmT +Lmi6uZuyrC8kxwQiV2cmlA5u +-----END PRIVATE KEY----- diff --git a/tests/data/sslcert.pem b/tests/data/sslcert.pem new file mode 100644 index 00000000000..87d5aaba9fa --- /dev/null +++ b/tests/data/sslcert.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID4zCCAsugAwIBAgIUbxUiUtDxld8EMB7W+gh02eBeqJgwDQYJKoZIhvcNAQEL +BQAwgYAxCzAJBgNVBAYTAlRHMQwwCgYDVQQIDANQVEIxDDAKBgNVBAcMA1BUQjEM +MAoGA1UECgwDUFRCMQwwCgYDVQQLDANQVEIxDDAKBgNVBAMMA1BUQjErMCkGCSqG +SIb3DQEJARYcZGV2c0BweXRob24tdGVsZWdyYW0tYm90Lm9yZzAeFw0yMjAyMjUx +MDEzMjFaFw0zMjAyMjMxMDEzMjFaMIGAMQswCQYDVQQGEwJURzEMMAoGA1UECAwD +UFRCMQwwCgYDVQQHDANQVEIxDDAKBgNVBAoMA1BUQjEMMAoGA1UECwwDUFRCMQww +CgYDVQQDDANQVEIxKzApBgkqhkiG9w0BCQEWHGRldnNAcHl0aG9uLXRlbGVncmFt +LWJvdC5vcmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC56MwT6O0M +yarZhpHdFDNugvFRQHYfnkuK7CLRTnR7XpDawW3ByByhKf/SWdjMphzuR1NklOPy +KMSvWmr9+2grr1hIM8Ca++yqGb+GnkyHsdrIBlxDGgwdZ+nwzkRVcdmZhebwYuMG +cNp6VnlZXnfPPFiqEd7ZsOM9GkTrL7dDnfpZG3CvpFMBQCLdQNyNmcg0dip9t93w +U174tCNMWnTGTT5b/pghzdkJXblQUhIjNT9E1g1/iHgcckcFWUGqAXsRfI+EDz9a +fQB98VHKZGPulVOxCtLAqKwg3JSTSLR+Z6omr+KmgRKfwzHpdctvBZrsoPOvr7zm +G/criTIB5KU7AgMBAAGjUzBRMB0GA1UdDgQWBBRhCKLkt3RjoaSiV14n1u8590Pf +HDAfBgNVHSMEGDAWgBRhCKLkt3RjoaSiV14n1u8590PfHDAPBgNVHRMBAf8EBTAD +AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQB1yXCnOWxZqhda5sKIQLwHPORz9kfPplYZ +RxLZaymGCrieRr0NWPy1CezBsXNES1ICpEZ02P6Bel8GEzGS5cAbYPvIP8qzz/Ic +zgN5QG86klixLO6Q7VWYRGMFEI9d/2/UVGbw6KltIQt0bznoKvkrTnNTydQc/L7e +Ae+oqVl3OUuhtdU0DOjncEVKWKY0Hl18juSkTO59oHaL3R0SeNZ38chv9wtSRE3+ +ACDH51i6L9cwG0hdpuIx1UKkSDvU4ci9YnZsTdwkjbvi8VX68Sn9WsnZq0k4V4vt ++uhH8RVdxHp/TSv5LSOTMCg2v33dZjW/xOnvpRQvZNNXBxOi/ZH8 +-----END CERTIFICATE----- diff --git a/tests/test_updater.py b/tests/test_updater.py index 13368959e99..20aaaac51ca 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -40,7 +40,7 @@ ) from telegram.ext._utils.webhookhandler import WebhookServer from telegram.request import HTTPXRequest -from tests.conftest import make_message_update, make_message, DictBot +from tests.conftest import make_message_update, make_message, DictBot, data_file class TestUpdater: @@ -480,8 +480,11 @@ async def set_webhook(*args, **kwargs): for key, value in expected_set_webhook.items(): assert kwargs.pop(key, None) == value, f"set, {key}, {value}" - # TODO: double check if this should be https - assert kwargs in ({'url': 'https://127.0.0.1:80/'}, {'url': 'https://listen:80/'}) + assert kwargs in ( + {'url': 'http://127.0.0.1:80/'}, + {'url': 'http://listen:80/'}, + {'url': 'https://listen-ssl:42/ssl-path'}, + ) return True async def delete_webhook(*args, **kwargs): @@ -524,6 +527,20 @@ async def serve_forever(*args, **kwargs): ) await updater.stop() + expected_set_webhook['certificate'] = data_file('sslcert.pem') + await updater.start_webhook( + listen='listen-ssl', + port=42, + url_path='ssl-path', + allowed_updates=['message'], + drop_pending_updates=True, + ip_address='123.456.789', + max_connections=47, + cert=data_file('sslcert.pem'), + key=data_file('sslcert.key'), + ) + await updater.stop() + @pytest.mark.parametrize('invalid_data', [True, False], ids=('invalid data', 'valid data')) @pytest.mark.asyncio async def test_webhook_arbitrary_callback_data( From 8b4aa7ab1e14f8e270440a2e0e815641f497246a Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 25 Feb 2022 22:23:09 +0100 Subject: [PATCH 029/153] Finish Updater tests --- telegram/ext/_updater.py | 15 +-- tests/test_updater.py | 245 +++++++++++++++++++++------------------ 2 files changed, 136 insertions(+), 124 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 86440d4b67e..19ab0dad796 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -122,16 +122,6 @@ async def __aexit__( # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ await self.shutdown() - async def _task_wrapper( - self, target: Callable, name: str, *args: object, **kwargs: object - ) -> None: - self._logger.debug('%s - started', name) - try: - await target(*args, **kwargs) - except Exception: - self._logger.exception('Unhandled exception in %s.', name) - self._logger.debug('%s - ended', name) - async def start_polling( self, poll_interval: float = 0.0, @@ -250,7 +240,10 @@ async def polling_action_cb() -> bool: if updates: if not self.running: - self._logger.debug('Updates ignored and will be pulled again on restart') + self._logger.critical( + 'Updater stopped unexpectedly. Pulled updates will be ignored and again ' + 'on restart.' + ) else: for update in updates: await self.update_queue.put(update) diff --git a/tests/test_updater.py b/tests/test_updater.py index 20aaaac51ca..cd8dba3255d 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio +import logging from http import HTTPStatus from pathlib import Path from random import randrange @@ -322,7 +323,7 @@ async def do_request(*args, **kwargs): ) @pytest.mark.parametrize( - 'error,callback', + 'error,callback_should_be_called', argvalues=[ (TelegramError('TestMessage'), True), (RetryAfter(1), False), @@ -330,9 +331,10 @@ async def do_request(*args, **kwargs): ], ids=('TelegramError', 'RetryAfter', 'TimedOut'), ) + @pytest.mark.parametrize('custom_error_callback', [True, False]) @pytest.mark.asyncio async def test_start_polling_exceptions_and_error_callback( - self, monkeypatch, updater, error, callback + self, monkeypatch, updater, error, callback_should_be_called, custom_error_callback, caplog ): get_updates_event = asyncio.Event() @@ -349,26 +351,81 @@ async def get_updates(*args, **kwargs): async with updater: self.err_handler_called = asyncio.Event() - await updater.start_polling(error_callback=self.error_callback) - await asyncio.sleep(1) + with caplog.at_level(logging.ERROR): + if custom_error_callback: + await updater.start_polling(error_callback=self.error_callback) + else: + await updater.start_polling() + + # Also makes sure that the error handler was called + await get_updates_event.wait() - if callback: - # Make sure that the error handler was called - assert self.err_handler_called.is_set() - assert self.received == error - # Make sure that get_updates was called - assert get_updates_event.is_set() + if callback_should_be_called: + # Make sure that the error handler was called + if custom_error_callback: + assert self.received == error + else: + assert len(caplog.records) > 0 + records = (record.getMessage() for record in caplog.records) + assert 'Error while getting Updates: TestMessage' in records + + # Make sure that get_updates was called + assert get_updates_event.is_set() # Make sure that Updater polling keeps running self.err_handler_called.clear() get_updates_event.clear() + caplog.clear() + + # Also makes sure that the error handler was called await get_updates_event.wait() - if callback: - # Make sure that the error handler was called - assert self.err_handler_called.is_set() - assert self.received == error + + if callback_should_be_called: + if callback_should_be_called: + if custom_error_callback: + assert self.received == error + else: + assert len(caplog.records) > 0 + records = (record.getMessage() for record in caplog.records) + assert 'Error while getting Updates: TestMessage' in records await updater.stop() + @pytest.mark.asyncio + async def test_start_polling_unexpected_shutdown(self, updater, monkeypatch, caplog): + update_queue = asyncio.Queue() + await update_queue.put(Update(update_id=1)) + await update_queue.put(Update(update_id=2)) + first_update_event = asyncio.Event() + second_update_event = asyncio.Event() + + async def get_updates(*args, **kwargs): + self.message_count = kwargs.get('offset') + update = await update_queue.get() + if update.update_id == 1: + first_update_event.set() + else: + await second_update_event.wait() + return [update] + + monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + + async with updater: + with caplog.at_level(logging.ERROR): + await updater.start_polling() + + await first_update_event.wait() + # Unfortunately we need to use the private attribute here to produce the problem + updater._running = False + second_update_event.set() + + await asyncio.sleep(0.1) + assert caplog.records + records = (record.getMessage() for record in caplog.records) + assert any('Updater stopped unexpectedly.' in record for record in records) + + # Make sure that the update_id offset wasn't increased + assert self.message_count == 2 + @pytest.mark.asyncio @pytest.mark.parametrize('ext_bot', [True, False]) @pytest.mark.parametrize('drop_pending_updates', (True, False)) @@ -409,12 +466,6 @@ async def set_webhook(*args, **kwargs): await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') assert (await updater.update_queue.get()).to_dict() == update.to_dict() - # Returns Forbidden if wrong content types - response = await self._send_webhook_message( - ip, port, None, 'TOKEN', content_type='invalid' - ) - assert response.status_code == HTTPStatus.FORBIDDEN - # Returns Not Found if path is incorrect response = await self._send_webhook_message(ip, port, '123456', 'webhook_handler.py') assert response.status_code == HTTPStatus.NOT_FOUND @@ -655,96 +706,64 @@ def webhook_server_init(*args, **kwargs): assert (await updater.update_queue.get()).to_dict() == update.to_dict() assert self.test_flag == [True, True] - # @pytest.mark.parametrize(('error',),argvalues=[(TelegramError(''),)], ids=('TelegramError',)) - # def test_bootstrap_retries_success(self, monkeypatch, updater, error): - # retries = 2 - # - # def attempt(*args, **kwargs): - # if self.attempts < retries: - # self.attempts += 1 - # raise error - # - # monkeypatch.setattr(updater.bot, 'set_webhook', attempt) - # - # updater.running = True - # updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) - # assert self.attempts == retries - # - # @pytest.mark.parametrize( - # ('error', 'attempts'), - # argvalues=[(TelegramError(''), 2), (Unauthorized(''), 1), (InvalidToken(), 1)], - # ids=('TelegramError', 'Unauthorized', 'InvalidToken'), - # ) - # def test_bootstrap_retries_error(self, monkeypatch, updater, error, attempts): - # retries = 1 - # - # def attempt(*args, **kwargs): - # self.attempts += 1 - # raise error - # - # monkeypatch.setattr(updater.bot, 'set_webhook', attempt) - # - # updater.running = True - # with pytest.raises(type(error)): - # updater._bootstrap(retries, False, 'path', None, bootstrap_interval=0) - # assert self.attempts == attempts - # - # @pytest.mark.parametrize('drop_pending_updates', (True, False)) - # def test_bootstrap_clean_updates(self, monkeypatch, updater, drop_pending_updates): - # # As dropping pending updates is done by passing `drop_pending_updates` to - # # set_webhook, we just check that we pass the correct value - # self.test_flag = False - # - # def delete_webhook(**kwargs): - # self.test_flag = kwargs.get('drop_pending_updates') == drop_pending_updates - # - # monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) - # - # updater.running = True - # updater._bootstrap( - # 1, - # drop_pending_updates=drop_pending_updates, - # webhook_url=None, - # allowed_updates=None, - # bootstrap_interval=0, - # ) - # assert self.test_flag is True - # - # @flaky(3, 1) - # def test_webhook_invalid_posts(self, updater): - # ip = '127.0.0.1' - # port = randrange(1024, 49152) # select random port for travis - # thr = Thread( - # target=updater._start_webhook, args=(ip, port, '', None, None, 0, False, None, None) - # ) - # thr.start() - # - # sleep(0.2) - # - # try: - # with pytest.raises(HTTPError) as excinfo: - # self._send_webhook_msg( - # ip, port, 'data', content_type='application/xml' - # ) - # assert excinfo.value.code == 403 - # - # with pytest.raises(HTTPError) as excinfo: - # self._send_webhook_msg(ip, port, 'dummy-payload', content_len=-2) - # assert excinfo.value.code == 500 - # - # # TODO: prevent urllib or the underlying from adding content-length - # # with pytest.raises(HTTPError) as excinfo: - # # self._send_webhook_msg(ip, port, 'dummy-payload', content_len=None) - # # assert excinfo.value.code == 411 - # - # with pytest.raises(HTTPError): - # self._send_webhook_msg(ip, port, 'dummy-payload', content_len='not-a-number') - # assert excinfo.value.code == 500 - # - # finally: - # updater.httpd.shutdown() - # thr.join() - - # TODO: - # test_start_webhook_set/delete_webhook_parameters - # test_start_webhook_bootstrap_retries + @pytest.mark.asyncio + @pytest.mark.parametrize('exception_class', (InvalidToken, TelegramError)) + @pytest.mark.parametrize('retries', (3, 0)) + async def test_start_webhook_bootstrap_retries( + self, updater, monkeypatch, exception_class, retries + ): + async def do_request(*args, **kwargs): + self.message_count += 1 + raise exception_class(str(self.message_count)) + + monkeypatch.setattr(HTTPXRequest, 'do_request', do_request) + + async with updater: + if exception_class == InvalidToken: + with pytest.raises(InvalidToken, match='1'): + await updater.start_webhook(bootstrap_retries=retries) + else: + with pytest.raises(TelegramError, match=str(retries + 1)): + await updater.start_webhook( + bootstrap_retries=retries, + ) + + @pytest.mark.asyncio + async def test_webhook_invalid_posts(self, updater, monkeypatch): + async def return_true(*args, **kwargs): + return True + + monkeypatch.setattr(updater.bot, 'set_webhook', return_true) + monkeypatch.setattr(updater.bot, 'delete_webhook', return_true) + + ip = '127.0.0.1' + port = randrange(1024, 49152) + + async with updater: + await updater.start_webhook(listen=ip, port=port) + + response = await self._send_webhook_message(ip, port, None, content_type='invalid') + assert response.status_code == HTTPStatus.FORBIDDEN + + response = await self._send_webhook_message( + ip, + port, + payload_str='data', + content_type='application/xml', + ) + assert response.status_code == HTTPStatus.FORBIDDEN + + response = await self._send_webhook_message( + ip, port, 'dummy-payload', content_len=None + ) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + # httpx already complains about bad content length in _send_webhook_message + # before the requests below reach the webhook, but not testing this is probably + # okay + # response = await self._send_webhook_message( + # ip, port, 'dummy-payload', content_len=-2) + # assert response.status_code == HTTPStatus.FORBIDDEN + # response = await self._send_webhook_message( + # ip, port, 'dummy-payload', content_len='not-a-number') + # assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR From f2624574f5e1171d5f07a62b653be57cdb7e21a8 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 26 Feb 2022 11:48:10 +0100 Subject: [PATCH 030/153] Try fixing existing tests --- telegram/ext/_application.py | 2 -- tests/test_bot.py | 27 +++++++++++++++++++-------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index e71bb8ae719..06558f6d848 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -120,8 +120,6 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): updater (:class:`telegram.ext.Updater`, optional): The updater used by this application. job_queue (:class:`telegram.ext.JobQueue`): Optional. The :class:`telegram.ext.JobQueue` instance to pass onto handler callbacks. - concurrent_updates (:obj:`int`, optional): Number updates that may be processed in - parallel. chat_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for the chat. diff --git a/tests/test_bot.py b/tests/test_bot.py index f6dda21d70e..d27bee6a301 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -20,6 +20,7 @@ import asyncio import inspect import logging +import socket import time import datetime as dtm from collections import defaultdict @@ -1656,16 +1657,26 @@ async def post(*args, **kwargs): @flaky(3, 1) @pytest.mark.asyncio - async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot): + @pytest.mark.parametrize('use_ip', [True, False]) + async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot, use_ip): url = 'https://python-telegram-bot.org/test/webhook' max_connections = 7 allowed_updates = ['message'] - await bot.set_webhook( - url, - max_connections=max_connections, - allowed_updates=allowed_updates, - ip_address='198.51.100.127', - ) + if use_ip: + # Get the ip address of the website - dynamically just in case it ever changes + ip = socket.gethostbyname('python-telegram-bot.org') + await bot.set_webhook( + url, + max_connections=max_connections, + allowed_updates=allowed_updates, + ip_address=ip, + ) + else: + await bot.set_webhook( + url, + max_connections=max_connections, + allowed_updates=allowed_updates, + ) await asyncio.sleep(2) live_info = await bot.get_webhook_info() await asyncio.sleep(6) @@ -1676,7 +1687,7 @@ async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot): assert live_info.url == url assert live_info.max_connections == max_connections assert live_info.allowed_updates == allowed_updates - assert live_info.ip_address == '198.51.100.142' + assert live_info.ip_address == ip @pytest.mark.parametrize('drop_pending_updates', [True, False]) @pytest.mark.asyncio From a4080a491054340a32c963a9e9e10bed75f5e911 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 26 Feb 2022 12:35:17 +0100 Subject: [PATCH 031/153] Try harder --- tests/conftest.py | 16 ++++++---------- tests/test_bot.py | 42 +++++++++++++++++++----------------------- tests/test_sticker.py | 18 +++++++++--------- tests/test_updater.py | 24 +++++++++++++++--------- 4 files changed, 49 insertions(+), 51 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8c0fab1ffda..2aa22e908d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,15 +156,10 @@ class DictApplication(Application): pass -@pytest.fixture(scope='function') +@pytest.fixture(scope='session') @pytest.mark.asyncio async def bot(bot_info): - async with DictExtBot( - bot_info['token'], - private_key=PRIVATE_KEY, - request=TestHttpxRequest(8), - get_updates_request=TestHttpxRequest(1), - ) as _bot: + async with make_bot(bot_info) as _bot: yield _bot @@ -260,8 +255,9 @@ def app(_app): @pytest.fixture(scope='function') @pytest.mark.asyncio -async def updater(bot): - up = Updater(bot=bot, update_queue=asyncio.Queue()) +async def updater(bot_info): + # We build a new bot each time so that we use `updater` in a context manager without problems + up = Updater(bot=make_bot(bot_info), update_queue=asyncio.Queue()) yield up if up.running: await up.stop() @@ -298,7 +294,7 @@ def make_bot(bot_info, **kwargs): """ Tests are executed on tg.ext.ExtBot, as that class only extends the functionality of tg.bot """ - _bot = ExtBot( + _bot = DictExtBot( bot_info['token'], private_key=PRIVATE_KEY, request=TestHttpxRequest(8), diff --git a/tests/test_bot.py b/tests/test_bot.py index d27bee6a301..cba01a13e96 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -162,8 +162,9 @@ class TestBot: def reset(self): self.test_flag = None - @pytest.mark.parametrize('inst', ['bot', "default_bot"], indirect=True) - def test_slot_behaviour(self, inst, mro_slots): + @pytest.mark.parametrize('bot_class', [Bot, ExtBot]) + def test_slot_behaviour(self, bot_class, bot, mro_slots): + inst = bot_class(bot.token) for attr in inst.__slots__: assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" @@ -1660,35 +1661,30 @@ async def post(*args, **kwargs): @pytest.mark.parametrize('use_ip', [True, False]) async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot, use_ip): url = 'https://python-telegram-bot.org/test/webhook' + # Get the ip address of the website - dynamically just in case it ever changes + ip = socket.gethostbyname('python-telegram-bot.org') max_connections = 7 allowed_updates = ['message'] - if use_ip: - # Get the ip address of the website - dynamically just in case it ever changes - ip = socket.gethostbyname('python-telegram-bot.org') - await bot.set_webhook( - url, - max_connections=max_connections, - allowed_updates=allowed_updates, - ip_address=ip, - ) - else: - await bot.set_webhook( - url, - max_connections=max_connections, - allowed_updates=allowed_updates, - ) - await asyncio.sleep(2) + await bot.set_webhook( + url, + max_connections=max_connections, + allowed_updates=allowed_updates, + ip_address=ip if use_ip else None, + ) + + await asyncio.sleep(1) live_info = await bot.get_webhook_info() - await asyncio.sleep(6) - await bot.delete_webhook() - await asyncio.sleep(2) - info = await bot.get_webhook_info() - assert info.url == '' assert live_info.url == url assert live_info.max_connections == max_connections assert live_info.allowed_updates == allowed_updates assert live_info.ip_address == ip + await bot.delete_webhook() + await asyncio.sleep(1) + info = await bot.get_webhook_info() + assert info.url == '' + assert info.ip_address is None + @pytest.mark.parametrize('drop_pending_updates', [True, False]) @pytest.mark.asyncio async def test_set_webhook_delete_webhook_drop_pending_updates( diff --git a/tests/test_sticker.py b/tests/test_sticker.py index 022093ba162..181b740a766 100644 --- a/tests/test_sticker.py +++ b/tests/test_sticker.py @@ -16,9 +16,9 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio import os from pathlib import Path -from time import sleep import pytest from flaky import flaky @@ -62,13 +62,13 @@ async def animated_sticker(bot, chat_id): @pytest.fixture(scope='function') def video_sticker_file(): - with Path('tests/data/telegram_video_sticker.webm').open('rb') as f: + with data_file('telegram_video_sticker.webm').open('rb') as f: yield f @pytest.fixture(scope='class') def video_sticker(bot, chat_id): - with Path('tests/data/telegram_video_sticker.webm').open('rb') as f: + with data_file('telegram_video_sticker.webm').open('rb') as f: return bot.send_sticker(chat_id, sticker=f, timeout=50).sticker @@ -517,7 +517,7 @@ async def test_bot_methods_1_tgs(self, bot, chat_id): @flaky(3, 1) @pytest.mark.asyncio async def test_bot_methods_1_webm(self, bot, chat_id): - with Path('tests/data/telegram_video_sticker.webm').open('rb') as f: + with data_file('telegram_video_sticker.webm').open('rb') as f: assert await bot.add_sticker_to_set( chat_id, f'video_test_by_{bot.username}', webm_sticker=f, emojis='🤔' ) @@ -554,7 +554,7 @@ async def test_bot_methods_2_webm(self, bot, video_sticker_set): @flaky(10, 1) @pytest.mark.asyncio async def test_bot_methods_3_png(self, bot, chat_id, sticker_set_thumb_file): - sleep(1) + await asyncio.sleep(1) assert await bot.set_sticker_set_thumb( f'test_by_{bot.username}', chat_id, sticker_set_thumb_file ) @@ -564,7 +564,7 @@ async def test_bot_methods_3_png(self, bot, chat_id, sticker_set_thumb_file): async def test_bot_methods_3_tgs( self, bot, chat_id, animated_sticker_file, animated_sticker_set ): - sleep(1) + await asyncio.sleep(1) animated_test = f'animated_test_by_{bot.username}' assert await bot.set_sticker_set_thumb(animated_test, chat_id, animated_sticker_file) file_id = animated_sticker_set.stickers[-1].file_id @@ -582,21 +582,21 @@ def test_bot_methods_3_webm(self, bot, chat_id, video_sticker_file, video_sticke @flaky(10, 1) @pytest.mark.asyncio async def test_bot_methods_4_png(self, bot, sticker_set): - sleep(1) + await asyncio.sleep(1) file_id = sticker_set.stickers[-1].file_id assert await bot.delete_sticker_from_set(file_id) @flaky(10, 1) @pytest.mark.asyncio async def test_bot_methods_4_tgs(self, bot, animated_sticker_set): - sleep(1) + await asyncio.sleep(1) file_id = animated_sticker_set.stickers[-1].file_id assert await bot.delete_sticker_from_set(file_id) @flaky(10, 1) @pytest.mark.asyncio async def test_bot_methods_4_webm(self, bot, video_sticker_set): - sleep(1) + await asyncio.sleep(1) file_id = video_sticker_set.stickers[-1].file_id assert await bot.delete_sticker_from_set(file_id) diff --git a/tests/test_updater.py b/tests/test_updater.py index cd8dba3255d..066dcb0f9a5 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -103,11 +103,13 @@ async def _send_webhook_message( url=url, method=get_method or 'POST', data=payload, headers=headers ) - def test_slot_behaviour(self, updater, mro_slots): - for at in updater.__slots__: - at = f"_Updater{at}" if at.startswith('__') and not at.endswith('__') else at - assert getattr(updater, at, 'err') != 'err', f"got extra slot '{at}'" - assert len(mro_slots(updater)) == len(set(mro_slots(updater))), "duplicate slot" + @pytest.mark.asyncio + async def test_slot_behaviour(self, updater, mro_slots): + async with updater: + for at in updater.__slots__: + at = f"_Updater{at}" if at.startswith('__') and not at.endswith('__') else at + assert getattr(updater, at, 'err') != 'err', f"got extra slot '{at}'" + assert len(mro_slots(updater)) == len(set(mro_slots(updater))), "duplicate slot" def test_init(self, bot): queue = asyncio.Queue() @@ -310,9 +312,11 @@ async def do_request(*args, **kwargs): self.message_count += 1 raise exception_class(str(self.message_count)) - monkeypatch.setattr(HTTPXRequest, 'do_request', do_request) - async with updater: + # Patch within the context so that updater.bot.initialize can still be called + # by the context manager + monkeypatch.setattr(HTTPXRequest, 'do_request', do_request) + if exception_class == InvalidToken: with pytest.raises(InvalidToken, match='1'): await updater.start_polling(bootstrap_retries=retries) @@ -716,9 +720,11 @@ async def do_request(*args, **kwargs): self.message_count += 1 raise exception_class(str(self.message_count)) - monkeypatch.setattr(HTTPXRequest, 'do_request', do_request) - async with updater: + # Patch within the context so that updater.bot.initialize can still be called + # by the context manager + monkeypatch.setattr(HTTPXRequest, 'do_request', do_request) + if exception_class == InvalidToken: with pytest.raises(InvalidToken, match='1'): await updater.start_webhook(bootstrap_retries=retries) From e2a390745f5b0b78c1867b3b29abced4576ee06a Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 26 Feb 2022 15:46:00 +0100 Subject: [PATCH 032/153] Get started on application tests --- telegram/ext/_application.py | 3 - tests/conftest.py | 49 +- tests/test_application.py | 1225 ++++++++++++++++++++++++++++++++++ tests/test_updater.py | 6 +- 4 files changed, 1236 insertions(+), 47 deletions(-) create mode 100644 tests/test_application.py diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 06558f6d848..dd38f92aceb 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -222,9 +222,6 @@ def __init__( self._concurrent_updates_sem = asyncio.BoundedSemaphore(concurrent_updates or 1) self._concurrent_updates: int = concurrent_updates or 0 - if self.job_queue: - self.job_queue.set_application(self) - self.bot_data = self.context_types.bot_data() self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data) self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data) diff --git a/tests/conftest.py b/tests/conftest.py index 2aa22e908d9..9c8aa3b0c38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,13 +23,8 @@ import os import re -from collections import defaultdict from pathlib import Path -from queue import Queue -from threading import Thread, Event -from time import sleep from typing import Callable, List, Iterable, Any, Dict -from types import MappingProxyType import pytest import pytz @@ -211,46 +206,17 @@ def provider_token(bot_info): return bot_info['payment_provider_token'] -def create_dp(bot): - # Application is heavy to init (due to many threads and such) so we have a single session - # scoped one here, but before each test, reset it (app fixture below) +@pytest.fixture(scope='function') +@pytest.mark.asyncio +async def app(bot_info): + # We build a new bot each time so that we use `app` in a context manager without problems application = ( - ApplicationBuilder().bot(bot).workers(2).application_class(DictApplication).build() + ApplicationBuilder().bot(make_bot(bot_info)).application_class(DictApplication).build() ) - # TODO: Do we need the thread? - thr = Thread(target=application.start) - thr.start() - sleep(2) yield application - sleep(1) if application.running: - application.stop() - thr.join() - - -@pytest.fixture(scope='session') -def _app(bot): - yield from create_dp(bot) - - -@pytest.fixture(scope='function') -def app(_app): - # Reset the application first - # TODO: consider just using the builder pattern to build a new object - while not _app.update_queue.empty(): - _app.update_queue.get(False) - _app._chat_data = defaultdict(dict) - _app._user_data = defaultdict(dict) - _app.chat_data = MappingProxyType(_app._chat_data) # Rebuild the mapping so it updates - _app.user_data = MappingProxyType(_app._user_data) - _app.bot_data = {} - _app.handlers = {} - _app.error_handlers = {} - _app.__stop_event = Event() - _app.__async_queue = Queue() - _app.__async_threads = set() - _app.persistence = None - yield _app + await application.stop() + await application.shutdown() @pytest.fixture(scope='function') @@ -261,6 +227,7 @@ async def updater(bot_info): yield up if up.running: await up.stop() + await up.shutdown() PROJECT_ROOT_PATH = Path(__file__).parent.parent.resolve() diff --git a/tests/test_application.py b/tests/test_application.py new file mode 100644 index 00000000000..f30cbfd02be --- /dev/null +++ b/tests/test_application.py @@ -0,0 +1,1225 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio +from queue import Queue + +import pytest + +from telegram import Bot +from telegram.ext import ( + JobQueue, + CallbackContext, + ApplicationBuilder, + Application, + ContextTypes, + PicklePersistence, + Updater, +) + +from telegram.error import TelegramError + +from tests.conftest import make_message_update + + +class CustomContext(CallbackContext): + pass + + +class TestDispatcher: + message_update = make_message_update(message='Text') + received = None + count = 0 + + @pytest.fixture(autouse=True, name='reset') + def reset_fixture(self): + self.reset() + + def reset(self): + self.received = None + self.count = 0 + + async def error_handler_context(self, update, context): + self.received = context.error.message + + async def error_handler_raise_error(self, update, context): + raise Exception('Failing bigly') + + async def callback_increase_count(self, update, context): + self.count += 1 + + def callback_set_count(self, count): + async def callback(update, context): + self.count = count + + return callback + + async def callback_raise_error(self, update, context): + raise TelegramError(update.message.text) + + async def callback_received(self, update, context): + self.received = update.message + + async def callback_context(self, update, context): + if ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(context.update_queue, Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.error, TelegramError) + ): + self.received = context.error.message + + def test_slot_behaviour(self, bot, mro_slots): + app = ApplicationBuilder().bot(bot).build() + for at in app.__slots__: + at = f"_Application{at}" if at.startswith('__') and not at.endswith('__') else at + assert getattr(app, at, 'err') != 'err', f"got extra slot '{at}'" + assert len(mro_slots(app)) == len(set(mro_slots(app))), "duplicate slot" + + def test_manual_init_warning(self, recwarn, updater): + Application( + bot=None, + update_queue=None, + job_queue=None, + persistence=None, + context_types=ContextTypes(), + updater=updater, + concurrent_updates=False, + ) + assert len(recwarn) == 1 + assert ( + str(recwarn[-1].message) + == '`Application` instances should be built via the `ApplicationBuilder`.' + ) + assert recwarn[0].filename == __file__, "stacklevel is incorrect!" + + @pytest.mark.parametrize( + 'concurrent_updates, expected', [(0, 0), (4, 4), (False, 0), (True, 4096)] + ) + @pytest.mark.filterwarnings("ignore: `Application` instances should") + def test_init(self, bot, concurrent_updates, expected): + update_queue = asyncio.Queue() + job_queue = JobQueue() + persistence = PicklePersistence('file_path') + context_types = ContextTypes() + updater = Updater(bot=bot, update_queue=update_queue) + app = Application( + bot=bot, + update_queue=update_queue, + job_queue=job_queue, + persistence=persistence, + context_types=context_types, + updater=updater, + concurrent_updates=concurrent_updates, + ) + assert app.bot is bot + assert app.update_queue is update_queue + assert app.job_queue is job_queue + assert app.persistence is persistence + assert app.context_types is context_types + assert app.updater is updater + assert app.update_queue is updater.update_queue + assert app.bot is updater.bot + assert app.concurrent_updates == expected + + # These should be done by the builder + assert app.persistence.bot is None + with pytest.raises(RuntimeError, match='No application was set'): + app.job_queue.application + + with pytest.raises(ValueError, match='must be a non-negative'): + Application( + bot=bot, + update_queue=update_queue, + job_queue=job_queue, + persistence=persistence, + context_types=context_types, + updater=updater, + concurrent_updates=-1, + ) + + @pytest.mark.asyncio + async def test_initialize(self, bot, monkeypatch): + """Initialization of persistence is tested eslewhere""" + # TODO: do this! + self.test_flag = set() + + async def initialize_bot(*args, **kwargs): + self.test_flag.add('bot') + + async def initialize_updater(*args, **kwargs): + self.test_flag.add('updater') + + monkeypatch.setattr(Bot, 'initialize', initialize_bot) + monkeypatch.setattr(Updater, 'initialize', initialize_updater) + + await ApplicationBuilder().token(bot.token).build().initialize() + assert self.test_flag == {'bot', 'updater'} + + @pytest.mark.asyncio + async def test_shutdown(self, bot, monkeypatch): + """Studown of persistence is tested eslewhere""" + # TODO: do this! + self.test_flag = set() + + async def shutdown_bot(*args, **kwargs): + self.test_flag.add('bot') + + async def shutdown_updater(*args, **kwargs): + self.test_flag.add('updater') + + monkeypatch.setattr(Bot, 'shutdown', shutdown_bot) + monkeypatch.setattr(Updater, 'shutdown', shutdown_updater) + + await ApplicationBuilder().token(bot.token).build().shutdown() + assert self.test_flag == {'bot', 'updater'} + + @pytest.mark.asyncio + async def test_context_manager(self, monkeypatch, app): + self.test_flag = set() + + async def initialize(*args, **kwargs): + self.test_flag.add('initialize') + + async def shutdown(*args, **kwargs): + self.test_flag.add('stop') + + monkeypatch.setattr(Application, 'initialize', initialize) + monkeypatch.setattr(Application, 'shutdown', shutdown) + + async with app: + pass + + assert self.test_flag == {'initialize', 'stop'} + + @pytest.mark.asyncio + async def test_context_manager_exception_on_init(self, monkeypatch, app): + async def initialize(*args, **kwargs): + raise RuntimeError('initialize') + + async def shutdown(*args): + self.test_flag = 'stop' + + monkeypatch.setattr(Application, 'initialize', initialize) + monkeypatch.setattr(Application, 'shutdown', shutdown) + + with pytest.raises(RuntimeError, match='initialize'): + async with app: + pass + + assert self.test_flag == 'stop' + + @pytest.mark.parametrize("data", ["chat_data", "user_data"]) + def test_chat_user_data_read_only(self, app, data): + read_only_data = getattr(app, data) + writable_data = getattr(app, f"_{data}") + writable_data[123] = 321 + assert read_only_data == writable_data + with pytest.raises(TypeError): + read_only_data[111] = 123 + + def test_builder(self, app): + builder_1 = app.builder() + builder_2 = app.builder() + assert isinstance(builder_1, ApplicationBuilder) + assert isinstance(builder_2, ApplicationBuilder) + assert builder_1 is not builder_2 + + # Make sure that setting a token doesn't raise an exception + # i.e. check that the builders are "empty"/new + builder_1.token(app.bot.token) + builder_2.token(app.bot.token) + + # + # def test_one_context_per_update(self, app): + # def one(update, context): + # if update.message.text == 'test': + # context.my_flag = True + # + # def two(update, context): + # if update.message.text == 'test': + # if not hasattr(context, 'my_flag'): + # pytest.fail() + # else: + # if hasattr(context, 'my_flag'): + # pytest.fail() + # + # app.add_handler(MessageHandler(filters.Regex('test'), one), group=1) + # app.add_handler(MessageHandler(None, two), group=2) + # u = Update(1, Message(1, None, None, None, text='test')) + # app.process_update(u) + # u.message.text = 'something' + # app.process_update(u) + # + # def test_error_handler(self, app): + # app.add_error_handler(self.error_handler_context) + # error = TelegramError('Unauthorized.') + # app.update_queue.put(error) + # sleep(0.1) + # assert self.received == 'Unauthorized.' + # + # # Remove handler + # app.remove_error_handler(self.error_handler_context) + # self.reset() + # + # app.update_queue.put(error) + # sleep(0.1) + # assert self.received is None + # + # def test_double_add_error_handler(self, app, caplog): + # app.add_error_handler(self.error_handler_context) + # with caplog.at_level(logging.DEBUG): + # app.add_error_handler(self.error_handler_context) + # assert len(caplog.records) == 1 + # assert caplog.records[-1].getMessage().startswith( + # 'The callback is already registered') + # + # def test_construction_with_bad_persistence(self, caplog, bot): + # class my_per: + # def __init__(self): + # self.store_data = PersistenceInput(False, False, False, False) + # + # with pytest.raises( + # TypeError, match='persistence must be based on telegram.ext.BasePersistence' + # ): + # ApplicationBuilder().bot(bot).persistence(my_per()).build() + # + # def test_error_handler_that_raises_errors(self, app): + # """ + # Make sure that errors raised in error handlers don't break the main loop of the + # application + # """ + # handler_raise_error = MessageHandler(filters.ALL, self.callback_raise_error) + # handler_increase_count = MessageHandler(filters.ALL, self.callback_increase_count) + # error = TelegramError('Unauthorized.') + # + # app.add_error_handler(self.error_handler_raise_error) + # + # # From errors caused by handlers + # app.add_handler(handler_raise_error) + # app.update_queue.put(self.message_update) + # sleep(0.1) + # + # # From errors in the update_queue + # app.remove_handler(handler_raise_error) + # app.add_handler(handler_increase_count) + # app.update_queue.put(error) + # app.update_queue.put(self.message_update) + # sleep(0.1) + # + # assert self.count == 1 + # + # @pytest.mark.parametrize(['block', 'expected_output'], [(True, 5), (False, 0)]) + # def test_default_run_async_error_handler(self, app, monkeypatch, block, expected_output): + # def mock_async_err_handler(*args, **kwargs): + # self.count = 5 + # + # # set defaults value to app.bot + # app.bot._defaults = Defaults(block=block) + # try: + # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) + # app.add_error_handler(self.error_handler_context) + # + # monkeypatch.setattr(app, 'block', mock_async_err_handler) + # app.process_update(self.message_update) + # + # assert self.count == expected_output + # + # finally: + # # reset app.bot.defaults values + # app.bot._defaults = None + # + # @pytest.mark.parametrize( + # ['block', 'expected_output'], [(True, 'running async'), (False, None)] + # ) + # def test_default_run_async(self, monkeypatch, app, block, expected_output): + # def mock_run_async(*args, **kwargs): + # self.received = 'running async' + # + # # set defaults value to app.bot + # app.bot._defaults = Defaults(block=block) + # try: + # app.add_handler(MessageHandler(filters.ALL, lambda u, c: None)) + # monkeypatch.setattr(app, 'block', mock_run_async) + # app.process_update(self.message_update) + # assert self.received == expected_output + # + # finally: + # # reset defaults value + # app.bot._defaults = None + # + # def test_run_async_multiple(self, bot, app, dp2): + # def get_dispatcher_name(q): + # q.put(current_thread().name) + # + # q1 = Queue() + # q2 = Queue() + # + # app.block(get_dispatcher_name, q1) + # dp2.block(get_dispatcher_name, q2) + # + # sleep(0.1) + # + # name1 = q1.get() + # name2 = q2.get() + # + # assert name1 != name2 + # + # def test_async_raises_dispatcher_handler_stop(self, app, recwarn): + # def callback(update, context): + # raise ApplicationHandlerStop() + # + # app.add_handler(MessageHandler(filters.ALL, callback, block=True)) + # + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert len(recwarn) == 1 + # assert str(recwarn[-1].message).startswith( + # 'ApplicationHandlerStop is not supported with async functions' + # ) + # + # def test_add_async_handler(self, app): + # app.add_handler( + # MessageHandler( + # filters.ALL, + # self.callback_received, + # block=True, + # ) + # ) + # + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.received == self.message_update.message + # + # def test_run_async_no_error_handler(self, app, caplog): + # def func(): + # raise RuntimeError('Async Error') + # + # with caplog.at_level(logging.ERROR): + # app.block(func) + # sleep(0.1) + # assert len(caplog.records) == 1 + # assert caplog.records[-1].getMessage().startswith('No error handlers are registered') + # + # def test_async_handler_async_error_handler_context(self, app): + # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error, block=True)) + # app.add_error_handler(self.error_handler_context, block=True) + # + # app.update_queue.put(self.message_update) + # sleep(2) + # assert self.received == self.message_update.message.text + # + # def test_async_handler_error_handler_that_raises_error(self, app, caplog): + # handler = MessageHandler(filters.ALL, self.callback_raise_error, block=True) + # app.add_handler(handler) + # app.add_error_handler(self.error_handler_raise_error, block=False) + # + # with caplog.at_level(logging.ERROR): + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert len(caplog.records) == 1 + # assert ( + # caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') + # ) + # + # # Make sure that the main loop still runs + # app.remove_handler(handler) + # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.count == 1 + # + # def test_async_handler_async_error_handler_that_raises_error(self, app, caplog): + # handler = MessageHandler(filters.ALL, self.callback_raise_error, block=True) + # app.add_handler(handler) + # app.add_error_handler(self.error_handler_raise_error, block=True) + # + # with caplog.at_level(logging.ERROR): + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert len(caplog.records) == 1 + # assert ( + # caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') + # ) + # + # # Make sure that the main loop still runs + # app.remove_handler(handler) + # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.count == 1 + # + # def test_error_in_handler(self, app): + # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) + # app.add_error_handler(self.error_handler_context) + # + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.received == self.message_update.message.text + # + # def test_add_remove_handler(self, app): + # handler = MessageHandler(filters.ALL, self.callback_increase_count) + # app.add_handler(handler) + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.count == 1 + # app.remove_handler(handler) + # app.update_queue.put(self.message_update) + # assert self.count == 1 + # + # def test_add_remove_handler_non_default_group(self, app): + # handler = MessageHandler(filters.ALL, self.callback_increase_count) + # app.add_handler(handler, group=2) + # with pytest.raises(KeyError): + # app.remove_handler(handler) + # app.remove_handler(handler, group=2) + # + # def test_error_start_twice(self, app): + # assert app.running + # app.start() + # + # def test_handler_order_in_group(self, app): + # app.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1))) + # app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2))) + # app.add_handler(MessageHandler(filters.TEXT, self.callback_set_count(3))) + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.count == 2 + # + # def test_groups(self, app): + # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count)) + # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=2) + # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=-1) + # + # app.update_queue.put(self.message_update) + # sleep(0.1) + # assert self.count == 3 + # + # def test_add_handlers_complex(self, app): + # """Tests both add_handler & add_handlers together & confirms the correct insertion + # order""" + # msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) + # msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) + # + # app.add_handler(msg_handler_set_count, 1) + # app.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) + # + # photo_update = Update(2, message=Message(2, None, None, photo=True)) + # # Putting updates in the queue calls the callback + # app.update_queue.put(self.message_update) + # app.update_queue.put(photo_update) + # sleep(0.1) # sleep is required otherwise there is random behaviour + # + # # Test if handler was added to correct group with correct order- + # assert ( + # self.count == 2 + # and len(app.handlers[1]) == 3 + # and app.handlers[1][0] is msg_handler_set_count + # ) + # + # # Now lets test add_handlers when `handlers` is a dict- + # voice_filter_handler_to_check = MessageHandler(filters.VOICE, + # self.callback_increase_count) + # app.add_handlers( + # handlers={ + # 1: [ + # MessageHandler(filters.USER, self.callback_increase_count), + # voice_filter_handler_to_check, + # ], + # -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], + # } + # ) + # + # user_update = Update(3, message=Message(3, None, None, from_user=User(1, 's', True))) + # voice_update = Update(4, message=Message(4, None, None, voice=True)) + # app.update_queue.put(user_update) + # app.update_queue.put(voice_update) + # sleep(0.1) + # + # assert ( + # self.count == 4 + # and len(app.handlers[1]) == 5 + # and app.handlers[1][-1] is voice_filter_handler_to_check + # ) + # + # app.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) + # sleep(0.1) + # + # assert self.count == 2 and len(app.handlers[-1]) == 1 + # + # # Now lets test the errors which can be produced- + # with pytest.raises(ValueError, match="The `group` argument"): + # app.add_handlers({2: [msg_handler_set_count]}, group=0) + # with pytest.raises(ValueError, match="Handlers for group 3"): + # app.add_handlers({3: msg_handler_set_count}) + # with pytest.raises(ValueError, match="The `handlers` argument must be a sequence"): + # app.add_handlers({msg_handler_set_count}) + # + # def test_add_handler_errors(self, app): + # handler = 'not a handler' + # with pytest.raises(TypeError, match='handler is not an instance of'): + # app.add_handler(handler) + # + # handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) + # with pytest.raises(TypeError, match='group is not int'): + # app.add_handler(handler, 'one') + # + # def test_flow_stop(self, app, bot): + # passed = [] + # + # def start1(b, u): + # passed.append('start1') + # raise ApplicationHandlerStop + # + # def start2(b, u): + # passed.append('start2') + # + # def start3(b, u): + # passed.append('start3') + # + # def error(b, u, e): + # passed.append('error') + # passed.append(e) + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # None, + # None, + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # + # # If Stop raised handlers in other groups should not be called. + # passed = [] + # app.add_handler(CommandHandler('start', start1), 1) + # app.add_handler(CommandHandler('start', start3), 1) + # app.add_handler(CommandHandler('start', start2), 2) + # app.process_update(update) + # assert passed == ['start1'] + # + # def test_exception_in_handler(self, app, bot): + # passed = [] + # err = Exception('General exception') + # + # def start1(u, c): + # passed.append('start1') + # raise err + # + # def start2(u, c): + # passed.append('start2') + # + # def start3(u, c): + # passed.append('start3') + # + # def error(u, c): + # passed.append('error') + # passed.append(c.error) + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # None, + # None, + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # + # # If an unhandled exception was caught, no further handlers from the same group should be + # # called. Also, the error handler should be called and receive the exception + # passed = [] + # app.add_handler(CommandHandler('start', start1), 1) + # app.add_handler(CommandHandler('start', start2), 1) + # app.add_handler(CommandHandler('start', start3), 2) + # app.add_error_handler(error) + # app.process_update(update) + # assert passed == ['start1', 'error', err, 'start3'] + # + # def test_telegram_error_in_handler(self, app, bot): + # passed = [] + # err = TelegramError('Telegram error') + # + # def start1(u, c): + # passed.append('start1') + # raise err + # + # def start2(u, c): + # passed.append('start2') + # + # def start3(u, c): + # passed.append('start3') + # + # def error(u, c): + # passed.append('error') + # passed.append(c.error) + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # None, + # None, + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # + # # If a TelegramException was caught, an error handler should be called and no further + # # handlers from the same group should be called. + # app.add_handler(CommandHandler('start', start1), 1) + # app.add_handler(CommandHandler('start', start2), 1) + # app.add_handler(CommandHandler('start', start3), 2) + # app.add_error_handler(error) + # app.process_update(update) + # assert passed == ['start1', 'error', err, 'start3'] + # assert passed[2] is err + # + # def test_error_while_saving_chat_data(self, bot): + # increment = [] + # + # class OwnPersistence(BasePersistence): + # def get_callback_data(self): + # return None + # + # def update_callback_data(self, data): + # raise Exception + # + # def get_bot_data(self): + # return {} + # + # def update_bot_data(self, data): + # raise Exception + # + # def drop_chat_data(self, chat_id): + # pass + # + # def drop_user_data(self, user_id): + # pass + # + # def get_chat_data(self): + # return defaultdict(dict) + # + # def update_chat_data(self, chat_id, data): + # raise Exception + # + # def get_user_data(self): + # return defaultdict(dict) + # + # def update_user_data(self, user_id, data): + # raise Exception + # + # def get_conversations(self, name): + # pass + # + # def update_conversation(self, name, key, new_state): + # pass + # + # def refresh_user_data(self, user_id, user_data): + # pass + # + # def refresh_chat_data(self, chat_id, chat_data): + # pass + # + # def refresh_bot_data(self, bot_data): + # pass + # + # def flush(self): + # pass + # + # def start1(u, c): + # pass + # + # def error(u, c): + # increment.append("error") + # + # # If updating a user_data or chat_data from a persistence object throws an error, + # # the error handler should catch it + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # Chat(1, "lala"), + # from_user=User(1, "Test", False), + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # my_persistence = OwnPersistence() + # app = ApplicationBuilder().bot(bot).persistence(my_persistence).build() + # app.add_handler(CommandHandler('start', start1)) + # app.add_error_handler(error) + # app.process_update(update) + # assert increment == ["error", "error", "error", "error"] + # + # def test_flow_stop_in_error_handler(self, app, bot): + # passed = [] + # err = TelegramError('Telegram error') + # + # def start1(u, c): + # passed.append('start1') + # raise err + # + # def start2(u, c): + # passed.append('start2') + # + # def start3(u, c): + # passed.append('start3') + # + # def error(u, c): + # passed.append('error') + # passed.append(c.error) + # raise ApplicationHandlerStop + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # None, + # None, + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # + # # If a TelegramException was caught, an error handler should be called and no further + # # handlers from the same group should be called. + # app.add_handler(CommandHandler('start', start1), 1) + # app.add_handler(CommandHandler('start', start2), 1) + # app.add_handler(CommandHandler('start', start3), 2) + # app.add_error_handler(error) + # app.process_update(update) + # assert passed == ['start1', 'error', err] + # assert passed[2] is err + # + # def test_sensible_worker_thread_names(self, dp2): + # thread_names = [thread.name for thread in dp2._Dispatcher__async_threads] + # for thread_name in thread_names: + # assert thread_name.startswith(f"Bot:{dp2.bot.id}:worker:") + # + # @pytest.mark.parametrize( + # 'message', + # [ + # Message(message_id=1, chat=Chat(id=2, type=None), migrate_from_chat_id=1, date=None), + # Message(message_id=1, chat=Chat(id=1, type=None), migrate_to_chat_id=2, date=None), + # Message(message_id=1, chat=Chat(id=1, type=None), date=None), + # None, + # ], + # ) + # @pytest.mark.parametrize('old_chat_id', [None, 1, "1"]) + # @pytest.mark.parametrize('new_chat_id', [None, 2, "1"]) + # def test_migrate_chat_data(self, app, message: 'Message', old_chat_id: int, + # new_chat_id: int): + # def call(match: str): + # with pytest.raises(ValueError, match=match): + # app.migrate_chat_data( + # message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id + # ) + # + # if message and (old_chat_id or new_chat_id): + # call(r"^Message and chat_id pair are mutually exclusive$") + # return + # + # if not any((message, old_chat_id, new_chat_id)): + # call(r"^chat_id pair or message must be passed$") + # return + # + # if message: + # if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: + # call(r"^Invalid message instance") + # return + # effective_old_chat_id = message.migrate_from_chat_id or message.chat.id + # effective_new_chat_id = message.migrate_to_chat_id or message.chat.id + # + # elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): + # call(r"^old_chat_id and new_chat_id must be integers$") + # return + # else: + # effective_old_chat_id = old_chat_id + # effective_new_chat_id = new_chat_id + # + # app.chat_data[effective_old_chat_id]['key'] = "test" + # app.migrate_chat_data(message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id) + # assert effective_old_chat_id not in app.chat_data + # assert app.chat_data[effective_new_chat_id]['key'] == "test" + # + # def test_error_while_persisting(self, app, caplog): + # class OwnPersistence(BasePersistence): + # def update(self, data): + # raise Exception('PersistenceError') + # + # def update_callback_data(self, data): + # self.update(data) + # + # def update_bot_data(self, data): + # self.update(data) + # + # def update_chat_data(self, chat_id, data): + # self.update(data) + # + # def update_user_data(self, user_id, data): + # self.update(data) + # + # def drop_user_data(self, user_id): + # pass + # + # def drop_chat_data(self, chat_id): + # pass + # + # def get_chat_data(self): + # pass + # + # def get_bot_data(self): + # pass + # + # def get_user_data(self): + # pass + # + # def get_callback_data(self): + # pass + # + # def get_conversations(self, name): + # pass + # + # def update_conversation(self, name, key, new_state): + # pass + # + # def refresh_bot_data(self, bot_data): + # pass + # + # def refresh_user_data(self, user_id, user_data): + # pass + # + # def refresh_chat_data(self, chat_id, chat_data): + # pass + # + # def flush(self): + # pass + # + # def callback(update, context): + # pass + # + # test_flag = [] + # + # def error(update, context): + # nonlocal test_flag + # test_flag.append(str(context.error) == 'PersistenceError') + # raise Exception('ErrorHandlingError') + # + # update = Update( + # 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + # ) + # handler = MessageHandler(filters.ALL, callback) + # app.add_handler(handler) + # app.add_error_handler(error) + # + # app.persistence = OwnPersistence() + # + # with caplog.at_level(logging.ERROR): + # app.process_update(update) + # + # assert test_flag == [True, True, True, True] + # assert len(caplog.records) == 4 + # for record in caplog.records: + # message = record.getMessage() + # assert message.startswith('An error was raised and an uncaught') + # + # def test_persisting_no_user_no_chat(self, app): + # class OwnPersistence(BasePersistence): + # def __init__(self): + # super().__init__() + # self.test_flag_bot_data = False + # self.test_flag_chat_data = False + # self.test_flag_user_data = False + # + # def update_bot_data(self, data): + # self.test_flag_bot_data = True + # + # def update_chat_data(self, chat_id, data): + # self.test_flag_chat_data = True + # + # def update_user_data(self, user_id, data): + # self.test_flag_user_data = True + # + # def update_conversation(self, name, key, new_state): + # pass + # + # def drop_chat_data(self, chat_id): + # pass + # + # def drop_user_data(self, user_id): + # pass + # + # def get_conversations(self, name): + # pass + # + # def get_user_data(self): + # pass + # + # def get_bot_data(self): + # pass + # + # def get_chat_data(self): + # pass + # + # def refresh_bot_data(self, bot_data): + # pass + # + # def refresh_user_data(self, user_id, user_data): + # pass + # + # def refresh_chat_data(self, chat_id, chat_data): + # pass + # + # def get_callback_data(self): + # pass + # + # def update_callback_data(self, data): + # pass + # + # def flush(self): + # pass + # + # def callback(update, context): + # pass + # + # handler = MessageHandler(filters.ALL, callback) + # app.add_handler(handler) + # app.persistence = OwnPersistence() + # + # update = Update( + # 1, message=Message(1, None, None, from_user=User(1, '', False), text='Text') + # ) + # app.process_update(update) + # assert app.persistence.test_flag_bot_data + # assert app.persistence.test_flag_user_data + # assert not app.persistence.test_flag_chat_data + # + # app.persistence.test_flag_bot_data = False + # app.persistence.test_flag_user_data = False + # app.persistence.test_flag_chat_data = False + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) + # app.process_update(update) + # assert app.persistence.test_flag_bot_data + # assert not app.persistence.test_flag_user_data + # assert app.persistence.test_flag_chat_data + # + # @pytest.mark.parametrize( + # "c_id,expected", + # [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], + # ids=["test chat_id removal", "test no key in data (no error)"], + # ) + # def test_drop_chat_data(self, app, c_id, expected): + # app._chat_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) + # app.drop_chat_data(c_id) + # assert app.chat_data == expected + # + # @pytest.mark.parametrize( + # "u_id,expected", + # [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], + # ids=["test user_id removal", "test no key in data (no error)"], + # ) + # def test_drop_user_data(self, app, u_id, expected): + # app._user_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) + # app.drop_user_data(u_id) + # assert app.user_data == expected + # + # def test_update_persistence_once_per_update(self, monkeypatch, app): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # + # for group in range(5): + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text=None)) + # app.process_update(update) + # assert self.count == 0 + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='text')) + # app.process_update(update) + # assert self.count == 1 + # + # def test_update_persistence_all_async(self, monkeypatch, app): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args, **kwargs): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # monkeypatch.setattr(app, 'block', dummy_callback) + # + # for group in range(5): + # app.add_handler( + # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group + # ) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) + # app.process_update(update) + # assert self.count == 0 + # + # app.bot._defaults = Defaults(block=True) + # try: + # for group in range(5): + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, + # text='Text')) + # app.process_update(update) + # assert self.count == 0 + # finally: + # app.bot._defaults = None + # + # @pytest.mark.parametrize('block', [DEFAULT_FALSE, False]) + # def test_update_persistence_one_sync(self, monkeypatch, app, block): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args, **kwargs): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # monkeypatch.setattr(app, 'block', dummy_callback) + # + # for group in range(5): + # app.add_handler( + # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group + # ) + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback, block=block),group=5) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) + # app.process_update(update) + # assert self.count == 1 + # + # @pytest.mark.parametrize('block,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)]) + # def test_update_persistence_defaults_async(self, monkeypatch, app, block, expected): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args, **kwargs): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # monkeypatch.setattr(app, 'block', dummy_callback) + # app.bot._defaults = Defaults(block=block) + # + # try: + # for group in range(5): + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, + # text='Text')) + # app.process_update(update) + # assert self.count == expected + # finally: + # app.bot._defaults = None + # + # def test_custom_context_init(self, bot): + # cc = ContextTypes( + # context=CustomContext, + # user_data=int, + # chat_data=float, + # bot_data=complex, + # ) + # + # application = ApplicationBuilder().bot(bot).context_types(cc).build() + # + # assert isinstance(application.user_data[1], int) + # assert isinstance(application.chat_data[1], float) + # assert isinstance(application.bot_data, complex) + # + # def test_custom_context_error_handler(self, bot): + # def error_handler(_, context): + # self.received = ( + # type(context), + # type(context.user_data), + # type(context.chat_data), + # type(context.bot_data), + # ) + # + # application = ( + # ApplicationBuilder() + # .bot(bot) + # .context_types( + # ContextTypes( + # context=CustomContext, bot_data=int, user_data=float, chat_data=complex + # ) + # ) + # .build() + # ) + # application.add_error_handler(error_handler) + # application.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) + # + # application.process_update(self.message_update) + # sleep(0.1) + # assert self.received == (CustomContext, float, complex, int) + # + # def test_custom_context_handler_callback(self, bot): + # def callback(_, context): + # self.received = ( + # type(context), + # type(context.user_data), + # type(context.chat_data), + # type(context.bot_data), + # ) + # + # application = ( + # ApplicationBuilder() + # .bot(bot) + # .context_types( + # ContextTypes( + # context=CustomContext, bot_data=int, user_data=float, chat_data=complex + # ) + # ) + # .build() + # ) + # application.add_handler(MessageHandler(filters.ALL, callback)) + # + # application.process_update(self.message_update) + # sleep(0.1) + # assert self.received == (CustomContext, float, complex, int) diff --git a/tests/test_updater.py b/tests/test_updater.py index 066dcb0f9a5..36d2af6fb03 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -466,7 +466,7 @@ async def set_webhook(*args, **kwargs): assert updater.running # Now, we send an update to the server - update = make_message_update('Webhook', message_factory=make_message) + update = make_message_update('Webhook') await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') assert (await updater.update_queue.get()).to_dict() == update.to_dict() @@ -494,7 +494,7 @@ async def set_webhook(*args, **kwargs): url_path='TOKEN', ) assert updater.running - update = make_message_update('Webhook', message_factory=make_message) + update = make_message_update('Webhook') await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') assert (await updater.update_queue.get()).to_dict() == update.to_dict() await updater.stop() @@ -705,7 +705,7 @@ def webhook_server_init(*args, **kwargs): await updater.start_webhook(ip, port, webhook_url=None, cert=Path(__file__).as_posix()) # Now, we send an update to the server - update = make_message_update(message='test_message', message_factory=make_message) + update = make_message_update(message='test_message') await self._send_webhook_message(ip, port, update.to_json()) assert (await updater.update_queue.get()).to_dict() == update.to_dict() assert self.test_flag == [True, True] From 88978ffe38a90fda37dce8fe88ee9af6cc17d982 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 26 Feb 2022 23:27:05 +0100 Subject: [PATCH 033/153] Improve & tests logic if initializing, shutdown, start & stop --- telegram/_bot.py | 20 +++-- telegram/ext/_application.py | 74 +++++++++++------ telegram/ext/_applicationbuilder.py | 2 - telegram/ext/_updater.py | 119 ++++++++++++++++++---------- telegram/request/_httpxrequest.py | 4 + tests/test_application.py | 64 +++++++++++++-- tests/test_bot.py | 29 ++++++- tests/test_photo.py | 1 - tests/test_request.py | 16 +++- tests/test_updater.py | 76 ++++++++++++++++++ 10 files changed, 318 insertions(+), 87 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index 7e44ebbec97..59e1261e78e 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -356,18 +356,24 @@ async def initialize(self) -> None: cache :attr:`bot` and calls :meth:`telegram.request.BaseRequest.initialize` for :attr:`request`. """ - if not self._initialized: - await asyncio.gather(self._request[0].initialize(), self._request[1].initialize()) - await self.get_me() - self._initialized = True + if self._initialized: + self._logger.warning('This Bot is already initialized.') + return + + await asyncio.gather(self._request[0].initialize(), self._request[1].initialize()) + await self.get_me() + self._initialized = True async def shutdown(self) -> None: """Stop & clear resources used by this class. Currently just calls :meth:`telegram.request.BaseRequest.shutdown` for the request objects used by this bot. """ - if self._initialized: - await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown()) - self._initialized = False + if not self._initialized: + self._logger.warning('This Bot is already shut down.') + return + + await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown()) + self._initialized = False async def __aenter__(self: BT) -> BT: try: diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index dd38f92aceb..320db2cbffc 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -270,6 +270,10 @@ def concurrent_updates(self) -> int: return self._concurrent_updates async def initialize(self) -> None: + if self._initialized: + _logger.warning('This Application is already initialized.') + return + await self.bot.initialize() if self.updater: await self.updater.initialize() @@ -296,6 +300,20 @@ async def initialize(self) -> None: self._initialized = True async def shutdown(self) -> None: + """ + + Returns: + + Raises: + :exc:`RuntimeError`: If the application is still :attr:`running`. + """ + if self.running: + raise RuntimeError('This Application is still running!') + + if not self._initialized: + _logger.warning('This Application is already shut down.') + return + await self.bot.shutdown() if self.updater: await self.updater.shutdown() @@ -381,36 +399,42 @@ async def start(self, ready: Event = None) -> None: ready (:obj:`asyncio.Event`, optional): If specified, the event will be set once the application is ready. + Raises: + :exc:`RuntimeError`: If the application is already running or was not initialized. """ if self.running: - _logger.warning('already running') - if ready is not None: - ready.set() - return + raise RuntimeError('This Application is already running!') + if not self._initialized: + raise RuntimeError('This Application is not initialized!') + self._running = True self.__update_persistence_event.clear() - if self.persistence: - self.__update_persistence_task = asyncio.create_task( - self._persistence_updater() - # TODO: Add this once we drop py3.7 - # name=f'Application:{self.bot.id}:persistence_updater' - ) - _logger.debug('Loop for updating persistence started') - if self.job_queue: - self.job_queue.start() - _logger.debug('JobQueue started') + try: + if self.persistence: + self.__update_persistence_task = asyncio.create_task( + self._persistence_updater() + # TODO: Add this once we drop py3.7 + # name=f'Application:{self.bot.id}:persistence_updater' + ) + _logger.debug('Loop for updating persistence started') - self.__update_fetcher_task = asyncio.create_task( - self._update_fetcher(), - # TODO: Add this once we drop py3.7 - # name=f'Application:{self.bot.id}:update_fetcher' - ) - self._running = True - _logger.info('Application started') + if self.job_queue: + self.job_queue.start() + _logger.debug('JobQueue started') - if ready is not None: - ready.set() + self.__update_fetcher_task = asyncio.create_task( + self._update_fetcher(), + # TODO: Add this once we drop py3.7 + # name=f'Application:{self.bot.id}:update_fetcher' + ) + _logger.info('Application started') + + if ready is not None: + ready.set() + except Exception as exc: + self._running = False + raise exc async def stop(self) -> None: """Stops the process after processing any pending updates or tasks created by @@ -452,6 +476,7 @@ async def stop(self) -> None: _logger.debug('Waiting for persistence loop to finish') self.__update_persistence_event.set() await self.__update_persistence_task + self.__update_persistence_event.clear() _logger.info('Application.stop() complete') @@ -978,8 +1003,6 @@ async def __update_persistence(self) -> None: # We don't want to update any data that has been deleted! update_ids -= delete_ids - print('deleting chat_ids', delete_ids) - print('updating chat_ids', update_ids) for chat_id in update_ids: coroutines.add( @@ -1029,7 +1052,6 @@ async def __update_persistence(self) -> None: result = new_state.resolve() effective_new_state = None if result is TrackingDict.DELETED else result - print(name, key, effective_new_state) # TODO: Test that we actually pass `None` here in case the conversation had ended, # i.e. effective_new_state is TrackingDict.DELETED coroutines.add( diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index c4471ba6f41..ce7d4a78a2d 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -264,8 +264,6 @@ def build( bot = self._updater.bot update_queue = self._updater.update_queue - print(self._concurrent_updates) - application: Application[ BT, CCT, UD, CD, BD, JQ ] = DefaultValue.get_value( # type: ignore[call-arg] # pylint: disable=not-callable diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 19ab0dad796..45fb11c2363 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -73,6 +73,7 @@ class Updater: 'update_queue', 'last_update_id', '_running', + '_initialized', '_httpd', '__lock', '__polling_task', @@ -88,6 +89,7 @@ def __init__( self.last_update_id = 0 self._running = False + self._initialized = False self._httpd: Optional[WebhookServer] = None self.__lock = asyncio.Lock() self.__polling_task: Optional[asyncio.Task] = None @@ -98,10 +100,30 @@ def running(self) -> bool: return self._running async def initialize(self) -> None: + if self._initialized: + self._logger.warning('This Updater is already initialized.') + return + await self.bot.initialize() + self._initialized = True async def shutdown(self) -> None: + """ + + Returns: + + Raises: + :exc:`RuntimeError`: If the updater is still :attr:`running`. + """ + if self.running: + raise RuntimeError('This Updater is still running!') + + if not self._initialized: + self._logger.warning('This Updater is already shut down.') + return + await self.bot.shutdown() + self._initialized = False self._logger.debug('Shut down of Updater complete') async def __aenter__(self: _UpdaterType) -> _UpdaterType: @@ -175,31 +197,37 @@ async def start_polling( async with self.__lock: if self.running: raise RuntimeError('This Updater is already running!') + if not self._initialized: + raise RuntimeError('This Updater is not initialized!') self._running = True - # Create & start tasks - polling_ready = asyncio.Event() - - await self._start_polling( - poll_interval=poll_interval, - timeout=timeout, - read_timeout=read_timeout, - write_timeout=write_timeout, - connect_timeout=connect_timeout, - pool_timeout=pool_timeout, - bootstrap_retries=bootstrap_retries, - drop_pending_updates=drop_pending_updates, - allowed_updates=allowed_updates, - ready=polling_ready, - error_callback=error_callback, - ) + try: + # Create & start tasks + polling_ready = asyncio.Event() + + await self._start_polling( + poll_interval=poll_interval, + timeout=timeout, + read_timeout=read_timeout, + write_timeout=write_timeout, + connect_timeout=connect_timeout, + pool_timeout=pool_timeout, + bootstrap_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + allowed_updates=allowed_updates, + ready=polling_ready, + error_callback=error_callback, + ) - self._logger.debug('Waiting for polling to start') - await polling_ready.wait() - self._logger.debug('Polling to started') + self._logger.debug('Waiting for polling to start') + await polling_ready.wait() + self._logger.debug('Polling to started') - return self.update_queue + return self.update_queue + except Exception as exc: + self._running = False + raise exc async def _start_polling( self, @@ -332,30 +360,36 @@ async def start_webhook( async with self.__lock: if self.running: raise RuntimeError('This Updater is already running!') + if not self._initialized: + raise RuntimeError('This Updater is not initialized!') self._running = True - # Create & start tasks - webhook_ready = asyncio.Event() - - await self._start_webhook( - listen=listen, - port=port, - url_path=url_path, - cert=cert, - key=key, - bootstrap_retries=bootstrap_retries, - drop_pending_updates=drop_pending_updates, - webhook_url=webhook_url, - allowed_updates=allowed_updates, - ready=webhook_ready, - ip_address=ip_address, - max_connections=max_connections, - ) + try: + # Create & start tasks + webhook_ready = asyncio.Event() + + await self._start_webhook( + listen=listen, + port=port, + url_path=url_path, + cert=cert, + key=key, + bootstrap_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + webhook_url=webhook_url, + allowed_updates=allowed_updates, + ready=webhook_ready, + ip_address=ip_address, + max_connections=max_connections, + ) - self._logger.debug('Waiting for webhook server to start') - await webhook_ready.wait() - self._logger.debug('Webhook server started') + self._logger.debug('Waiting for webhook server to start') + await webhook_ready.wait() + self._logger.debug('Webhook server started') + except Exception as exc: + self._running = False + raise exc # Return the update queue so the main thread can insert updates return self.update_queue @@ -592,5 +626,8 @@ async def _stop_polling(self) -> None: if self.__polling_task: self._logger.debug('Waiting background polling task to join.') self.__polling_task.cancel() - await self.__polling_task + try: + await self.__polling_task + except asyncio.CancelledError: + pass self.__polling_task = None diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 923d6c08d36..8899ac0e41a 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -110,6 +110,10 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """See :meth:`BaseRequest.shutdown`.""" + if self._client.is_closed: + _logger.warning('This HTTPXRequest is already shut down.') + return + await self._client.aclose() async def do_request( diff --git a/tests/test_application.py b/tests/test_application.py index f30cbfd02be..9d1113a3b4c 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio +from collections import defaultdict from queue import Queue import pytest @@ -41,7 +42,7 @@ class CustomContext(CallbackContext): pass -class TestDispatcher: +class TestApplication: message_update = make_message_update(message='Text') received = None count = 0 @@ -187,9 +188,58 @@ async def shutdown_updater(*args, **kwargs): monkeypatch.setattr(Bot, 'shutdown', shutdown_bot) monkeypatch.setattr(Updater, 'shutdown', shutdown_updater) - await ApplicationBuilder().token(bot.token).build().shutdown() + async with ApplicationBuilder().token(bot.token).build(): + pass assert self.test_flag == {'bot', 'updater'} + @pytest.mark.asyncio + async def test_multiple_inits_and_shutdowns(self, app, monkeypatch): + self.received = defaultdict(int) + + async def initialize(*args, **kargs): + self.received['init'] += 1 + + async def shutdown(*args, **kwargs): + self.received['shutdown'] += 1 + + monkeypatch.setattr(app.bot, 'initialize', initialize) + monkeypatch.setattr(app.bot, 'shutdown', shutdown) + + await app.initialize() + await app.initialize() + await app.initialize() + await app.shutdown() + await app.shutdown() + await app.shutdown() + + # 2 instead of 1 since `Updater.initialize` also calls bot.init/shutdown + assert self.received['init'] == 2 + assert self.received['shutdown'] == 2 + + @pytest.mark.asyncio + async def test_start_without_initialize(self, app): + with pytest.raises(RuntimeError, match='not initialized'): + await app.start() + + @pytest.mark.asyncio + async def test_shutdown_while_running(self, app): + async with app: + await app.start() + with pytest.raises(RuntimeError, match='still running'): + await app.shutdown() + await app.stop() + + @pytest.mark.asyncio + async def test_start_not_running_after_failure(self, app): + class Event(asyncio.Event): + def set(self) -> None: + raise Exception('Test Exception') + + async with app: + with pytest.raises(Exception, match='Test Exception'): + await app.start(ready=Event()) + assert app.running is False + @pytest.mark.asyncio async def test_context_manager(self, monkeypatch, app): self.test_flag = set() @@ -365,14 +415,14 @@ def test_builder(self, app): # app.bot._defaults = None # # def test_run_async_multiple(self, bot, app, dp2): - # def get_dispatcher_name(q): + # def get_application_name(q): # q.put(current_thread().name) # # q1 = Queue() # q2 = Queue() # - # app.block(get_dispatcher_name, q1) - # dp2.block(get_dispatcher_name, q2) + # app.block(get_application_name, q1) + # dp2.block(get_application_name, q2) # # sleep(0.1) # @@ -381,7 +431,7 @@ def test_builder(self, app): # # assert name1 != name2 # - # def test_async_raises_dispatcher_handler_stop(self, app, recwarn): + # def test_async_raises_application_handler_stop(self, app, recwarn): # def callback(update, context): # raise ApplicationHandlerStop() # @@ -833,7 +883,7 @@ def test_builder(self, app): # assert passed[2] is err # # def test_sensible_worker_thread_names(self, dp2): - # thread_names = [thread.name for thread in dp2._Dispatcher__async_threads] + # thread_names = [thread.name for thread in dp2._Application__async_threads] # for thread_name in thread_names: # assert thread_name.startswith(f"Bot:{dp2.bot.id}:worker:") # diff --git a/tests/test_bot.py b/tests/test_bot.py index cba01a13e96..326153356dd 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -60,7 +60,7 @@ from telegram._utils.datetime import from_timestamp, to_timestamp from telegram._utils.defaultvalue import DefaultValue from telegram.helpers import escape_markdown -from telegram.request import RequestData, BaseRequest +from telegram.request import RequestData, BaseRequest, HTTPXRequest from tests.conftest import ( expect_bad_request, check_defaults_handling, @@ -187,7 +187,7 @@ async def test_invalid_token(self, token): Bot(token) @pytest.mark.asyncio - async def test_initialize_and_stop(self, bot, monkeypatch): + async def test_initialize_and_shutdown(self, bot, monkeypatch): async def initialize(*args, **kwargs): self.test_flag = ['initialize'] @@ -209,6 +209,31 @@ async def stop(*args, **kwargs): finally: await orig_stop() + @pytest.mark.asyncio + async def test_multiple_inits_and_shutdowns(self, bot, monkeypatch): + self.received = defaultdict(int) + + async def initialize(*args, **kwargs): + self.received['init'] += 1 + + async def shutdown(*args, **kwargs): + self.received['shutdown'] += 1 + + monkeypatch.setattr(HTTPXRequest, 'initialize', initialize) + monkeypatch.setattr(HTTPXRequest, 'shutdown', shutdown) + + test_bot = Bot(bot.token) + await test_bot.initialize() + await test_bot.initialize() + await test_bot.initialize() + await test_bot.shutdown() + await test_bot.shutdown() + await test_bot.shutdown() + + # 2 instead of 1 since we have to request objects for each bot + assert self.received['init'] == 2 + assert self.received['shutdown'] == 2 + @pytest.mark.asyncio async def test_context_manager(self, monkeypatch, bot): async def initialize(): diff --git a/tests/test_photo.py b/tests/test_photo.py index 27eb9dcd0ba..6c44f8b29f7 100644 --- a/tests/test_photo.py +++ b/tests/test_photo.py @@ -62,7 +62,6 @@ def thumb(_photo): @pytest.fixture(scope='class') def photo(_photo): - print([ps.to_json() for ps in _photo]) return _photo[-1] diff --git a/tests/test_request.py b/tests/test_request.py index ffc4d73736e..3e80fdc932e 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -345,6 +345,21 @@ class Client: ) assert request._client.timeout == httpx.Timeout(connect=43, read=44, write=45, pool=46) + @pytest.mark.asyncio + async def test_multiple_shutdowns(self, httpx_request, monkeypatch): + self.test_flag = 0 + + async def aclose(*args, **kwargs): + self.test_flag += 1 + + await httpx_request.shutdown() + # only mock after the first run so that AsyncClient can still set the `is_closed` flag + monkeypatch.setattr(httpx.AsyncClient, 'aclose', aclose) + await httpx_request.shutdown() + await httpx_request.shutdown() + + assert self.test_flag == 0 + @pytest.mark.asyncio async def test_context_manager(self, monkeypatch): async def initialize(): @@ -408,7 +423,6 @@ async def test_do_request_manual_timeouts(self, monkeypatch, httpx_request): manual_timeouts = httpx.Timeout(connect=52, read=53, write=54, pool=55) async def make_assertion(_, **kwargs): - print(kwargs.get('timeout'), manual_timeouts) self.test_flag = kwargs.get('timeout') == manual_timeouts return httpx.Response(HTTPStatus.OK) diff --git a/tests/test_updater.py b/tests/test_updater.py index 36d2af6fb03..aa2cb94958b 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -18,6 +18,7 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio import logging +from collections import defaultdict from http import HTTPStatus from pathlib import Path from random import randrange @@ -144,6 +145,60 @@ async def shutdown_bot(*args, **kwargs): assert self.test_flag + @pytest.mark.asyncio + async def test_multiple_inits_and_shutdowns(self, updater, monkeypatch): + self.test_flag = defaultdict(int) + + async def initialize(*args, **kargs): + self.test_flag['init'] += 1 + + async def shutdown(*args, **kwargs): + self.test_flag['shutdown'] += 1 + + monkeypatch.setattr(updater.bot, 'initialize', initialize) + monkeypatch.setattr(updater.bot, 'shutdown', shutdown) + + await updater.initialize() + await updater.initialize() + await updater.initialize() + await updater.shutdown() + await updater.shutdown() + await updater.shutdown() + + assert self.test_flag['init'] == 1 + assert self.test_flag['shutdown'] == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) + async def test_start_without_initialize(self, updater, method): + with pytest.raises(RuntimeError, match='not initialized'): + await getattr(updater, method)() + + @pytest.mark.asyncio + @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) + async def test_shutdown_while_running(self, updater, method, monkeypatch): + async def set_webhook(*args, **kwargs): + return True + + monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + + async with updater: + if 'webhook' in method: + await getattr(updater, method)( + ip_address=ip, + port=port, + ) + else: + await asyncio.sleep(1) + await getattr(updater, method)() + + with pytest.raises(RuntimeError, match='still running'): + await updater.shutdown() + await updater.stop() + @pytest.mark.asyncio async def test_context_manager(self, monkeypatch, updater): async def initialize(*args, **kwargs): @@ -325,6 +380,7 @@ async def do_request(*args, **kwargs): await updater.start_polling( bootstrap_retries=retries, ) + await updater.stop() @pytest.mark.parametrize( 'error,callback_should_be_called', @@ -430,6 +486,19 @@ async def get_updates(*args, **kwargs): # Make sure that the update_id offset wasn't increased assert self.message_count == 2 + @pytest.mark.asyncio + async def test_start_polling_not_running_after_failure(self, updater, monkeypatch): + # Unfortunately we have to use some internal logic to trigger an exception + async def _start_polling(*args, **kwargs): + raise Exception('Test Exception') + + monkeypatch.setattr(Updater, '_start_polling', _start_polling) + + async with updater: + with pytest.raises(Exception, match='Test Exception'): + await updater.start_polling() + assert updater.running is False + @pytest.mark.asyncio @pytest.mark.parametrize('ext_bot', [True, False]) @pytest.mark.parametrize('drop_pending_updates', (True, False)) @@ -646,6 +715,8 @@ async def return_true(*args, **kwargs): assert isinstance(button.callback_data, InvalidCallbackData) else: assert button.callback_data == 'callback_data' + + await updater.stop() finally: updater.bot.arbitrary_callback_data = False updater.bot.callback_data_cache.clear_callback_data() @@ -674,6 +745,7 @@ async def return_true(*args, **kwargs): webhook_url=None, allowed_updates=None, ) + assert updater.running is False @pytest.mark.asyncio async def test_webhook_ssl_just_for_telegram(self, monkeypatch, updater): @@ -709,6 +781,7 @@ def webhook_server_init(*args, **kwargs): await self._send_webhook_message(ip, port, update.to_json()) assert (await updater.update_queue.get()).to_dict() == update.to_dict() assert self.test_flag == [True, True] + await updater.stop() @pytest.mark.asyncio @pytest.mark.parametrize('exception_class', (InvalidToken, TelegramError)) @@ -733,6 +806,7 @@ async def do_request(*args, **kwargs): await updater.start_webhook( bootstrap_retries=retries, ) + await updater.stop() @pytest.mark.asyncio async def test_webhook_invalid_posts(self, updater, monkeypatch): @@ -773,3 +847,5 @@ async def return_true(*args, **kwargs): # response = await self._send_webhook_message( # ip, port, 'dummy-payload', content_len='not-a-number') # assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + await updater.stop() From 254a65e570e2e61d33bd444db5f2331f04d2cc79 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 27 Feb 2022 09:26:35 +0100 Subject: [PATCH 034/153] Further improve & test logic if initializing, shutdown, start & stop --- telegram/_bot.py | 2 +- telegram/ext/_updater.py | 2 +- telegram/request/_httpxrequest.py | 15 ++++++++-- tests/test_application.py | 8 ++++++ tests/test_bot.py | 9 ++++++ tests/test_request.py | 48 ++++++++++++++++++++++++++----- tests/test_updater.py | 8 ++++++ 7 files changed, 80 insertions(+), 12 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index 59e1261e78e..93e8efdf33e 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -268,7 +268,7 @@ def _insert_defaults(self, data: Dict[str, object]) -> None: # pylint: disable= async def _post( self, - endpoint: str, # 'sendMessage', 'sendPhoto', 'getMe' + endpoint: str, data: JSONDict = None, # {'chat_id': 123, 'text': 'Hello there!'} read_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 45fb11c2363..f106a6dff7f 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -113,7 +113,7 @@ async def shutdown(self) -> None: Returns: Raises: - :exc:`RuntimeError`: If the updater is still :attr:`running`. + :exc:`RuntimeError`: If the updater is still running. """ if self.running: raise RuntimeError('This Updater is still running!') diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 8899ac0e41a..0038825a1e8 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -77,7 +77,7 @@ class HTTPXRequest(BaseRequest): connections in the connection pool! """ - __slots__ = ('_client',) + __slots__ = ('_client', '_client_kwargs') def __init__( self, @@ -98,15 +98,21 @@ def __init__( max_connections=connection_pool_size, max_keepalive_connections=connection_pool_size, ) - - self._client = httpx.AsyncClient( + self._client_kwargs = dict( timeout=timeout, proxies=proxy_url, limits=limits, ) + self._client = self._build_client() + + def _build_client(self) -> httpx.AsyncClient: + return httpx.AsyncClient(**self._client_kwargs) # type: ignore[arg-type] + async def initialize(self) -> None: """See :meth:`BaseRequest.initialize`.""" + if self._client.is_closed: + self._client = self._build_client() async def shutdown(self) -> None: """See :meth:`BaseRequest.shutdown`.""" @@ -127,6 +133,9 @@ async def do_request( pool_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, ) -> Tuple[int, bytes]: """See :meth:`BaseRequest.do_request`.""" + if self._client.is_closed: + raise RuntimeError('This HTTPXRequest is not initialized!') + if isinstance(read_timeout, DefaultValue): read_timeout = self._client.timeout.read if isinstance(write_timeout, DefaultValue): diff --git a/tests/test_application.py b/tests/test_application.py index 9d1113a3b4c..04e9fe98443 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -216,6 +216,14 @@ async def shutdown(*args, **kwargs): assert self.received['init'] == 2 assert self.received['shutdown'] == 2 + @pytest.mark.asyncio + async def test_multiple_init_cycles(self, app): + # nothing really to assert - this should just not fail + async with app: + await app.bot.get_me() + async with app: + await app.bot.get_me() + @pytest.mark.asyncio async def test_start_without_initialize(self, app): with pytest.raises(RuntimeError, match='not initialized'): diff --git a/tests/test_bot.py b/tests/test_bot.py index 326153356dd..f26c85810bc 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -234,6 +234,15 @@ async def shutdown(*args, **kwargs): assert self.received['init'] == 2 assert self.received['shutdown'] == 2 + @pytest.mark.asyncio + async def test_multiple_init_cycles(self, bot): + # nothing really to assert - this should just not fail + test_bot = Bot(bot.token) + async with test_bot: + await test_bot.get_me() + async with test_bot: + await test_bot.get_me() + @pytest.mark.asyncio async def test_context_manager(self, monkeypatch, bot): async def initialize(): diff --git a/tests/test_request.py b/tests/test_request.py index 3e80fdc932e..f0c4044a694 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -20,6 +20,7 @@ implementations for BaseRequest and we want to test HTTPXRequest anyway.""" import asyncio import json +from collections import defaultdict from dataclasses import dataclass from http import HTTPStatus from typing import Tuple, Any, Coroutine, Callable @@ -346,19 +347,52 @@ class Client: assert request._client.timeout == httpx.Timeout(connect=43, read=44, write=45, pool=46) @pytest.mark.asyncio - async def test_multiple_shutdowns(self, httpx_request, monkeypatch): - self.test_flag = 0 + async def test_multiple_inits_and_shutdowns(self, monkeypatch): + self.test_flag = defaultdict(int) - async def aclose(*args, **kwargs): - self.test_flag += 1 + orig_init = httpx.AsyncClient.__init__ + orig_aclose = httpx.AsyncClient.aclose + class Client(httpx.AsyncClient): + def __init__(*args, **kwargs): + print('this is init') + orig_init(*args, **kwargs) + self.test_flag['init'] += 1 + + async def aclose(*args, **kwargs): + print('this is aclose') + await orig_aclose(*args, **kwargs) + self.test_flag['shutdown'] += 1 + + monkeypatch.setattr(httpx, 'AsyncClient', Client) + + # Create a new one instead of using the fixture so that the mocking can work + httpx_request = HTTPXRequest() + + await httpx_request.initialize() + await httpx_request.initialize() + await httpx_request.initialize() await httpx_request.shutdown() - # only mock after the first run so that AsyncClient can still set the `is_closed` flag - monkeypatch.setattr(httpx.AsyncClient, 'aclose', aclose) await httpx_request.shutdown() await httpx_request.shutdown() - assert self.test_flag == 0 + assert self.test_flag['init'] == 1 + assert self.test_flag['shutdown'] == 1 + + @pytest.mark.asyncio + async def test_multiple_init_cycles(self): + # nothing really to assert - this should just not fail + httpx_request = HTTPXRequest() + async with httpx_request: + await httpx_request.do_request(url='https://python-telegram-bot.org', method='GET') + async with httpx_request: + await httpx_request.do_request(url='https://python-telegram-bot.org', method='GET') + + @pytest.mark.asyncio + async def test_do_request_after_shutdown(self, httpx_request): + await httpx_request.shutdown() + with pytest.raises(RuntimeError, match='not initialized'): + await httpx_request.do_request(url='url', method='GET') @pytest.mark.asyncio async def test_context_manager(self, monkeypatch): diff --git a/tests/test_updater.py b/tests/test_updater.py index aa2cb94958b..2ad09712d2a 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -168,6 +168,14 @@ async def shutdown(*args, **kwargs): assert self.test_flag['init'] == 1 assert self.test_flag['shutdown'] == 1 + @pytest.mark.asyncio + async def test_multiple_init_cycles(self, updater): + # nothing really to assert - this should just not fail + async with updater: + await updater.bot.get_me() + async with updater: + await updater.bot.get_me() + @pytest.mark.asyncio @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) async def test_start_without_initialize(self, updater, method): From f3671fe88124393dd43b6e8b499e00da5e893ebe Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 28 Feb 2022 22:30:04 +0100 Subject: [PATCH 035/153] Few more application tests --- telegram/ext/_application.py | 2 +- telegram/ext/_messagehandler.py | 4 +- tests/test_application.py | 299 ++++++++++++++++++++++++++++---- 3 files changed, 272 insertions(+), 33 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 320db2cbffc..80c884413a9 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -699,7 +699,7 @@ async def process_update(self, update: object) -> None: try: for handler in handlers: check = handler.check_update(update) - if check is not None and check is not False: + if bool(check): if not context: context = self.context_types.context.from_update(update, self) await context.refresh_data() diff --git a/telegram/ext/_messagehandler.py b/telegram/ext/_messagehandler.py index 486caf8f115..d9b43e00c36 100644 --- a/telegram/ext/_messagehandler.py +++ b/telegram/ext/_messagehandler.py @@ -94,7 +94,9 @@ def check_update(self, update: object) -> Optional[Union[bool, Dict[str, list]]] """ if isinstance(update, Update): - return self.filters.check_update(update) + # The `or False` makes sure that we don't return empty dicts + # TODO: add a test for this to MessageHandler + return self.filters.check_update(update) or False return None def collect_additional_context( diff --git a/tests/test_application.py b/tests/test_application.py index 04e9fe98443..a4d74759677 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -31,6 +31,8 @@ ContextTypes, PicklePersistence, Updater, + filters, + MessageHandler, ) from telegram.error import TelegramError @@ -144,6 +146,10 @@ def test_init(self, bot, concurrent_updates, expected): with pytest.raises(RuntimeError, match='No application was set'): app.job_queue.application + assert isinstance(app.bot_data, dict) + assert isinstance(app.chat_data[1], dict) + assert isinstance(app.user_data[1], dict) + with pytest.raises(ValueError, match='must be a non-negative'): Application( bot=bot, @@ -155,6 +161,20 @@ def test_init(self, bot, concurrent_updates, expected): concurrent_updates=-1, ) + def test_custom_context_init(self, bot): + cc = ContextTypes( + context=CustomContext, + user_data=int, + chat_data=float, + bot_data=complex, + ) + + application = ApplicationBuilder().bot(bot).context_types(cc).build() + + assert isinstance(application.user_data[1], int) + assert isinstance(application.chat_data[1], float) + assert isinstance(application.bot_data, complex) + @pytest.mark.asyncio async def test_initialize(self, bot, monkeypatch): """Initialization of persistence is tested eslewhere""" @@ -304,26 +324,257 @@ def test_builder(self, app): builder_1.token(app.bot.token) builder_2.token(app.bot.token) + @pytest.mark.asyncio + async def test_one_context_per_update(self, app): + self.received = None + + async def one(update, context): + print('handler one for messag', repr(update.message.text)) + self.received = context + + def two(update, context): + if update.message.text == 'test': + if context is not self.received: + pytest.fail('Expected same context object, got different') + else: + if context is self.received: + print(context, self.received) + pytest.fail('First handler was wrongly called') + + app.add_handler(MessageHandler(filters.Regex('test'), one), group=1) + app.add_handler(MessageHandler(filters.ALL, two), group=2) + u = make_message_update(message='test') + await app.process_update(u) + self.received = None + u.message.text = 'something' + await app.process_update(u) + + @pytest.mark.asyncio + async def test_error_in_handler(self, app): + app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) + app.add_error_handler(self.error_handler_context) + + async with app: + await app.start() + await app.update_queue.put(self.message_update) + await asyncio.sleep(0.1) + await app.stop() + + assert self.received == self.message_update.message.text + + @pytest.mark.asyncio + async def test_add_remove_handler(self, app): + handler = MessageHandler(filters.ALL, self.callback_increase_count) + app.add_handler(handler) + + async with app: + await app.start() + await app.update_queue.put(self.message_update) + await asyncio.sleep(0.1) + assert self.count == 1 + app.remove_handler(handler) + await app.update_queue.put(self.message_update) + assert self.count == 1 + await app.stop() + + @pytest.mark.asyncio + async def test_add_remove_handler_non_default_group(self, app): + handler = MessageHandler(filters.ALL, self.callback_increase_count) + app.add_handler(handler, group=2) + with pytest.raises(KeyError): + app.remove_handler(handler) + app.remove_handler(handler, group=2) + # - # def test_one_context_per_update(self, app): - # def one(update, context): - # if update.message.text == 'test': - # context.my_flag = True + @pytest.mark.asyncio + async def test_handler_order_in_group(self, app): + app.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1))) + app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2))) + app.add_handler(MessageHandler(filters.TEXT, self.callback_set_count(3))) + async with app: + await app.start() + await app.update_queue.put(self.message_update) + await asyncio.sleep(0.1) + assert self.count == 2 + await app.stop() + + @pytest.mark.asyncio + async def test_groups(self, app): + app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count)) + app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=2) + app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=-1) + + async with app: + await app.start() + await app.update_queue.put(self.message_update) + await asyncio.sleep(0.1) + assert self.count == 3 + await app.stop() + # - # def two(update, context): - # if update.message.text == 'test': - # if not hasattr(context, 'my_flag'): - # pytest.fail() - # else: - # if hasattr(context, 'my_flag'): - # pytest.fail() + # @pytest.mark.asyncio + # async def test_add_handlers_complex(self, app): + # """Tests both add_handler & add_handlers together & confirms the correct insertion + # order""" + # msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) + # msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) # - # app.add_handler(MessageHandler(filters.Regex('test'), one), group=1) - # app.add_handler(MessageHandler(None, two), group=2) - # u = Update(1, Message(1, None, None, None, text='test')) - # app.process_update(u) - # u.message.text = 'something' - # app.process_update(u) + # app.add_handler(msg_handler_set_count, 1) + # app.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) + # + # photo_update = Update(2, message=Message(2, None, None, photo=True)) + # # Putting updates in the queue calls the callback + # app.update_queue.put(self.message_update) + # app.update_queue.put(photo_update) + # sleep(0.1) # sleep is required otherwise there is random behaviour + # + # # Test if handler was added to correct group with correct order- + # assert ( + # self.count == 2 + # and len(app.handlers[1]) == 3 + # and app.handlers[1][0] is msg_handler_set_count + # ) + # + # # Now lets test add_handlers when `handlers` is a dict- + # voice_filter_handler_to_check = MessageHandler(filters.VOICE, + # self.callback_increase_count) + # app.add_handlers( + # handlers={ + # 1: [ + # MessageHandler(filters.USER, self.callback_increase_count), + # voice_filter_handler_to_check, + # ], + # -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], + # } + # ) + # + # user_update = Update(3, message=Message(3, None, None, from_user=User(1, 's', True))) + # voice_update = Update(4, message=Message(4, None, None, voice=True)) + # app.update_queue.put(user_update) + # app.update_queue.put(voice_update) + # sleep(0.1) + # + # assert ( + # self.count == 4 + # and len(app.handlers[1]) == 5 + # and app.handlers[1][-1] is voice_filter_handler_to_check + # ) + # + # app.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) + # sleep(0.1) + # + # assert self.count == 2 and len(app.handlers[-1]) == 1 + # + # # Now lets test the errors which can be produced- + # with pytest.raises(ValueError, match="The `group` argument"): + # app.add_handlers({2: [msg_handler_set_count]}, group=0) + # with pytest.raises(ValueError, match="Handlers for group 3"): + # app.add_handlers({3: msg_handler_set_count}) + # with pytest.raises(ValueError, match="The `handlers` argument must be a sequence"): + # app.add_handlers({msg_handler_set_count}) + # + # @pytest.mark.asyncio + # async def test_add_handler_errors(self, app): + # handler = 'not a handler' + # with pytest.raises(TypeError, match='handler is not an instance of'): + # app.add_handler(handler) + # + # handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) + # with pytest.raises(TypeError, match='group is not int'): + # app.add_handler(handler, 'one') + # + # @pytest.mark.asyncio + # async def test_flow_stop(self, app, bot): + # passed = [] + # + # def start1(b, u): + # passed.append('start1') + # raise ApplicationHandlerStop + # + # def start2(b, u): + # passed.append('start2') + # + # def start3(b, u): + # passed.append('start3') + # + # def error(b, u, e): + # passed.append('error') + # passed.append(e) + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # None, + # None, + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # + # # If Stop raised handlers in other groups should not be called. + # passed = [] + # app.add_handler(CommandHandler('start', start1), 1) + # app.add_handler(CommandHandler('start', start3), 1) + # app.add_handler(CommandHandler('start', start2), 2) + # app.process_update(update) + # assert passed == ['start1'] + # + # @pytest.mark.asyncio + # async def test_exception_in_handler(self, app, bot): + # passed = [] + # err = Exception('General exception') + # + # def start1(u, c): + # passed.append('start1') + # raise err + # + # def start2(u, c): + # passed.append('start2') + # + # def start3(u, c): + # passed.append('start3') + # + # def error(u, c): + # passed.append('error') + # passed.append(c.error) + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # None, + # None, + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # + # # If an unhandled exception was caught, no further handlers from the same group should be + # # called. Also, the error handler should be called and receive the exception + # passed = [] + # app.add_handler(CommandHandler('start', start1), 1) + # app.add_handler(CommandHandler('start', start2), 1) + # app.add_handler(CommandHandler('start', start3), 2) + # app.add_error_handler(error) + # app.process_update(update) + # assert passed == ['start1', 'error', err, 'start3'] + + # + + # TODProperly test app.start! + # @pytest.mark.asyncio + # async def test_error_start_twice(self, app): + # assert app.running + # app.start() # # def test_error_handler(self, app): # app.add_error_handler(self.error_handler_context) @@ -1217,20 +1468,6 @@ def test_builder(self, app): # finally: # app.bot._defaults = None # - # def test_custom_context_init(self, bot): - # cc = ContextTypes( - # context=CustomContext, - # user_data=int, - # chat_data=float, - # bot_data=complex, - # ) - # - # application = ApplicationBuilder().bot(bot).context_types(cc).build() - # - # assert isinstance(application.user_data[1], int) - # assert isinstance(application.chat_data[1], float) - # assert isinstance(application.bot_data, complex) - # # def test_custom_context_error_handler(self, bot): # def error_handler(_, context): # self.received = ( From 1ac116e9832cef377f2424a39875f77212a57a24 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 1 Mar 2022 21:15:48 +0100 Subject: [PATCH 036/153] Few more application tests --- telegram/ext/_application.py | 67 ++--- telegram/ext/_updater.py | 20 +- tests/test_application.py | 515 ++++++++++++++++++++--------------- tests/test_updater.py | 4 + 4 files changed, 347 insertions(+), 259 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 80c884413a9..3a0beda9e27 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -446,39 +446,44 @@ async def stop(self) -> None: Warning: Once this method is called, no more updates will be fetched from :attr:`update_queue`, even if it's not empty. - """ - if self.running: - self._running = False - _logger.info('Application is stopping. This might take a moment.') - if self.updater and self.updater.running: - _logger.debug('Waiting for updater to stop fetching updates') - await self.updater.stop() - - # Stop listening for new updates and handle all pending ones - await self.update_queue.put(_STOP_SIGNAL) - _logger.debug('Waiting for update_queue to join') - await self.update_queue.join() - if self.__update_fetcher_task: - await self.__update_fetcher_task - _logger.debug("Application stopped fetching of updates.") - - if self.job_queue: - _logger.debug('Waiting for running jobs to finish') - await self.job_queue.stop(wait=True) - _logger.debug('JobQueue stopped') - - _logger.debug('Waiting for `create_task` calls to be processed') - await asyncio.gather(*self.__create_task_tasks, return_exceptions=True) - - # Make sure that this is the *last* step of stopping the application! - if self.persistence and self.__update_persistence_task: - _logger.debug('Waiting for persistence loop to finish') - self.__update_persistence_event.set() - await self.__update_persistence_task - self.__update_persistence_event.clear() + Raises: + :exc:`RuntimeError`: If the application is not running. + """ + if not self.running: + raise RuntimeError('This Application is not running!') - _logger.info('Application.stop() complete') + self._running = False + _logger.info('Application is stopping. This might take a moment.') + + if self.updater and self.updater.running: + _logger.debug('Waiting for updater to stop fetching updates') + await self.updater.stop() + + # Stop listening for new updates and handle all pending ones + await self.update_queue.put(_STOP_SIGNAL) + _logger.debug('Waiting for update_queue to join') + await self.update_queue.join() + if self.__update_fetcher_task: + await self.__update_fetcher_task + _logger.debug("Application stopped fetching of updates.") + + if self.job_queue: + _logger.debug('Waiting for running jobs to finish') + await self.job_queue.stop(wait=True) + _logger.debug('JobQueue stopped') + + _logger.debug('Waiting for `create_task` calls to be processed') + await asyncio.gather(*self.__create_task_tasks, return_exceptions=True) + + # Make sure that this is the *last* step of stopping the application! + if self.persistence and self.__update_persistence_task: + _logger.debug('Waiting for persistence loop to finish') + self.__update_persistence_event.set() + await self.__update_persistence_task + self.__update_persistence_event.clear() + + _logger.info('Application.stop() complete') def run_polling( self, diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index f106a6dff7f..cc41d786c96 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -604,17 +604,23 @@ def bootstrap_on_err_cb(exc: Exception) -> None: ) async def stop(self) -> None: - """Stops the polling/webhook.""" + """Stops the polling/webhook. + + Raises: + :exc:`RuntimeError`: If the updater is not running. + """ async with self.__lock: - if self.running: - self._logger.debug('Stopping Updater') + if not self.running: + raise RuntimeError('This Updater is not running!') - self._running = False + self._logger.debug('Stopping Updater') + + self._running = False - await self._stop_httpd() - await self._stop_polling() + await self._stop_httpd() + await self._stop_polling() - self._logger.debug('Updater.stop() is complete') + self._logger.debug('Updater.stop() is complete') async def _stop_httpd(self) -> None: if self._httpd: diff --git a/tests/test_application.py b/tests/test_application.py index a4d74759677..f84d72a21df 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -22,7 +22,7 @@ import pytest -from telegram import Bot +from telegram import Bot, Message, User, MessageEntity from telegram.ext import ( JobQueue, CallbackContext, @@ -33,6 +33,10 @@ Updater, filters, MessageHandler, + Handler, + ApplicationHandlerStop, + CommandHandler, + TypeHandler, ) from telegram.error import TelegramError @@ -324,12 +328,56 @@ def test_builder(self, app): builder_1.token(app.bot.token) builder_2.token(app.bot.token) + @pytest.mark.asyncio + async def test_start_stop_processing_updates(self, app): + # TODO: repeat a similar test for create_task, persistence processing and job queue + async def callback(u, c): + self.received = u + + assert not app.running + assert not app.updater.running + app.add_handler(TypeHandler(object, callback)) + + await app.update_queue.put(1) + await asyncio.sleep(0.05) + assert not app.update_queue.empty() + assert self.received is None + + async with app: + await app.start() + assert app.running + assert not app.updater.running + await asyncio.sleep(0.05) + assert app.update_queue.empty() + assert self.received == 1 + + await app.stop() + assert not app.running + assert not app.updater.running + await app.update_queue.put(2) + await asyncio.sleep(0.05) + assert not app.update_queue.empty() + assert self.received != 2 + assert self.received == 1 + + @pytest.mark.asyncio + async def test_error_start_stop_twice(self, app): + async with app: + await app.start() + assert app.running + with pytest.raises(RuntimeError, match='already running'): + await app.start() + + await app.stop() + assert not app.running + with pytest.raises(RuntimeError, match='not running'): + await app.stop() + @pytest.mark.asyncio async def test_one_context_per_update(self, app): self.received = None async def one(update, context): - print('handler one for messag', repr(update.message.text)) self.received = context def two(update, context): @@ -349,18 +397,14 @@ def two(update, context): u.message.text = 'something' await app.process_update(u) - @pytest.mark.asyncio - async def test_error_in_handler(self, app): - app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - app.add_error_handler(self.error_handler_context) + def test_add_handler_errors(self, app): + handler = 'not a handler' + with pytest.raises(TypeError, match='handler is not an instance of'): + app.add_handler(handler) - async with app: - await app.start() - await app.update_queue.put(self.message_update) - await asyncio.sleep(0.1) - await app.stop() - - assert self.received == self.message_update.message.text + handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) + with pytest.raises(TypeError, match='group is not int'): + app.add_handler(handler, 'one') @pytest.mark.asyncio async def test_add_remove_handler(self, app): @@ -370,7 +414,7 @@ async def test_add_remove_handler(self, app): async with app: await app.start() await app.update_queue.put(self.message_update) - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) assert self.count == 1 app.remove_handler(handler) await app.update_queue.put(self.message_update) @@ -394,7 +438,7 @@ async def test_handler_order_in_group(self, app): async with app: await app.start() await app.update_queue.put(self.message_update) - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) assert self.count == 2 await app.stop() @@ -407,188 +451,217 @@ async def test_groups(self, app): async with app: await app.start() await app.update_queue.put(self.message_update) - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) assert self.count == 3 await app.stop() - # - # @pytest.mark.asyncio - # async def test_add_handlers_complex(self, app): - # """Tests both add_handler & add_handlers together & confirms the correct insertion - # order""" - # msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) - # msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) - # - # app.add_handler(msg_handler_set_count, 1) - # app.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) - # - # photo_update = Update(2, message=Message(2, None, None, photo=True)) - # # Putting updates in the queue calls the callback - # app.update_queue.put(self.message_update) - # app.update_queue.put(photo_update) - # sleep(0.1) # sleep is required otherwise there is random behaviour - # - # # Test if handler was added to correct group with correct order- - # assert ( - # self.count == 2 - # and len(app.handlers[1]) == 3 - # and app.handlers[1][0] is msg_handler_set_count - # ) - # - # # Now lets test add_handlers when `handlers` is a dict- - # voice_filter_handler_to_check = MessageHandler(filters.VOICE, - # self.callback_increase_count) - # app.add_handlers( - # handlers={ - # 1: [ - # MessageHandler(filters.USER, self.callback_increase_count), - # voice_filter_handler_to_check, - # ], - # -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], - # } - # ) - # - # user_update = Update(3, message=Message(3, None, None, from_user=User(1, 's', True))) - # voice_update = Update(4, message=Message(4, None, None, voice=True)) - # app.update_queue.put(user_update) - # app.update_queue.put(voice_update) - # sleep(0.1) - # - # assert ( - # self.count == 4 - # and len(app.handlers[1]) == 5 - # and app.handlers[1][-1] is voice_filter_handler_to_check - # ) - # - # app.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) - # sleep(0.1) - # - # assert self.count == 2 and len(app.handlers[-1]) == 1 - # - # # Now lets test the errors which can be produced- - # with pytest.raises(ValueError, match="The `group` argument"): - # app.add_handlers({2: [msg_handler_set_count]}, group=0) - # with pytest.raises(ValueError, match="Handlers for group 3"): - # app.add_handlers({3: msg_handler_set_count}) - # with pytest.raises(ValueError, match="The `handlers` argument must be a sequence"): - # app.add_handlers({msg_handler_set_count}) - # - # @pytest.mark.asyncio - # async def test_add_handler_errors(self, app): - # handler = 'not a handler' - # with pytest.raises(TypeError, match='handler is not an instance of'): - # app.add_handler(handler) - # - # handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) - # with pytest.raises(TypeError, match='group is not int'): - # app.add_handler(handler, 'one') - # - # @pytest.mark.asyncio - # async def test_flow_stop(self, app, bot): - # passed = [] - # - # def start1(b, u): - # passed.append('start1') - # raise ApplicationHandlerStop - # - # def start2(b, u): - # passed.append('start2') - # - # def start3(b, u): - # passed.append('start3') - # - # def error(b, u, e): - # passed.append('error') - # passed.append(e) - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # None, - # None, - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # - # # If Stop raised handlers in other groups should not be called. - # passed = [] - # app.add_handler(CommandHandler('start', start1), 1) - # app.add_handler(CommandHandler('start', start3), 1) - # app.add_handler(CommandHandler('start', start2), 2) - # app.process_update(update) - # assert passed == ['start1'] - # - # @pytest.mark.asyncio - # async def test_exception_in_handler(self, app, bot): - # passed = [] - # err = Exception('General exception') - # - # def start1(u, c): - # passed.append('start1') - # raise err - # - # def start2(u, c): - # passed.append('start2') - # - # def start3(u, c): - # passed.append('start3') - # - # def error(u, c): - # passed.append('error') - # passed.append(c.error) - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # None, - # None, - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # - # # If an unhandled exception was caught, no further handlers from the same group should be - # # called. Also, the error handler should be called and receive the exception - # passed = [] - # app.add_handler(CommandHandler('start', start1), 1) - # app.add_handler(CommandHandler('start', start2), 1) - # app.add_handler(CommandHandler('start', start3), 2) - # app.add_error_handler(error) - # app.process_update(update) - # assert passed == ['start1', 'error', err, 'start3'] + @pytest.mark.asyncio + async def test_add_handlers(self, app): + """Tests both add_handler & add_handlers together & confirms the correct insertion + order""" + msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) + msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) - # + app.add_handler(msg_handler_set_count, 1) + app.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) + + photo_update = make_message_update(message=Message(2, None, None, photo=True)) + + async with app: + await app.start() + # Putting updates in the queue calls the callback + await app.update_queue.put(self.message_update) + await app.update_queue.put(photo_update) + await asyncio.sleep(0.05) # sleep is required otherwise there is random behaviour + + # Test if handler was added to correct group with correct order- + assert ( + self.count == 2 + and len(app.handlers[1]) == 3 + and app.handlers[1][0] is msg_handler_set_count + ) + + # Now lets test add_handlers when `handlers` is a dict- + voice_filter_handler_to_check = MessageHandler( + filters.VOICE, self.callback_increase_count + ) + app.add_handlers( + handlers={ + 1: [ + MessageHandler(filters.USER, self.callback_increase_count), + voice_filter_handler_to_check, + ], + -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], + } + ) + + user_update = make_message_update( + message=Message(3, None, None, from_user=User(1, 's', True)) + ) + voice_update = make_message_update(message=Message(4, None, None, voice=True)) + await app.update_queue.put(user_update) + await app.update_queue.put(voice_update) + await asyncio.sleep(0.05) + + assert ( + self.count == 4 + and len(app.handlers[1]) == 5 + and app.handlers[1][-1] is voice_filter_handler_to_check + ) + + await app.update_queue.put( + make_message_update(message=Message(5, None, None, caption='cap')) + ) + await asyncio.sleep(0.05) + + assert self.count == 2 and len(app.handlers[-1]) == 1 + + # Now lets test the errors which can be produced- + with pytest.raises(ValueError, match="The `group` argument"): + app.add_handlers({2: [msg_handler_set_count]}, group=0) + with pytest.raises(ValueError, match="Handlers for group 3"): + app.add_handlers({3: msg_handler_set_count}) + with pytest.raises(ValueError, match="The `handlers` argument must be a sequence"): + app.add_handlers({msg_handler_set_count}) + + await app.stop() + + @pytest.mark.asyncio + async def test_check_update(self, app): + class TestHandler(Handler): + def check_update(_, update: object): + self.received = object() + + def handle_update( + _, + update, + application, + check_result, + context, + ): + assert application is app + assert check_result is not self.received + + async with app: + app.add_handler(TestHandler('callback')) + await app.start() + await app.update_queue.put(object()) + await asyncio.sleep(0.05) + await app.stop() + + @pytest.mark.asyncio + async def test_flow_stop(self, app, bot): + passed = [] + + def start1(b, u): + passed.append('start1') + raise ApplicationHandlerStop + + def start2(b, u): + passed.append('start2') + + def start3(b, u): + passed.append('start3') + + def error(b, u, e): + passed.append('error') + passed.append(e) + + update = make_message_update( + message=Message( + 1, + None, + None, + None, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ), + ) + + # If ApplicationHandlerStop raised handlers in other groups should not be called. + passed = [] + app.add_handler(CommandHandler('start', start1), 1) + app.add_handler(CommandHandler('start', start3), 1) + app.add_handler(CommandHandler('start', start2), 2) + await app.process_update(update) + assert passed == ['start1'] + + @pytest.mark.asyncio + async def test_error_in_handler_part_1(self, app): + app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) + app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(42)), group=1) + app.add_error_handler(self.error_handler_context) + + async with app: + await app.start() + await app.update_queue.put(self.message_update) + await asyncio.sleep(0.05) + await app.stop() + + assert self.received == self.message_update.message.text + # Higher groups should still be called + assert self.count == 42 + + @pytest.mark.asyncio + async def test_error_in_handler_part_2(self, app, bot): + passed = [] + err = Exception('General exception') + + async def start1(u, c): + passed.append('start1') + raise err + + async def start2(u, c): + passed.append('start2') + + async def start3(u, c): + passed.append('start3') + + async def error(u, c): + passed.append('error') + passed.append(c.error) + + update = make_message_update( + message=Message( + 1, + None, + None, + None, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ), + ) + + # If an unhandled exception was caught, no further handlers from the same group should be + # called. Also, the error handler should be called and receive the exception + passed = [] + app.add_handler(CommandHandler('start', start1), 1) + app.add_handler(CommandHandler('start', start2), 1) + app.add_handler(CommandHandler('start', start3), 2) + app.add_error_handler(error) + await app.process_update(update) + assert passed == ['start1', 'error', err, 'start3'] - # TODProperly test app.start! - # @pytest.mark.asyncio - # async def test_error_start_twice(self, app): - # assert app.running - # app.start() # # def test_error_handler(self, app): # app.add_error_handler(self.error_handler_context) # error = TelegramError('Unauthorized.') - # app.update_queue.put(error) - # sleep(0.1) + # await app.update_queue.put(error) + # await asyncio.sleep(0.05) # assert self.received == 'Unauthorized.' # # # Remove handler # app.remove_error_handler(self.error_handler_context) # self.reset() # - # app.update_queue.put(error) - # sleep(0.1) + # await app.update_queue.put(error) + # await asyncio.sleep(0.05) # assert self.received is None # # def test_double_add_error_handler(self, app, caplog): @@ -622,15 +695,15 @@ async def test_groups(self, app): # # # From errors caused by handlers # app.add_handler(handler_raise_error) - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # # # From errors in the update_queue # app.remove_handler(handler_raise_error) # app.add_handler(handler_increase_count) - # app.update_queue.put(error) - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(error) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # # assert self.count == 1 # @@ -683,7 +756,7 @@ async def test_groups(self, app): # app.block(get_application_name, q1) # dp2.block(get_application_name, q2) # - # sleep(0.1) + # await asyncio.sleep(0.05) # # name1 = q1.get() # name2 = q2.get() @@ -696,8 +769,8 @@ async def test_groups(self, app): # # app.add_handler(MessageHandler(filters.ALL, callback, block=True)) # - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert len(recwarn) == 1 # assert str(recwarn[-1].message).startswith( # 'ApplicationHandlerStop is not supported with async functions' @@ -712,8 +785,8 @@ async def test_groups(self, app): # ) # ) # - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.received == self.message_update.message # # def test_run_async_no_error_handler(self, app, caplog): @@ -722,7 +795,7 @@ async def test_groups(self, app): # # with caplog.at_level(logging.ERROR): # app.block(func) - # sleep(0.1) + # await asyncio.sleep(0.05) # assert len(caplog.records) == 1 # assert caplog.records[-1].getMessage().startswith('No error handlers are registered') # @@ -730,7 +803,7 @@ async def test_groups(self, app): # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error, block=True)) # app.add_error_handler(self.error_handler_context, block=True) # - # app.update_queue.put(self.message_update) + # await app.update_queue.put(self.message_update) # sleep(2) # assert self.received == self.message_update.message.text # @@ -740,8 +813,8 @@ async def test_groups(self, app): # app.add_error_handler(self.error_handler_raise_error, block=False) # # with caplog.at_level(logging.ERROR): - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert len(caplog.records) == 1 # assert ( # caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') @@ -750,8 +823,8 @@ async def test_groups(self, app): # # Make sure that the main loop still runs # app.remove_handler(handler) # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.count == 1 # # def test_async_handler_async_error_handler_that_raises_error(self, app, caplog): @@ -760,8 +833,8 @@ async def test_groups(self, app): # app.add_error_handler(self.error_handler_raise_error, block=True) # # with caplog.at_level(logging.ERROR): - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert len(caplog.records) == 1 # assert ( # caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') @@ -770,26 +843,26 @@ async def test_groups(self, app): # # Make sure that the main loop still runs # app.remove_handler(handler) # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.count == 1 # # def test_error_in_handler(self, app): # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) # app.add_error_handler(self.error_handler_context) # - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.received == self.message_update.message.text # # def test_add_remove_handler(self, app): # handler = MessageHandler(filters.ALL, self.callback_increase_count) # app.add_handler(handler) - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.count == 1 # app.remove_handler(handler) - # app.update_queue.put(self.message_update) + # await app.update_queue.put(self.message_update) # assert self.count == 1 # # def test_add_remove_handler_non_default_group(self, app): @@ -807,8 +880,8 @@ async def test_groups(self, app): # app.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1))) # app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2))) # app.add_handler(MessageHandler(filters.TEXT, self.callback_set_count(3))) - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.count == 2 # # def test_groups(self, app): @@ -816,8 +889,8 @@ async def test_groups(self, app): # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=2) # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=-1) # - # app.update_queue.put(self.message_update) - # sleep(0.1) + # await app.update_queue.put(self.message_update) + # await asyncio.sleep(0.05) # assert self.count == 3 # # def test_add_handlers_complex(self, app): @@ -831,9 +904,9 @@ async def test_groups(self, app): # # photo_update = Update(2, message=Message(2, None, None, photo=True)) # # Putting updates in the queue calls the callback - # app.update_queue.put(self.message_update) - # app.update_queue.put(photo_update) - # sleep(0.1) # sleep is required otherwise there is random behaviour + # await app.update_queue.put(self.message_update) + # await app.update_queue.put(photo_update) + # await asyncio.sleep(0.05) # sleep is required otherwise there is random behaviour # # # Test if handler was added to correct group with correct order- # assert ( @@ -857,9 +930,9 @@ async def test_groups(self, app): # # user_update = Update(3, message=Message(3, None, None, from_user=User(1, 's', True))) # voice_update = Update(4, message=Message(4, None, None, voice=True)) - # app.update_queue.put(user_update) - # app.update_queue.put(voice_update) - # sleep(0.1) + # await app.update_queue.put(user_update) + # await app.update_queue.put(voice_update) + # await asyncio.sleep(0.05) # # assert ( # self.count == 4 @@ -867,8 +940,8 @@ async def test_groups(self, app): # and app.handlers[1][-1] is voice_filter_handler_to_check # ) # - # app.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) - # sleep(0.1) + # await app.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) + # await asyncio.sleep(0.05) # # assert self.count == 2 and len(app.handlers[-1]) == 1 # @@ -1491,7 +1564,7 @@ async def test_groups(self, app): # application.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) # # application.process_update(self.message_update) - # sleep(0.1) + # await asyncio.sleep(0.05) # assert self.received == (CustomContext, float, complex, int) # # def test_custom_context_handler_callback(self, bot): @@ -1516,5 +1589,5 @@ async def test_groups(self, app): # application.add_handler(MessageHandler(filters.ALL, callback)) # # application.process_update(self.message_update) - # sleep(0.1) + # await asyncio.sleep(0.05) # assert self.received == (CustomContext, float, complex, int) diff --git a/tests/test_updater.py b/tests/test_updater.py index 2ad09712d2a..fb11769f0f6 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -304,6 +304,8 @@ async def test_start_polling_already_running(self, updater): with pytest.raises(RuntimeError, match='already running'): await task await updater.stop() + with pytest.raises(RuntimeError, match='not running'): + await updater.stop() @pytest.mark.asyncio async def test_start_polling_get_updates_parameters(self, updater, monkeypatch): @@ -593,6 +595,8 @@ async def return_true(*args, **kwargs): with pytest.raises(RuntimeError, match='already running'): await task await updater.stop() + with pytest.raises(RuntimeError, match='not running'): + await updater.stop() @pytest.mark.asyncio async def test_start_webhook_parameters_passing(self, updater, monkeypatch): From d6f95a96307c4feeb0b5a6186146c2343d6ee0ed Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 2 Mar 2022 09:10:59 +0100 Subject: [PATCH 037/153] fix updater tests --- tests/test_updater.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_updater.py b/tests/test_updater.py index fb11769f0f6..4d751625b27 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -390,7 +390,6 @@ async def do_request(*args, **kwargs): await updater.start_polling( bootstrap_retries=retries, ) - await updater.stop() @pytest.mark.parametrize( 'error,callback_should_be_called', @@ -818,7 +817,6 @@ async def do_request(*args, **kwargs): await updater.start_webhook( bootstrap_retries=retries, ) - await updater.stop() @pytest.mark.asyncio async def test_webhook_invalid_posts(self, updater, monkeypatch): From af42aaf80a6d5482a8c971738dd17cea15838371 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 6 Mar 2022 16:00:11 +0100 Subject: [PATCH 038/153] More application tests - move tests for persistence integration into a standalone test file --- telegram/ext/_application.py | 15 +- tests/test_application.py | 1225 +++++++------------------ tests/test_persistence_integration.py | 361 ++++++++ tests/test_updater.py | 2 - 4 files changed, 724 insertions(+), 879 deletions(-) create mode 100644 tests/test_persistence_integration.py diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 3a0beda9e27..136d9b1b63e 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -617,7 +617,7 @@ def __create_task( if self.running: self.__create_task_tasks.add(task) - task.add_done_callback(self.__create_task_tasks.discard) + task.add_done_callback(self.__create_task_done_callback) else: _logger.warning( "Tasks created via `Application.create_task` while the application is not " @@ -626,6 +626,12 @@ def __create_task( return task + def __create_task_done_callback(self, task: asyncio.Task) -> None: + # We just retrieve the eventual exception so that asyncio doesn't complain in case + # it's not retrieved somewhere else + task.exception() + self.__create_task_tasks.discard(task) + async def __create_task_callback( self, coroutine: Coroutine[Any, Any, _RT], @@ -637,7 +643,9 @@ async def __create_task_callback( except Exception as exception: if isinstance(exception, ApplicationHandlerStop): warn( - 'ApplicationHandlerStop is not supported with asynchronously running handlers.' + 'ApplicationHandlerStop is not supported with asynchronously ' + 'running handlers.', + stacklevel=1, ) # Avoid infinite recursion of error handlers. @@ -654,6 +662,7 @@ async def __create_task_callback( # So we can and must handle it await self.dispatch_error(update, exception, coroutine=coroutine) + # Raise exception so that it can be set on the task raise exception finally: self._mark_for_persistence_update(update=update) @@ -704,7 +713,7 @@ async def process_update(self, update: object) -> None: try: for handler in handlers: check = handler.check_update(update) - if bool(check): + if not (check is None or check is False): if not context: context = self.context_types.context.from_update(update, self) await context.refresh_data() diff --git a/tests/test_application.py b/tests/test_application.py index f84d72a21df..c722c2e32f4 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -16,13 +16,17 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +"""The integration of persistence into the application is tested in test_persistence_integration. +""" import asyncio +import logging from collections import defaultdict +from pathlib import Path from queue import Queue import pytest -from telegram import Bot, Message, User, MessageEntity +from telegram import Bot, Message, User, MessageEntity, Chat from telegram.ext import ( JobQueue, CallbackContext, @@ -40,8 +44,9 @@ ) from telegram.error import TelegramError +from telegram.warnings import PTBUserWarning -from tests.conftest import make_message_update +from tests.conftest import make_message_update, PROJECT_ROOT_PATH class CustomContext(CallbackContext): @@ -49,6 +54,10 @@ class CustomContext(CallbackContext): class TestApplication: + """The integration of persistence into the application is tested in + test_persistence_integration. + """ + message_update = make_message_update(message='Text') received = None count = 0 @@ -70,14 +79,19 @@ async def error_handler_raise_error(self, update, context): async def callback_increase_count(self, update, context): self.count += 1 - def callback_set_count(self, count): + def callback_set_count(self, count, sleep: float = None): async def callback(update, context): + if sleep: + await asyncio.sleep(sleep) self.count = count return callback - async def callback_raise_error(self, update, context): - raise TelegramError(update.message.text) + def callback_raise_error(self, error_message: str): + async def callback(update, context): + raise TelegramError(error_message) + + return callback async def callback_received(self, update, context): self.received = update.message @@ -553,20 +567,16 @@ def handle_update( async def test_flow_stop(self, app, bot): passed = [] - def start1(b, u): + async def start1(b, u): passed.append('start1') raise ApplicationHandlerStop - def start2(b, u): + async def start2(b, u): passed.append('start2') - def start3(b, u): + async def start3(b, u): passed.append('start3') - def error(b, u, e): - passed.append('error') - passed.append(e) - update = make_message_update( message=Message( 1, @@ -589,9 +599,43 @@ def error(b, u, e): await app.process_update(update) assert passed == ['start1'] + @pytest.mark.asyncio + async def test_flow_stop_by_error_handler(self, app, bot): + passed = [] + exception = Exception('General excepition') + + async def start1(b, u): + passed.append('start1') + raise exception + + async def start2(b, u): + passed.append('start2') + + async def start3(b, u): + passed.append('start3') + + async def error(u, c): + passed.append('error') + passed.append(c.error) + raise ApplicationHandlerStop + + # If ApplicationHandlerStop raised handlers in other groups should not be called. + passed = [] + app.add_error_handler(error) + app.add_handler(TypeHandler(object, start1), 1) + app.add_handler(TypeHandler(object, start2), 1) + app.add_handler(TypeHandler(object, start3), 2) + await app.process_update(1) + assert passed == ['start1', 'error', exception] + @pytest.mark.asyncio async def test_error_in_handler_part_1(self, app): - app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) + app.add_handler( + MessageHandler( + filters.ALL, + self.callback_raise_error(error_message=self.message_update.message.text), + ) + ) app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(42)), group=1) app.add_error_handler(self.error_handler_context) @@ -648,64 +692,232 @@ async def error(u, c): await app.process_update(update) assert passed == ['start1', 'error', err, 'start3'] + @pytest.mark.asyncio + async def test_error_handler(self, app): + app.add_error_handler(self.error_handler_context) + app.add_handler(TypeHandler(object, self.callback_raise_error('TestError'))) + + async with app: + await app.start() + await app.update_queue.put(1) + await asyncio.sleep(0.05) + assert self.received == 'TestError' + + # Remove handler + app.remove_error_handler(self.error_handler_context) + self.reset() + + await app.update_queue.put(1) + await asyncio.sleep(0.05) + assert self.received is None + + await app.stop() + + def test_double_add_error_handler(self, app, caplog): + app.add_error_handler(self.error_handler_context) + with caplog.at_level(logging.DEBUG): + app.add_error_handler(self.error_handler_context) + assert len(caplog.records) == 1 + assert caplog.records[-1].getMessage().startswith('The callback is already registered') + + @pytest.mark.asyncio + async def test_error_handler_that_raises_errors(self, app, caplog): + """Make sure that errors raised in error handlers don't break the main loop of the + application + """ + handler_raise_error = TypeHandler( + int, self.callback_raise_error(error_message='TestError') + ) + handler_increase_count = TypeHandler(str, self.callback_increase_count) + + app.add_error_handler(self.error_handler_raise_error) + app.add_handler(handler_raise_error) + app.add_handler(handler_increase_count) + + with caplog.at_level(logging.ERROR): + async with app: + await app.start() + await app.update_queue.put(1) + await asyncio.sleep(0.05) + assert self.count == 0 + assert self.received is None + assert len(caplog.records) > 0 + log_messages = (record.getMessage() for record in caplog.records) + assert any( + 'uncaught error was raised while handling the error with an error_handler' + in message + for message in log_messages + ) + + await app.update_queue.put('1') + self.received = None + caplog.clear() + await asyncio.sleep(0.05) + assert self.count == 1 + assert self.received is None + assert not caplog.records + + await app.stop() + + @pytest.mark.asyncio + async def test_custom_context_error_handler(self, bot): + async def error_handler(_, context): + self.received = ( + type(context), + type(context.user_data), + type(context.chat_data), + type(context.bot_data), + ) + + application = ( + ApplicationBuilder() + .bot(bot) + .context_types( + ContextTypes( + context=CustomContext, bot_data=int, user_data=float, chat_data=complex + ) + ) + .build() + ) + application.add_error_handler(error_handler) + application.add_handler( + MessageHandler(filters.ALL, self.callback_raise_error('TestError')) + ) + + await application.process_update(self.message_update) + await asyncio.sleep(0.05) + assert self.received == (CustomContext, float, complex, int) + + @pytest.mark.asyncio + async def test_custom_context_handler_callback(self, bot): + def callback(_, context): + self.received = ( + type(context), + type(context.user_data), + type(context.chat_data), + type(context.bot_data), + ) + + application = ( + ApplicationBuilder() + .bot(bot) + .context_types( + ContextTypes( + context=CustomContext, bot_data=int, user_data=float, chat_data=complex + ) + ) + .build() + ) + application.add_handler(MessageHandler(filters.ALL, callback)) + + await application.process_update(self.message_update) + await asyncio.sleep(0.05) + assert self.received == (CustomContext, float, complex, int) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'check,expected', + [(True, True), (None, False), (False, False), ({}, True), ('', True), ('check', True)], + ) + async def test_check_update_handling(self, app, check, expected): + class MyHandler(Handler): + def check_update(self, update: object): + return check + + async def handle_update( + _, + update, + application, + check_result, + context, + ): + await super().handle_update( + update=update, + application=application, + check_result=check_result, + context=context, + ) + self.received = check_result + + app.add_handler(MyHandler(self.callback_increase_count)) + await app.process_update(1) + assert self.count == (1 if expected else 0) + if expected: + assert self.received == check + else: + assert self.received is None + + @pytest.mark.asyncio + async def test_non_blocking_handler(self, app): + event = asyncio.Event() + + async def callback(update, context): + await event.wait() + self.count = 42 + + app.add_handler(TypeHandler(object, callback, block=False)) + app.add_handler(TypeHandler(object, self.callback_increase_count), group=1) + + async with app: + await app.start() + await app.update_queue.put(1) + task = asyncio.create_task(app.stop()) + await asyncio.sleep(0.05) + assert self.count == 1 + # Make sure that app stops only once all non blocking callbacks are done + assert not task.done() + event.set() + await asyncio.sleep(0.05) + assert self.count == 42 + assert task.done() + + @pytest.mark.asyncio + async def test_non_blocking_handler_applicationhandlerstop(self, app, recwarn): + async def callback(update, context): + raise ApplicationHandlerStop + + app.add_handler(TypeHandler(object, callback, block=False)) + + async with app: + await app.start() + await app.update_queue.put(1) + await asyncio.sleep(0.05) + await app.stop() + + assert len(recwarn) == 1 + assert recwarn[0].category is PTBUserWarning + assert ( + str(recwarn[0].message) + == 'ApplicationHandlerStop is not supported with asynchronously running handlers.' + ) + assert ( + Path(recwarn[0].filename) == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' + ), "incorrect stacklevel!" + + @pytest.mark.asyncio + async def test_run_async_no_error_handler(self, app, caplog): + app.add_handler(TypeHandler(object, self.callback_raise_error, block=False)) + + with caplog.at_level(logging.ERROR): + async with app: + await app.start() + await app.update_queue.put(1) + await asyncio.sleep(0.05) + assert len(caplog.records) == 1 + assert ( + caplog.records[-1].getMessage().startswith('No error handlers are registered') + ) + await app.stop() + # - # def test_error_handler(self, app): - # app.add_error_handler(self.error_handler_context) - # error = TelegramError('Unauthorized.') - # await app.update_queue.put(error) - # await asyncio.sleep(0.05) - # assert self.received == 'Unauthorized.' - # - # # Remove handler - # app.remove_error_handler(self.error_handler_context) - # self.reset() - # - # await app.update_queue.put(error) - # await asyncio.sleep(0.05) - # assert self.received is None - # - # def test_double_add_error_handler(self, app, caplog): - # app.add_error_handler(self.error_handler_context) - # with caplog.at_level(logging.DEBUG): - # app.add_error_handler(self.error_handler_context) - # assert len(caplog.records) == 1 - # assert caplog.records[-1].getMessage().startswith( - # 'The callback is already registered') - # - # def test_construction_with_bad_persistence(self, caplog, bot): - # class my_per: - # def __init__(self): - # self.store_data = PersistenceInput(False, False, False, False) - # - # with pytest.raises( - # TypeError, match='persistence must be based on telegram.ext.BasePersistence' - # ): - # ApplicationBuilder().bot(bot).persistence(my_per()).build() - # - # def test_error_handler_that_raises_errors(self, app): - # """ - # Make sure that errors raised in error handlers don't break the main loop of the - # application - # """ - # handler_raise_error = MessageHandler(filters.ALL, self.callback_raise_error) - # handler_increase_count = MessageHandler(filters.ALL, self.callback_increase_count) - # error = TelegramError('Unauthorized.') - # - # app.add_error_handler(self.error_handler_raise_error) - # - # # From errors caused by handlers - # app.add_handler(handler_raise_error) - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) + # def test_async_handler_async_error_handler_context(self, app): + # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error, block=True)) + # app.add_error_handler(self.error_handler_context, block=True) # - # # From errors in the update_queue - # app.remove_handler(handler_raise_error) - # app.add_handler(handler_increase_count) - # await app.update_queue.put(error) # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # - # assert self.count == 1 + # sleep(2) + # assert self.received == self.message_update.message.text + # # @pytest.mark.parametrize(['block', 'expected_output'], [(True, 5), (False, 0)]) # def test_default_run_async_error_handler(self, app, monkeypatch, block, expected_output): @@ -746,67 +958,6 @@ async def error(u, c): # # reset defaults value # app.bot._defaults = None # - # def test_run_async_multiple(self, bot, app, dp2): - # def get_application_name(q): - # q.put(current_thread().name) - # - # q1 = Queue() - # q2 = Queue() - # - # app.block(get_application_name, q1) - # dp2.block(get_application_name, q2) - # - # await asyncio.sleep(0.05) - # - # name1 = q1.get() - # name2 = q2.get() - # - # assert name1 != name2 - # - # def test_async_raises_application_handler_stop(self, app, recwarn): - # def callback(update, context): - # raise ApplicationHandlerStop() - # - # app.add_handler(MessageHandler(filters.ALL, callback, block=True)) - # - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert len(recwarn) == 1 - # assert str(recwarn[-1].message).startswith( - # 'ApplicationHandlerStop is not supported with async functions' - # ) - # - # def test_add_async_handler(self, app): - # app.add_handler( - # MessageHandler( - # filters.ALL, - # self.callback_received, - # block=True, - # ) - # ) - # - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.received == self.message_update.message - # - # def test_run_async_no_error_handler(self, app, caplog): - # def func(): - # raise RuntimeError('Async Error') - # - # with caplog.at_level(logging.ERROR): - # app.block(func) - # await asyncio.sleep(0.05) - # assert len(caplog.records) == 1 - # assert caplog.records[-1].getMessage().startswith('No error handlers are registered') - # - # def test_async_handler_async_error_handler_context(self, app): - # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error, block=True)) - # app.add_error_handler(self.error_handler_context, block=True) - # - # await app.update_queue.put(self.message_update) - # sleep(2) - # assert self.received == self.message_update.message.text - # # def test_async_handler_error_handler_that_raises_error(self, app, caplog): # handler = MessageHandler(filters.ALL, self.callback_raise_error, block=True) # app.add_handler(handler) @@ -846,748 +997,74 @@ async def error(u, c): # await app.update_queue.put(self.message_update) # await asyncio.sleep(0.05) # assert self.count == 1 - # - # def test_error_in_handler(self, app): - # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - # app.add_error_handler(self.error_handler_context) - # - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.received == self.message_update.message.text - # - # def test_add_remove_handler(self, app): - # handler = MessageHandler(filters.ALL, self.callback_increase_count) - # app.add_handler(handler) - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.count == 1 - # app.remove_handler(handler) - # await app.update_queue.put(self.message_update) - # assert self.count == 1 - # - # def test_add_remove_handler_non_default_group(self, app): - # handler = MessageHandler(filters.ALL, self.callback_increase_count) - # app.add_handler(handler, group=2) - # with pytest.raises(KeyError): - # app.remove_handler(handler) - # app.remove_handler(handler, group=2) - # - # def test_error_start_twice(self, app): - # assert app.running - # app.start() - # - # def test_handler_order_in_group(self, app): - # app.add_handler(MessageHandler(filters.PHOTO, self.callback_set_count(1))) - # app.add_handler(MessageHandler(filters.ALL, self.callback_set_count(2))) - # app.add_handler(MessageHandler(filters.TEXT, self.callback_set_count(3))) - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.count == 2 - # - # def test_groups(self, app): - # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count)) - # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=2) - # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count), group=-1) - # - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.count == 3 - # - # def test_add_handlers_complex(self, app): - # """Tests both add_handler & add_handlers together & confirms the correct insertion - # order""" - # msg_handler_set_count = MessageHandler(filters.TEXT, self.callback_set_count(1)) - # msg_handler_inc_count = MessageHandler(filters.PHOTO, self.callback_increase_count) - # - # app.add_handler(msg_handler_set_count, 1) - # app.add_handlers((msg_handler_inc_count, msg_handler_inc_count), 1) - # - # photo_update = Update(2, message=Message(2, None, None, photo=True)) - # # Putting updates in the queue calls the callback - # await app.update_queue.put(self.message_update) - # await app.update_queue.put(photo_update) - # await asyncio.sleep(0.05) # sleep is required otherwise there is random behaviour - # - # # Test if handler was added to correct group with correct order- - # assert ( - # self.count == 2 - # and len(app.handlers[1]) == 3 - # and app.handlers[1][0] is msg_handler_set_count - # ) - # - # # Now lets test add_handlers when `handlers` is a dict- - # voice_filter_handler_to_check = MessageHandler(filters.VOICE, - # self.callback_increase_count) - # app.add_handlers( - # handlers={ - # 1: [ - # MessageHandler(filters.USER, self.callback_increase_count), - # voice_filter_handler_to_check, - # ], - # -1: [MessageHandler(filters.CAPTION, self.callback_set_count(2))], - # } - # ) - # - # user_update = Update(3, message=Message(3, None, None, from_user=User(1, 's', True))) - # voice_update = Update(4, message=Message(4, None, None, voice=True)) - # await app.update_queue.put(user_update) - # await app.update_queue.put(voice_update) - # await asyncio.sleep(0.05) - # - # assert ( - # self.count == 4 - # and len(app.handlers[1]) == 5 - # and app.handlers[1][-1] is voice_filter_handler_to_check - # ) - # - # await app.update_queue.put(Update(5, message=Message(5, None, None, caption='cap'))) - # await asyncio.sleep(0.05) - # - # assert self.count == 2 and len(app.handlers[-1]) == 1 - # - # # Now lets test the errors which can be produced- - # with pytest.raises(ValueError, match="The `group` argument"): - # app.add_handlers({2: [msg_handler_set_count]}, group=0) - # with pytest.raises(ValueError, match="Handlers for group 3"): - # app.add_handlers({3: msg_handler_set_count}) - # with pytest.raises(ValueError, match="The `handlers` argument must be a sequence"): - # app.add_handlers({msg_handler_set_count}) - # - # def test_add_handler_errors(self, app): - # handler = 'not a handler' - # with pytest.raises(TypeError, match='handler is not an instance of'): - # app.add_handler(handler) - # - # handler = MessageHandler(filters.PHOTO, self.callback_set_count(1)) - # with pytest.raises(TypeError, match='group is not int'): - # app.add_handler(handler, 'one') - # - # def test_flow_stop(self, app, bot): - # passed = [] - # - # def start1(b, u): - # passed.append('start1') - # raise ApplicationHandlerStop - # - # def start2(b, u): - # passed.append('start2') - # - # def start3(b, u): - # passed.append('start3') - # - # def error(b, u, e): - # passed.append('error') - # passed.append(e) - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # None, - # None, - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # - # # If Stop raised handlers in other groups should not be called. - # passed = [] - # app.add_handler(CommandHandler('start', start1), 1) - # app.add_handler(CommandHandler('start', start3), 1) - # app.add_handler(CommandHandler('start', start2), 2) - # app.process_update(update) - # assert passed == ['start1'] - # - # def test_exception_in_handler(self, app, bot): - # passed = [] - # err = Exception('General exception') - # - # def start1(u, c): - # passed.append('start1') - # raise err - # - # def start2(u, c): - # passed.append('start2') - # - # def start3(u, c): - # passed.append('start3') - # - # def error(u, c): - # passed.append('error') - # passed.append(c.error) - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # None, - # None, - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # - # # If an unhandled exception was caught, no further handlers from the same group should be - # # called. Also, the error handler should be called and receive the exception - # passed = [] - # app.add_handler(CommandHandler('start', start1), 1) - # app.add_handler(CommandHandler('start', start2), 1) - # app.add_handler(CommandHandler('start', start3), 2) - # app.add_error_handler(error) - # app.process_update(update) - # assert passed == ['start1', 'error', err, 'start3'] - # - # def test_telegram_error_in_handler(self, app, bot): - # passed = [] - # err = TelegramError('Telegram error') - # - # def start1(u, c): - # passed.append('start1') - # raise err - # - # def start2(u, c): - # passed.append('start2') - # - # def start3(u, c): - # passed.append('start3') - # - # def error(u, c): - # passed.append('error') - # passed.append(c.error) - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # None, - # None, - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # - # # If a TelegramException was caught, an error handler should be called and no further - # # handlers from the same group should be called. - # app.add_handler(CommandHandler('start', start1), 1) - # app.add_handler(CommandHandler('start', start2), 1) - # app.add_handler(CommandHandler('start', start3), 2) - # app.add_error_handler(error) - # app.process_update(update) - # assert passed == ['start1', 'error', err, 'start3'] - # assert passed[2] is err - # - # def test_error_while_saving_chat_data(self, bot): - # increment = [] - # - # class OwnPersistence(BasePersistence): - # def get_callback_data(self): - # return None - # - # def update_callback_data(self, data): - # raise Exception - # - # def get_bot_data(self): - # return {} - # - # def update_bot_data(self, data): - # raise Exception - # - # def drop_chat_data(self, chat_id): - # pass - # - # def drop_user_data(self, user_id): - # pass - # - # def get_chat_data(self): - # return defaultdict(dict) - # - # def update_chat_data(self, chat_id, data): - # raise Exception - # - # def get_user_data(self): - # return defaultdict(dict) - # - # def update_user_data(self, user_id, data): - # raise Exception - # - # def get_conversations(self, name): - # pass - # - # def update_conversation(self, name, key, new_state): - # pass - # - # def refresh_user_data(self, user_id, user_data): - # pass - # - # def refresh_chat_data(self, chat_id, chat_data): - # pass - # - # def refresh_bot_data(self, bot_data): - # pass - # - # def flush(self): - # pass - # - # def start1(u, c): - # pass - # - # def error(u, c): - # increment.append("error") - # - # # If updating a user_data or chat_data from a persistence object throws an error, - # # the error handler should catch it - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # Chat(1, "lala"), - # from_user=User(1, "Test", False), - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # my_persistence = OwnPersistence() - # app = ApplicationBuilder().bot(bot).persistence(my_persistence).build() - # app.add_handler(CommandHandler('start', start1)) - # app.add_error_handler(error) - # app.process_update(update) - # assert increment == ["error", "error", "error", "error"] - # - # def test_flow_stop_in_error_handler(self, app, bot): - # passed = [] - # err = TelegramError('Telegram error') - # - # def start1(u, c): - # passed.append('start1') - # raise err - # - # def start2(u, c): - # passed.append('start2') - # - # def start3(u, c): - # passed.append('start3') - # - # def error(u, c): - # passed.append('error') - # passed.append(c.error) - # raise ApplicationHandlerStop - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # None, - # None, - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # - # # If a TelegramException was caught, an error handler should be called and no further - # # handlers from the same group should be called. - # app.add_handler(CommandHandler('start', start1), 1) - # app.add_handler(CommandHandler('start', start2), 1) - # app.add_handler(CommandHandler('start', start3), 2) - # app.add_error_handler(error) - # app.process_update(update) - # assert passed == ['start1', 'error', err] - # assert passed[2] is err - # - # def test_sensible_worker_thread_names(self, dp2): - # thread_names = [thread.name for thread in dp2._Application__async_threads] - # for thread_name in thread_names: - # assert thread_name.startswith(f"Bot:{dp2.bot.id}:worker:") - # - # @pytest.mark.parametrize( - # 'message', - # [ - # Message(message_id=1, chat=Chat(id=2, type=None), migrate_from_chat_id=1, date=None), - # Message(message_id=1, chat=Chat(id=1, type=None), migrate_to_chat_id=2, date=None), - # Message(message_id=1, chat=Chat(id=1, type=None), date=None), - # None, - # ], - # ) - # @pytest.mark.parametrize('old_chat_id', [None, 1, "1"]) - # @pytest.mark.parametrize('new_chat_id', [None, 2, "1"]) - # def test_migrate_chat_data(self, app, message: 'Message', old_chat_id: int, - # new_chat_id: int): - # def call(match: str): - # with pytest.raises(ValueError, match=match): - # app.migrate_chat_data( - # message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id - # ) - # - # if message and (old_chat_id or new_chat_id): - # call(r"^Message and chat_id pair are mutually exclusive$") - # return - # - # if not any((message, old_chat_id, new_chat_id)): - # call(r"^chat_id pair or message must be passed$") - # return - # - # if message: - # if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: - # call(r"^Invalid message instance") - # return - # effective_old_chat_id = message.migrate_from_chat_id or message.chat.id - # effective_new_chat_id = message.migrate_to_chat_id or message.chat.id - # - # elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): - # call(r"^old_chat_id and new_chat_id must be integers$") - # return - # else: - # effective_old_chat_id = old_chat_id - # effective_new_chat_id = new_chat_id - # - # app.chat_data[effective_old_chat_id]['key'] = "test" - # app.migrate_chat_data(message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id) - # assert effective_old_chat_id not in app.chat_data - # assert app.chat_data[effective_new_chat_id]['key'] == "test" - # - # def test_error_while_persisting(self, app, caplog): - # class OwnPersistence(BasePersistence): - # def update(self, data): - # raise Exception('PersistenceError') - # - # def update_callback_data(self, data): - # self.update(data) - # - # def update_bot_data(self, data): - # self.update(data) - # - # def update_chat_data(self, chat_id, data): - # self.update(data) - # - # def update_user_data(self, user_id, data): - # self.update(data) - # - # def drop_user_data(self, user_id): - # pass - # - # def drop_chat_data(self, chat_id): - # pass - # - # def get_chat_data(self): - # pass - # - # def get_bot_data(self): - # pass - # - # def get_user_data(self): - # pass - # - # def get_callback_data(self): - # pass - # - # def get_conversations(self, name): - # pass - # - # def update_conversation(self, name, key, new_state): - # pass - # - # def refresh_bot_data(self, bot_data): - # pass - # - # def refresh_user_data(self, user_id, user_data): - # pass - # - # def refresh_chat_data(self, chat_id, chat_data): - # pass - # - # def flush(self): - # pass - # - # def callback(update, context): - # pass - # - # test_flag = [] - # - # def error(update, context): - # nonlocal test_flag - # test_flag.append(str(context.error) == 'PersistenceError') - # raise Exception('ErrorHandlingError') - # - # update = Update( - # 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - # ) - # handler = MessageHandler(filters.ALL, callback) - # app.add_handler(handler) - # app.add_error_handler(error) - # - # app.persistence = OwnPersistence() - # - # with caplog.at_level(logging.ERROR): - # app.process_update(update) - # - # assert test_flag == [True, True, True, True] - # assert len(caplog.records) == 4 - # for record in caplog.records: - # message = record.getMessage() - # assert message.startswith('An error was raised and an uncaught') - # - # def test_persisting_no_user_no_chat(self, app): - # class OwnPersistence(BasePersistence): - # def __init__(self): - # super().__init__() - # self.test_flag_bot_data = False - # self.test_flag_chat_data = False - # self.test_flag_user_data = False - # - # def update_bot_data(self, data): - # self.test_flag_bot_data = True - # - # def update_chat_data(self, chat_id, data): - # self.test_flag_chat_data = True - # - # def update_user_data(self, user_id, data): - # self.test_flag_user_data = True - # - # def update_conversation(self, name, key, new_state): - # pass - # - # def drop_chat_data(self, chat_id): - # pass - # - # def drop_user_data(self, user_id): - # pass - # - # def get_conversations(self, name): - # pass - # - # def get_user_data(self): - # pass - # - # def get_bot_data(self): - # pass - # - # def get_chat_data(self): - # pass - # - # def refresh_bot_data(self, bot_data): - # pass - # - # def refresh_user_data(self, user_id, user_data): - # pass - # - # def refresh_chat_data(self, chat_id, chat_data): - # pass - # - # def get_callback_data(self): - # pass - # - # def update_callback_data(self, data): - # pass - # - # def flush(self): - # pass - # - # def callback(update, context): - # pass - # - # handler = MessageHandler(filters.ALL, callback) - # app.add_handler(handler) - # app.persistence = OwnPersistence() - # - # update = Update( - # 1, message=Message(1, None, None, from_user=User(1, '', False), text='Text') - # ) - # app.process_update(update) - # assert app.persistence.test_flag_bot_data - # assert app.persistence.test_flag_user_data - # assert not app.persistence.test_flag_chat_data - # - # app.persistence.test_flag_bot_data = False - # app.persistence.test_flag_user_data = False - # app.persistence.test_flag_chat_data = False - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - # app.process_update(update) - # assert app.persistence.test_flag_bot_data - # assert not app.persistence.test_flag_user_data - # assert app.persistence.test_flag_chat_data - # - # @pytest.mark.parametrize( - # "c_id,expected", - # [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], - # ids=["test chat_id removal", "test no key in data (no error)"], - # ) - # def test_drop_chat_data(self, app, c_id, expected): - # app._chat_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) - # app.drop_chat_data(c_id) - # assert app.chat_data == expected - # - # @pytest.mark.parametrize( - # "u_id,expected", - # [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], - # ids=["test user_id removal", "test no key in data (no error)"], - # ) - # def test_drop_user_data(self, app, u_id, expected): - # app._user_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) - # app.drop_user_data(u_id) - # assert app.user_data == expected - # - # def test_update_persistence_once_per_update(self, monkeypatch, app): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # - # for group in range(5): - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text=None)) - # app.process_update(update) - # assert self.count == 0 - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='text')) - # app.process_update(update) - # assert self.count == 1 - # - # def test_update_persistence_all_async(self, monkeypatch, app): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args, **kwargs): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # monkeypatch.setattr(app, 'block', dummy_callback) - # - # for group in range(5): - # app.add_handler( - # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group - # ) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - # app.process_update(update) - # assert self.count == 0 - # - # app.bot._defaults = Defaults(block=True) - # try: - # for group in range(5): - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, - # text='Text')) - # app.process_update(update) - # assert self.count == 0 - # finally: - # app.bot._defaults = None - # - # @pytest.mark.parametrize('block', [DEFAULT_FALSE, False]) - # def test_update_persistence_one_sync(self, monkeypatch, app, block): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args, **kwargs): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # monkeypatch.setattr(app, 'block', dummy_callback) - # - # for group in range(5): - # app.add_handler( - # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group - # ) - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback, block=block),group=5) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - # app.process_update(update) - # assert self.count == 1 - # - # @pytest.mark.parametrize('block,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)]) - # def test_update_persistence_defaults_async(self, monkeypatch, app, block, expected): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args, **kwargs): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # monkeypatch.setattr(app, 'block', dummy_callback) - # app.bot._defaults = Defaults(block=block) - # - # try: - # for group in range(5): - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, - # text='Text')) - # app.process_update(update) - # assert self.count == expected - # finally: - # app.bot._defaults = None - # - # def test_custom_context_error_handler(self, bot): - # def error_handler(_, context): - # self.received = ( - # type(context), - # type(context.user_data), - # type(context.chat_data), - # type(context.bot_data), - # ) - # - # application = ( - # ApplicationBuilder() - # .bot(bot) - # .context_types( - # ContextTypes( - # context=CustomContext, bot_data=int, user_data=float, chat_data=complex - # ) - # ) - # .build() - # ) - # application.add_error_handler(error_handler) - # application.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - # - # application.process_update(self.message_update) - # await asyncio.sleep(0.05) - # assert self.received == (CustomContext, float, complex, int) - # - # def test_custom_context_handler_callback(self, bot): - # def callback(_, context): - # self.received = ( - # type(context), - # type(context.user_data), - # type(context.chat_data), - # type(context.bot_data), - # ) - # - # application = ( - # ApplicationBuilder() - # .bot(bot) - # .context_types( - # ContextTypes( - # context=CustomContext, bot_data=int, user_data=float, chat_data=complex - # ) - # ) - # .build() - # ) - # application.add_handler(MessageHandler(filters.ALL, callback)) - # - # application.process_update(self.message_update) - # await asyncio.sleep(0.05) - # assert self.received == (CustomContext, float, complex, int) + + @pytest.mark.parametrize( + 'message', + [ + Message(message_id=1, chat=Chat(id=2, type=None), migrate_from_chat_id=1, date=None), + Message(message_id=1, chat=Chat(id=1, type=None), migrate_to_chat_id=2, date=None), + Message(message_id=1, chat=Chat(id=1, type=None), date=None), + None, + ], + ) + @pytest.mark.parametrize('old_chat_id', [None, 1, "1"]) + @pytest.mark.parametrize('new_chat_id', [None, 2, "1"]) + def test_migrate_chat_data(self, app, message: 'Message', old_chat_id: int, new_chat_id: int): + def call(match: str): + with pytest.raises(ValueError, match=match): + app.migrate_chat_data( + message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id + ) + + if message and (old_chat_id or new_chat_id): + call(r"^Message and chat_id pair are mutually exclusive$") + return + + if not any((message, old_chat_id, new_chat_id)): + call(r"^chat_id pair or message must be passed$") + return + + if message: + if message.migrate_from_chat_id is None and message.migrate_to_chat_id is None: + call(r"^Invalid message instance") + return + effective_old_chat_id = message.migrate_from_chat_id or message.chat.id + effective_new_chat_id = message.migrate_to_chat_id or message.chat.id + + elif not (isinstance(old_chat_id, int) and isinstance(new_chat_id, int)): + call(r"^old_chat_id and new_chat_id must be integers$") + return + else: + effective_old_chat_id = old_chat_id + effective_new_chat_id = new_chat_id + + app.chat_data[effective_old_chat_id]['key'] = "test" + app.migrate_chat_data(message=message, old_chat_id=old_chat_id, new_chat_id=new_chat_id) + assert effective_old_chat_id not in app.chat_data + assert app.chat_data[effective_new_chat_id]['key'] == "test" + + @pytest.mark.parametrize( + "c_id,expected", + [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], + ids=["test chat_id removal", "test no key in data (no error)"], + ) + def test_drop_chat_data(self, app, c_id, expected): + app._chat_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) + app.drop_chat_data(c_id) + assert app.chat_data == expected + + @pytest.mark.parametrize( + "u_id,expected", + [(321, {222: "remove_me"}), (111, {321: {'not_empty': 'no'}, 222: "remove_me"})], + ids=["test user_id removal", "test no key in data (no error)"], + ) + def test_drop_user_data(self, app, u_id, expected): + app._user_data.update({321: {'not_empty': 'no'}, 222: "remove_me"}) + app.drop_user_data(u_id) + assert app.user_data == expected + + # TODO: + # * Test stop() with updater running + # * Test run_polling/webhook + # * Test concurrent updates + # * Test create_task diff --git a/tests/test_persistence_integration.py b/tests/test_persistence_integration.py new file mode 100644 index 00000000000..ab43151f6a1 --- /dev/null +++ b/tests/test_persistence_integration.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import pytest + +from telegram.ext import ( + ApplicationBuilder, + PersistenceInput, +) + + +class TestPersistenceIntegration: + # TODO: + # * Test add_handler with persistent conversationhandler + # * Test migrate_chat_data + # * Test drop_chat/user_data + # * Test update_persistence & flush getting called on shutdown + + def test_construction_with_bad_persistence(self, caplog, bot): + class MyPersistence: + def __init__(self): + self.store_data = PersistenceInput(False, False, False, False) + + with pytest.raises( + TypeError, match='persistence must be based on telegram.ext.BasePersistence' + ): + ApplicationBuilder().bot(bot).persistence(MyPersistence()).build() + + # + # def test_error_while_saving_chat_data(self, bot): + # increment = [] + # + # class OwnPersistence(BasePersistence): + # def get_callback_data(self): + # return None + # + # def update_callback_data(self, data): + # raise Exception + # + # def get_bot_data(self): + # return {} + # + # def update_bot_data(self, data): + # raise Exception + # + # def drop_chat_data(self, chat_id): + # pass + # + # def drop_user_data(self, user_id): + # pass + # + # def get_chat_data(self): + # return defaultdict(dict) + # + # def update_chat_data(self, chat_id, data): + # raise Exception + # + # def get_user_data(self): + # return defaultdict(dict) + # + # def update_user_data(self, user_id, data): + # raise Exception + # + # def get_conversations(self, name): + # pass + # + # def update_conversation(self, name, key, new_state): + # pass + # + # def refresh_user_data(self, user_id, user_data): + # pass + # + # def refresh_chat_data(self, chat_id, chat_data): + # pass + # + # def refresh_bot_data(self, bot_data): + # pass + # + # def flush(self): + # pass + # + # def start1(u, c): + # pass + # + # def error(u, c): + # increment.append("error") + # + # # If updating a user_data or chat_data from a persistence object throws an error, + # # the error handler should catch it + # + # update = Update( + # 1, + # message=Message( + # 1, + # None, + # Chat(1, "lala"), + # from_user=User(1, "Test", False), + # text='/start', + # entities=[ + # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + # ], + # bot=bot, + # ), + # ) + # my_persistence = OwnPersistence() + # app = ApplicationBuilder().bot(bot).persistence(my_persistence).build() + # app.add_handler(CommandHandler('start', start1)) + # app.add_error_handler(error) + # app.process_update(update) + # assert increment == ["error", "error", "error", "error"] + # + # def test_error_while_persisting(self, app, caplog): + # class OwnPersistence(BasePersistence): + # def update(self, data): + # raise Exception('PersistenceError') + # + # def update_callback_data(self, data): + # self.update(data) + # + # def update_bot_data(self, data): + # self.update(data) + # + # def update_chat_data(self, chat_id, data): + # self.update(data) + # + # def update_user_data(self, user_id, data): + # self.update(data) + # + # def drop_user_data(self, user_id): + # pass + # + # def drop_chat_data(self, chat_id): + # pass + # + # def get_chat_data(self): + # pass + # + # def get_bot_data(self): + # pass + # + # def get_user_data(self): + # pass + # + # def get_callback_data(self): + # pass + # + # def get_conversations(self, name): + # pass + # + # def update_conversation(self, name, key, new_state): + # pass + # + # def refresh_bot_data(self, bot_data): + # pass + # + # def refresh_user_data(self, user_id, user_data): + # pass + # + # def refresh_chat_data(self, chat_id, chat_data): + # pass + # + # def flush(self): + # pass + # + # def callback(update, context): + # pass + # + # test_flag = [] + # + # def error(update, context): + # nonlocal test_flag + # test_flag.append(str(context.error) == 'PersistenceError') + # raise Exception('ErrorHandlingError') + # + # update = Update( + # 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + # ) + # handler = MessageHandler(filters.ALL, callback) + # app.add_handler(handler) + # app.add_error_handler(error) + # + # app.persistence = OwnPersistence() + # + # with caplog.at_level(logging.ERROR): + # app.process_update(update) + # + # assert test_flag == [True, True, True, True] + # assert len(caplog.records) == 4 + # for record in caplog.records: + # message = record.getMessage() + # assert message.startswith('An error was raised and an uncaught') + # + # def test_persisting_no_user_no_chat(self, app): + # class OwnPersistence(BasePersistence): + # def __init__(self): + # super().__init__() + # self.test_flag_bot_data = False + # self.test_flag_chat_data = False + # self.test_flag_user_data = False + # + # def update_bot_data(self, data): + # self.test_flag_bot_data = True + # + # def update_chat_data(self, chat_id, data): + # self.test_flag_chat_data = True + # + # def update_user_data(self, user_id, data): + # self.test_flag_user_data = True + # + # def update_conversation(self, name, key, new_state): + # pass + # + # def drop_chat_data(self, chat_id): + # pass + # + # def drop_user_data(self, user_id): + # pass + # + # def get_conversations(self, name): + # pass + # + # def get_user_data(self): + # pass + # + # def get_bot_data(self): + # pass + # + # def get_chat_data(self): + # pass + # + # def refresh_bot_data(self, bot_data): + # pass + # + # def refresh_user_data(self, user_id, user_data): + # pass + # + # def refresh_chat_data(self, chat_id, chat_data): + # pass + # + # def get_callback_data(self): + # pass + # + # def update_callback_data(self, data): + # pass + # + # def flush(self): + # pass + # + # def callback(update, context): + # pass + # + # handler = MessageHandler(filters.ALL, callback) + # app.add_handler(handler) + # app.persistence = OwnPersistence() + # + # update = Update( + # 1, message=Message(1, None, None, from_user=User(1, '', False), text='Text') + # ) + # app.process_update(update) + # assert app.persistence.test_flag_bot_data + # assert app.persistence.test_flag_user_data + # assert not app.persistence.test_flag_chat_data + # + # app.persistence.test_flag_bot_data = False + # app.persistence.test_flag_user_data = False + # app.persistence.test_flag_chat_data = False + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) + # app.process_update(update) + # assert app.persistence.test_flag_bot_data + # assert not app.persistence.test_flag_user_data + # assert app.persistence.test_flag_chat_data + # + # def test_update_persistence_all_async(self, monkeypatch, app): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args, **kwargs): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # monkeypatch.setattr(app, 'block', dummy_callback) + # + # for group in range(5): + # app.add_handler( + # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group + # ) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) + # app.process_update(update) + # assert self.count == 0 + # + # app.bot._defaults = Defaults(block=True) + # try: + # for group in range(5): + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, + # text='Text')) + # app.process_update(update) + # assert self.count == 0 + # finally: + # app.bot._defaults = None + # + # @pytest.mark.parametrize('block', [DEFAULT_FALSE, False]) + # def test_update_persistence_one_sync(self, monkeypatch, app, block): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args, **kwargs): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # monkeypatch.setattr(app, 'block', dummy_callback) + # + # for group in range(5): + # app.add_handler( + # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group + # ) + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback, block=block),group=5) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) + # app.process_update(update) + # assert self.count == 1 + # + # @pytest.mark.parametrize('block,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)]) + # def test_update_persistence_defaults_async(self, monkeypatch, app, block, expected): + # def update_persistence(*args, **kwargs): + # self.count += 1 + # + # def dummy_callback(*args, **kwargs): + # pass + # + # monkeypatch.setattr(app, 'update_persistence', update_persistence) + # monkeypatch.setattr(app, 'block', dummy_callback) + # app.bot._defaults = Defaults(block=block) + # + # try: + # for group in range(5): + # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) + # + # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, + # text='Text')) + # app.process_update(update) + # assert self.count == expected + # finally: + # app.bot._defaults = None diff --git a/tests/test_updater.py b/tests/test_updater.py index fb11769f0f6..4d751625b27 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -390,7 +390,6 @@ async def do_request(*args, **kwargs): await updater.start_polling( bootstrap_retries=retries, ) - await updater.stop() @pytest.mark.parametrize( 'error,callback_should_be_called', @@ -818,7 +817,6 @@ async def do_request(*args, **kwargs): await updater.start_webhook( bootstrap_retries=retries, ) - await updater.stop() @pytest.mark.asyncio async def test_webhook_invalid_posts(self, updater, monkeypatch): From a7b5fbd217e44dd2d28cde7fb9d0364e844d4a20 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 6 Mar 2022 17:39:45 +0100 Subject: [PATCH 039/153] jobqueue tests - jk, it's still application tests --- telegram/ext/_application.py | 15 ++--- tests/test_application.py | 109 +++++++++++++++++++++++++---------- 2 files changed, 84 insertions(+), 40 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 136d9b1b63e..b32456962fb 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -1109,14 +1109,6 @@ def add_error_handler( _logger.warning('The callback is already registered as an error handler. Ignoring.') return - if ( - block is DEFAULT_TRUE - and isinstance(self.bot, ExtBot) - and self.bot.defaults - and not self.bot.defaults.block - ): - block = False - self.error_handlers[callback] = block def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None: @@ -1170,7 +1162,12 @@ async def dispatch_error( job=job, coroutine=coroutine, ) - if not block: + if not block or ( + block is DEFAULT_TRUE + and isinstance(self.bot, ExtBot) + and self.bot.defaults + and not self.bot.defaults.block + ): self.__create_task( callback(update, context), update=update, is_error_handler=True ) diff --git a/tests/test_application.py b/tests/test_application.py index c722c2e32f4..b5e366d764f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -41,6 +41,7 @@ ApplicationHandlerStop, CommandHandler, TypeHandler, + Defaults, ) from telegram.error import TelegramError @@ -693,9 +694,10 @@ async def error(u, c): assert passed == ['start1', 'error', err, 'start3'] @pytest.mark.asyncio - async def test_error_handler(self, app): + @pytest.mark.parametrize('block', (True, False)) + async def test_error_handler(self, app, block): app.add_error_handler(self.error_handler_context) - app.add_handler(TypeHandler(object, self.callback_raise_error('TestError'))) + app.add_handler(TypeHandler(object, self.callback_raise_error('TestError'), block=block)) async with app: await app.start() @@ -895,7 +897,7 @@ async def callback(update, context): ), "incorrect stacklevel!" @pytest.mark.asyncio - async def test_run_async_no_error_handler(self, app, caplog): + async def test_non_blocking_no_error_handler(self, app, caplog): app.add_handler(TypeHandler(object, self.callback_raise_error, block=False)) with caplog.at_level(logging.ERROR): @@ -909,35 +911,80 @@ async def test_run_async_no_error_handler(self, app, caplog): ) await app.stop() - # - # def test_async_handler_async_error_handler_context(self, app): - # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error, block=True)) - # app.add_error_handler(self.error_handler_context, block=True) - # - # await app.update_queue.put(self.message_update) - # sleep(2) - # assert self.received == self.message_update.message.text + @pytest.mark.asyncio + @pytest.mark.parametrize('handler_block', (True, False)) + async def test_non_blocking_error_handler(self, app, handler_block): + event = asyncio.Event() + + async def async_error_handler(update, context): + await event.wait() + self.received = 'done' + + async def normal_error_handler(update, context): + self.count = 42 + + app.add_error_handler(async_error_handler, block=False) + app.add_error_handler(normal_error_handler) + app.add_handler(TypeHandler(object, self.callback_raise_error, block=handler_block)) + + async with app: + await app.start() + await app.update_queue.put(self.message_update) + task = asyncio.create_task(app.stop()) + await asyncio.sleep(0.05) + assert self.count == 42 + assert self.received is None + event.set() + await asyncio.sleep(0.05) + assert self.received == 'done' + assert task.done() + + @pytest.mark.asyncio + @pytest.mark.parametrize('handler_block', (True, False)) + async def test_non_blocking_error_handler_applicationhandlerstop( + self, app, recwarn, handler_block + ): + async def callback(update, context): + raise RuntimeError() + + async def error_handler(update, context): + raise ApplicationHandlerStop + + app.add_handler(TypeHandler(object, callback, block=handler_block)) + app.add_error_handler(error_handler, block=False) + + async with app: + await app.start() + await app.update_queue.put(1) + await asyncio.sleep(0.05) + await app.stop() + + assert len(recwarn) == 1 + assert recwarn[0].category is PTBUserWarning + assert ( + str(recwarn[0].message) + == 'ApplicationHandlerStop is not supported with asynchronously running handlers.' + ) + assert ( + Path(recwarn[0].filename) == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' + ), "incorrect stacklevel!" + + @pytest.mark.parametrize(['block', 'expected_output'], [(False, 0), (True, 5)]) + @pytest.mark.asyncio + async def test_default_block_error_handler(self, bot, monkeypatch, block, expected_output): + async def error_handler(*args, **kwargs): + await asyncio.sleep(0.1) + self.count = 5 + + app = Application.builder().token(bot.token).defaults(Defaults(block=block)).build() + app.add_handler(TypeHandler(object, self.callback_raise_error)) + app.add_error_handler(error_handler) + await app.process_update(1) + await asyncio.sleep(0.05) + assert self.count == expected_output + await asyncio.sleep(0.1) + assert self.count == 5 - # - # @pytest.mark.parametrize(['block', 'expected_output'], [(True, 5), (False, 0)]) - # def test_default_run_async_error_handler(self, app, monkeypatch, block, expected_output): - # def mock_async_err_handler(*args, **kwargs): - # self.count = 5 - # - # # set defaults value to app.bot - # app.bot._defaults = Defaults(block=block) - # try: - # app.add_handler(MessageHandler(filters.ALL, self.callback_raise_error)) - # app.add_error_handler(self.error_handler_context) - # - # monkeypatch.setattr(app, 'block', mock_async_err_handler) - # app.process_update(self.message_update) - # - # assert self.count == expected_output - # - # finally: - # # reset app.bot.defaults values - # app.bot._defaults = None # # @pytest.mark.parametrize( # ['block', 'expected_output'], [(True, 'running async'), (False, None)] From 59c96faf353bb0f7c4109966df62dcccd617ebd1 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 7 Mar 2022 21:13:47 +0100 Subject: [PATCH 040/153] you guessed it: application tests --- telegram/_bot.py | 2 +- telegram/ext/_application.py | 24 ++- telegram/ext/_jobqueue.py | 6 + telegram/ext/_updater.py | 2 +- tests/test_application.py | 258 +++++++++++++++++++------- tests/test_persistence_integration.py | 1 + 6 files changed, 221 insertions(+), 72 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index 93e8efdf33e..b61d8ea786f 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -357,7 +357,7 @@ async def initialize(self) -> None: :attr:`request`. """ if self._initialized: - self._logger.warning('This Bot is already initialized.') + self._logger.debug('This Bot is already initialized.') return await asyncio.gather(self._request[0].initialize(), self._request[1].initialize()) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index b32456962fb..ffafb5a8a2d 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -271,7 +271,7 @@ def concurrent_updates(self) -> int: async def initialize(self) -> None: if self._initialized: - _logger.warning('This Application is already initialized.') + _logger.debug('This Application is already initialized.') return await self.bot.initialize() @@ -392,7 +392,7 @@ async def start(self, ready: Event = None) -> None: :attr:`persistence` is set. Note: - This does *not* start fetching updates from Telegram. You need either start + This does *not* start fetching updates from Telegram. You need to either start :attr:`updater` manually or use one of :meth:`run_polling` or :meth:`run_webhook`. Args: @@ -438,8 +438,7 @@ async def start(self, ready: Event = None) -> None: async def stop(self) -> None: """Stops the process after processing any pending updates or tasks created by - :meth:`create_task`. Also stops :attr:`job_queue`, if set and :attr:`updater`, if set and - running. + :meth:`create_task`. Also stops :attr:`job_queue`, if set. Finally, calls :meth:`update_persistence` and :meth:`BasePersistence.flush` on :attr:`persistence`, if set. @@ -447,6 +446,11 @@ async def stop(self) -> None: Once this method is called, no more updates will be fetched from :attr:`update_queue`, even if it's not empty. + Note: + This does *not* stop :attr:`updater`. You need to either manually call + :meth:`telegram.ext.Updater.stop` or use one of :meth:`run_polling` or + :meth:`run_webhook`. + Raises: :exc:`RuntimeError`: If the application is not running. """ @@ -718,10 +722,16 @@ async def process_update(self, update: object) -> None: context = self.context_types.context.from_update(update, self) await context.refresh_data() coroutine: Coroutine = handler.handle_update(update, self, check, context) - if handler.block: - await coroutine - else: + + if not handler.block or ( + handler.block is DEFAULT_TRUE + and isinstance(self.bot, ExtBot) + and self.bot.defaults + and not self.bot.defaults.block + ): self.create_task(coroutine, update=update) + else: + await coroutine break # Stop processing with any other handler. diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 5e40c3d04ed..fb11624836d 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -502,6 +502,9 @@ async def stop(self, wait: bool = True) -> None: have finished. Defaults to :obj:`True`. """ + # the interface methods of AsyncIOExecutor are currently not really asyncio-compatible + # so we apply some small tweaks here to try and smoothen the integration into PTB + # TODO: When APS 4.0 hits, we should be able to remove the tweaks if wait: # Unfortunately AsyncIOExecutor just cancels them all ... await asyncio.gather( @@ -510,6 +513,9 @@ async def stop(self, wait: bool = True) -> None: ) if self.scheduler.running: self.scheduler.shutdown(wait=wait) + # scheduler.shutdown schedules a task in the event loop but immediatel returns + # so give it a tiny bit of time to actually shut down. + await asyncio.sleep(0.01) def jobs(self) -> Tuple['Job', ...]: """Returns a tuple of all *scheduled* jobs that are currently in the :class:`JobQueue`.""" diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index cc41d786c96..a12252e9ffe 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -101,7 +101,7 @@ def running(self) -> bool: async def initialize(self) -> None: if self._initialized: - self._logger.warning('This Updater is already initialized.') + self._logger.debug('This Updater is already initialized.') return await self.bot.initialize() diff --git a/tests/test_application.py b/tests/test_application.py index b5e366d764f..e2c3daa1a6a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -351,6 +351,7 @@ async def callback(u, c): assert not app.running assert not app.updater.running + assert not app.job_queue.scheduler.running app.add_handler(TypeHandler(object, callback)) await app.update_queue.put(1) @@ -361,6 +362,7 @@ async def callback(u, c): async with app: await app.start() assert app.running + assert app.job_queue.scheduler.running assert not app.updater.running await asyncio.sleep(0.05) assert app.update_queue.empty() @@ -369,6 +371,8 @@ async def callback(u, c): await app.stop() assert not app.running assert not app.updater.running + print(app.job_queue.scheduler.running) + assert not app.job_queue.scheduler.running await app.update_queue.put(2) await asyncio.sleep(0.05) assert not app.update_queue.empty() @@ -971,7 +975,7 @@ async def error_handler(update, context): @pytest.mark.parametrize(['block', 'expected_output'], [(False, 0), (True, 5)]) @pytest.mark.asyncio - async def test_default_block_error_handler(self, bot, monkeypatch, block, expected_output): + async def test_default_block_error_handler(self, bot, block, expected_output): async def error_handler(*args, **kwargs): await asyncio.sleep(0.1) self.count = 5 @@ -985,65 +989,47 @@ async def error_handler(*args, **kwargs): await asyncio.sleep(0.1) assert self.count == 5 - # - # @pytest.mark.parametrize( - # ['block', 'expected_output'], [(True, 'running async'), (False, None)] - # ) - # def test_default_run_async(self, monkeypatch, app, block, expected_output): - # def mock_run_async(*args, **kwargs): - # self.received = 'running async' - # - # # set defaults value to app.bot - # app.bot._defaults = Defaults(block=block) - # try: - # app.add_handler(MessageHandler(filters.ALL, lambda u, c: None)) - # monkeypatch.setattr(app, 'block', mock_run_async) - # app.process_update(self.message_update) - # assert self.received == expected_output - # - # finally: - # # reset defaults value - # app.bot._defaults = None - # - # def test_async_handler_error_handler_that_raises_error(self, app, caplog): - # handler = MessageHandler(filters.ALL, self.callback_raise_error, block=True) - # app.add_handler(handler) - # app.add_error_handler(self.error_handler_raise_error, block=False) - # - # with caplog.at_level(logging.ERROR): - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert len(caplog.records) == 1 - # assert ( - # caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') - # ) - # - # # Make sure that the main loop still runs - # app.remove_handler(handler) - # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.count == 1 - # - # def test_async_handler_async_error_handler_that_raises_error(self, app, caplog): - # handler = MessageHandler(filters.ALL, self.callback_raise_error, block=True) - # app.add_handler(handler) - # app.add_error_handler(self.error_handler_raise_error, block=True) - # - # with caplog.at_level(logging.ERROR): - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert len(caplog.records) == 1 - # assert ( - # caplog.records[-1].getMessage().startswith('An error was raised and an uncaught') - # ) - # - # # Make sure that the main loop still runs - # app.remove_handler(handler) - # app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) - # await app.update_queue.put(self.message_update) - # await asyncio.sleep(0.05) - # assert self.count == 1 + @pytest.mark.parametrize(['block', 'expected_output'], [(False, 0), (True, 5)]) + @pytest.mark.asyncio + async def test_default_block_handler(self, bot, block, expected_output): + app = Application.builder().token(bot.token).defaults(Defaults(block=block)).build() + app.add_handler(TypeHandler(object, self.callback_set_count(5, sleep=0.1))) + await app.process_update(1) + await asyncio.sleep(0.05) + assert self.count == expected_output + await asyncio.sleep(0.15) + assert self.count == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize('handler_block', (True, False)) + @pytest.mark.parametrize('error_handler_block', (True, False)) + async def test_nonblocking_handler_raises_and_non_blocking_error_handler_raises( + self, app, caplog, handler_block, error_handler_block + ): + handler = TypeHandler(object, self.callback_raise_error, block=handler_block) + app.add_handler(handler) + app.add_error_handler(self.error_handler_raise_error, block=error_handler_block) + + async with app: + await app.start() + with caplog.at_level(logging.ERROR): + await app.update_queue.put(1) + await asyncio.sleep(0.05) + assert len(caplog.records) == 1 + assert ( + caplog.records[-1] + .getMessage() + .startswith('An error was raised and an uncaught') + ) + + # Make sure that the main loop still runs + app.remove_handler(handler) + app.add_handler(MessageHandler(filters.ALL, self.callback_increase_count, block=True)) + await app.update_queue.put(self.message_update) + await asyncio.sleep(0.05) + assert self.count == 1 + + await app.stop() @pytest.mark.parametrize( 'message', @@ -1110,8 +1096,154 @@ def test_drop_user_data(self, app, u_id, expected): app.drop_user_data(u_id) assert app.user_data == expected + @pytest.mark.asyncio + async def test_create_task_basic(self, app): + async def callback(): + await asyncio.sleep(0.05) + self.count = 42 + return 43 + + task = app.create_task(callback()) + await asyncio.sleep(0.01) + assert not task.done() + out = await task + assert task.done() + assert self.count == 42 + assert out == 43 + + @pytest.mark.asyncio + @pytest.mark.parametrize('running', (True, False)) + async def test_create_task_awaiting_warning(self, app, running, caplog): + async def callback(): + await asyncio.sleep(0.1) + return 43 + + async with app: + if running: + await app.start() + + with caplog.at_level(logging.WARNING): + task = app.create_task(callback()) + + if running: + assert len(caplog.records) == 0 + assert not task.done() + await app.stop() + assert task.done() + assert task.result() == 43 + else: + assert len(caplog.records) == 1 + assert "won't be automatically awaited" in caplog.records[-1].getMessage() + assert not task.done() + await task + + @pytest.mark.asyncio + @pytest.mark.parametrize('update', (None, object())) + async def test_create_task_error_handling(self, app, update): + exception = RuntimeError('TestError') + + async def callback(): + raise exception + + async def error(update_arg, context): + self.received = update_arg, context.error + + app.add_error_handler(error) + if update: + task = app.create_task(callback(), update=update) + else: + task = app.create_task(callback()) + + with pytest.raises(RuntimeError, match='TestError'): + await task + assert task.exception() is exception + assert isinstance(self.received, tuple) + assert self.received[0] is update + assert self.received[1] is exception + + @pytest.mark.asyncio + async def test_await_create_task_tasks_on_stop(self, app): + async def callback_1(): + await asyncio.sleep(0.5) + + async def callback_2(): + await asyncio.sleep(0.1) + + async with app: + await app.start() + task_1 = app.create_task(callback_1()) + task_2 = app.create_task(callback_2()) + await task_2 + assert not task_1.done() + stop_task = asyncio.create_task(app.stop()) + assert not stop_task.done() + await asyncio.sleep(0.3) + assert not stop_task.done() + await asyncio.sleep(0.15) + assert stop_task.done() + + @pytest.mark.asyncio + async def test_no_concurrent_updates(self, app): + queue = asyncio.Queue() + event_1 = asyncio.Event() + event_2 = asyncio.Event() + await queue.put(event_1) + await queue.put(event_2) + + async def callback(u, c): + await asyncio.sleep(0.1) + event = await queue.get() + event.set() + + app.add_handler(TypeHandler(object, callback)) + async with app: + await app.start() + await app.update_queue.put(1) + await app.update_queue.put(2) + assert not event_1.is_set() + assert not event_2.is_set() + await asyncio.sleep(0.15) + assert event_1.is_set() + assert not event_2.is_set() + await asyncio.sleep(0.1) + assert event_1.is_set() + assert event_2.is_set() + + await app.stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize('concurrent_updates', (True, 15, 50, 256)) + async def test_concurrent_updates(self, bot, concurrent_updates): + app = Application.builder().bot(bot).concurrent_updates(concurrent_updates).build() + events = {i: asyncio.Event() for i in range(app.concurrent_updates + 10)} + queue = asyncio.Queue() + for event in events.values(): + await queue.put(event) + + async def callback(u, c): + await asyncio.sleep(0.5) + (await queue.get()).set() + + app.add_handler(TypeHandler(object, callback)) + async with app: + await app.start() + for i in range(app.concurrent_updates + 10): + await app.update_queue.put(i) + + for i in range(app.concurrent_updates + 10): + assert not events[i].is_set() + + await asyncio.sleep(0.9) + for i in range(app.concurrent_updates): + assert events[i].is_set() + for i in range(app.concurrent_updates, app.concurrent_updates + 10): + assert not events[i].is_set() + + await asyncio.sleep(0.5) + for i in range(app.concurrent_updates + 10): + assert events[i].is_set() + + await app.stop() + # TODO: - # * Test stop() with updater running # * Test run_polling/webhook - # * Test concurrent updates - # * Test create_task diff --git a/tests/test_persistence_integration.py b/tests/test_persistence_integration.py index ab43151f6a1..89166b294b9 100644 --- a/tests/test_persistence_integration.py +++ b/tests/test_persistence_integration.py @@ -30,6 +30,7 @@ class TestPersistenceIntegration: # * Test migrate_chat_data # * Test drop_chat/user_data # * Test update_persistence & flush getting called on shutdown + # * Test the update parameter of create_task def test_construction_with_bad_persistence(self, caplog, bot): class MyPersistence: From ddf4bb34b585335498e90b3b7fde0a0368db19b9 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 7 Mar 2022 21:34:54 +0100 Subject: [PATCH 041/153] Fix tests --- tests/test_application.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index e2c3daa1a6a..166314fc467 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -371,7 +371,6 @@ async def callback(u, c): await app.stop() assert not app.running assert not app.updater.running - print(app.job_queue.scheduler.running) assert not app.job_queue.scheduler.running await app.update_queue.put(2) await asyncio.sleep(0.05) @@ -405,7 +404,6 @@ def two(update, context): pytest.fail('Expected same context object, got different') else: if context is self.received: - print(context, self.received) pytest.fail('First handler was wrongly called') app.add_handler(MessageHandler(filters.Regex('test'), one), group=1) @@ -1214,7 +1212,7 @@ async def callback(u, c): @pytest.mark.asyncio @pytest.mark.parametrize('concurrent_updates', (True, 15, 50, 256)) async def test_concurrent_updates(self, bot, concurrent_updates): - app = Application.builder().bot(bot).concurrent_updates(concurrent_updates).build() + app = Application.builder().token(bot.token).concurrent_updates(concurrent_updates).build() events = {i: asyncio.Event() for i in range(app.concurrent_updates + 10)} queue = asyncio.Queue() for event in events.values(): From ed0ae0ab3698baee05b257634c381b231a38552a Mon Sep 17 00:00:00 2001 From: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 9 Mar 2022 21:11:17 +0100 Subject: [PATCH 042/153] Simplify a type hint Co-authored-by: Harshil <37377066+harshil21@users.noreply.github.com> --- telegram/ext/_applicationbuilder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index ce7d4a78a2d..a7fd0bb91cf 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -732,7 +732,7 @@ def context_types( self._context_types = context_types return self # type: ignore[return-value] - def updater(self: BuilderType, updater: Union[Updater, None]) -> BuilderType: + def updater(self: BuilderType, updater: Optional[Updater]) -> BuilderType: """Sets a :class:`telegram.ext.Updater` instance to be used for :attr:`telegram.ext.Application.updater`. The :attr:`telegram.ext.Updater.bot` and :attr:`telegram.ext.Updater.update_queue` be used for :attr:`telegram.ext.Application.bot` From 30bf4753d1d12fbd620fddcbc22874a2970a8f30 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 9 Mar 2022 22:40:08 +0100 Subject: [PATCH 043/153] Test run_polling/updater --- telegram/ext/_application.py | 27 ++-- tests/conftest.py | 36 +++++- tests/test_application.py | 234 ++++++++++++++++++++++++++++++++++- tests/test_request.py | 2 - tests/test_updater.py | 67 +++------- 5 files changed, 304 insertions(+), 62 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index ffafb5a8a2d..9d984491822 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -501,6 +501,7 @@ def run_polling( allowed_updates: List[str] = None, drop_pending_updates: bool = None, ready: asyncio.Event = None, + close_loop: bool = True, ) -> None: """Temp docstring to make this referencable #TODO: Adda meaningful description @@ -527,6 +528,7 @@ def error_callback(exc: TelegramError) -> None: error_callback=error_callback, ), ready=ready, + close_loop=close_loop, ) def run_webhook( @@ -543,6 +545,7 @@ def run_webhook( ip_address: str = None, max_connections: int = 40, ready: asyncio.Event = None, + close_loop: bool = True, ) -> None: """Temp docstring to make this referencable #TODO: Adda meaningful description @@ -567,22 +570,32 @@ def run_webhook( max_connections=max_connections, ), ready=ready, + close_loop=close_loop, ) - def __run(self, updater_coroutine: Coroutine, ready: asyncio.Event = None) -> None: - # TODO: get_event_loop is deprecated - switch to get_running_loop() - loop = asyncio.get_event_loop() # get_running_loop() + def __run( + self, updater_coroutine: Coroutine, ready: asyncio.Event = None, close_loop: bool = True + ) -> None: + # Calling get_event_loop() should still be okay even in py3.10+ as long as there is a + # running event loop or we are in the main thread, which are the intended use cases. + # See the docs of get_event_loop() and get_running_loop() for more info + loop = asyncio.get_event_loop() loop.run_until_complete(self.initialize()) - loop.run_until_complete(self.start(ready=ready)) loop.run_until_complete(updater_coroutine) + loop.run_until_complete(self.start(ready=ready)) try: loop.run_forever() # TODO: maybe allow for custom exception classes to catch here? Or provide a custom one? except (KeyboardInterrupt, SystemExit): - loop.run_until_complete(self.stop()) - loop.run_until_complete(self.shutdown()) + pass finally: - loop.close() + # We arrive here either by catching the exceptions above or if the loop gets stopped + try: + loop.run_until_complete(self.stop()) + loop.run_until_complete(self.shutdown()) + finally: + if close_loop: + loop.close() def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Task: """Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by diff --git a/tests/conftest.py b/tests/conftest.py index 9c8aa3b0c38..7ee280c7f76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,10 +24,11 @@ import os import re from pathlib import Path -from typing import Callable, List, Iterable, Any, Dict +from typing import Callable, List, Iterable, Any, Dict, Optional import pytest import pytz +from httpx import AsyncClient, Response from telegram import ( Message, @@ -775,3 +776,36 @@ def check_input_media(m: Dict): bot._defaults = None return True + + +async def send_webhook_message( + ip: str, + port: int, + payload_str: Optional[str], + url_path: str = '', + content_len: int = -1, + content_type: str = 'application/json', + get_method: str = None, +) -> Response: + headers = { + 'content-type': content_type, + } + + if not payload_str: + content_len = None + payload = None + else: + payload = bytes(payload_str, encoding='utf-8') + + if content_len == -1: + content_len = len(payload) + + if content_len is not None: + headers['content-length'] = str(content_len) + + url = f'http://{ip}:{port}/{url_path}' + + async with AsyncClient() as client: + return await client.request( + url=url, method=get_method or 'POST', data=payload, headers=headers + ) diff --git a/tests/test_application.py b/tests/test_application.py index 166314fc467..8fb5b6f626c 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -19,10 +19,18 @@ """The integration of persistence into the application is tested in test_persistence_integration. """ import asyncio +import inspect import logging +import os +import platform +import signal +import threading +import time from collections import defaultdict from pathlib import Path from queue import Queue +from random import randrange +from threading import Thread import pytest @@ -47,7 +55,7 @@ from telegram.error import TelegramError from telegram.warnings import PTBUserWarning -from tests.conftest import make_message_update, PROJECT_ROOT_PATH +from tests.conftest import make_message_update, PROJECT_ROOT_PATH, send_webhook_message class CustomContext(CallbackContext): @@ -1243,5 +1251,225 @@ async def callback(u, c): await app.stop() - # TODO: - # * Test run_polling/webhook + @pytest.mark.skipif( + platform.system() == 'Windows', + reason="Can't send signals without stopping whole process on windows", + ) + def test_run_polling_basic(self, app, monkeypatch): + ready_event = threading.Event() + exception_event = threading.Event() + update_event = threading.Event() + exception = TelegramError('This is a test error') + assertions = {} + + async def get_updates(*args, **kwargs): + if exception_event.is_set(): + raise exception + # This makes sure that other coroutines have a chance of running as well + await asyncio.sleep(0) + update_event.set() + return [self.message_update] + + def thread_target(): + ready_event.wait() + + # Check that everything's running + assertions['app_running'] = app.running + assertions['updater_running'] = app.updater.running + assertions['job_queue_running'] = app.job_queue.scheduler.running + + # Check that we're getting updates + update_event.wait() + time.sleep(0.05) + assertions['getting_updates'] = self.count == 42 + + # Check that errors are properly handled during polling + exception_event.set() + time.sleep(0.05) + assertions['exception_handling'] = self.received == exception.message + + os.kill(os.getpid(), signal.SIGINT) + time.sleep(0.1) + + # # Assert that everything has stopped running + assertions['app_not_running'] = not app.running + assertions['updater_not_running'] = not app.updater.running + assertions['job_queue_not_running'] = not app.job_queue.scheduler.running + + monkeypatch.setattr(app.bot, 'get_updates', get_updates) + app.add_error_handler(self.error_handler_context) + app.add_handler(TypeHandler(object, self.callback_set_count(42))) + + thread = Thread(target=thread_target) + thread.start() + app.run_polling(drop_pending_updates=True, ready=ready_event, close_loop=False) + thread.join() + + assert len(assertions) == 8 + for key, value in assertions.items(): + assert value, f"assertion '{key}' failed!" + + def test_run_polling_parameters_passing(self, app, monkeypatch): + ready_event = threading.Event() + + # First check that the default values match and that we have all arguments there + updater_signature = inspect.signature(app.updater.start_polling) + app_signature = inspect.signature(app.run_polling) + + for name, param in updater_signature.parameters.items(): + if name == 'error_callback': + assert name not in app_signature.parameters + continue + assert name in app_signature.parameters + assert param.kind == app_signature.parameters[name].kind + assert param.default == app_signature.parameters[name].default + + # Check that we pass them correctly + async def start_polling(_, **kwargs): + self.received = kwargs + return True + + def thread_target(): + ready_event.wait() + time.sleep(0.1) + os.kill(os.getpid(), signal.SIGINT) + + monkeypatch.setattr(Updater, 'start_polling', start_polling) + thread = Thread(target=thread_target) + thread.start() + app.run_polling(ready=ready_event, close_loop=False) + thread.join() + ready_event.clear() + + assert set(self.received.keys()) == set(updater_signature.parameters.keys()) + for name, param in updater_signature.parameters.items(): + if name == 'error_callback': + assert self.received[name] is not None + else: + assert self.received[name] == param.default + + expected = { + name: name for name in updater_signature.parameters if name != 'error_callback' + } + thread = Thread(target=thread_target) + thread.start() + app.run_polling(ready=ready_event, close_loop=False, **expected) + thread.join() + ready_event.clear() + + assert set(self.received.keys()) == set(updater_signature.parameters.keys()) + assert self.received.pop('error_callback', None) + assert self.received == expected + + @pytest.mark.skipif( + platform.system() == 'Windows', + reason="Can't send signals without stopping whole process on windows", + ) + def test_run_webhook_basic(self, app, monkeypatch): + ready_event = threading.Event() + assertions = {} + + async def delete_webhook(*args, **kwargs): + return True + + async def set_webhook(*args, **kwargs): + return True + + def thread_target(): + ready_event.wait() + + # Check that everything's running + assertions['app_running'] = app.running + assertions['updater_running'] = app.updater.running + assertions['job_queue_running'] = app.job_queue.scheduler.running + + # Check that we're getting updates + loop = asyncio.new_event_loop() + loop.run_until_complete( + send_webhook_message(ip, port, self.message_update.to_json(), 'TOKEN') + ) + loop.close() + time.sleep(0.05) + assertions['getting_updates'] = self.count == 42 + + os.kill(os.getpid(), signal.SIGINT) + time.sleep(0.1) + + # # Assert that everything has stopped running + assertions['app_not_running'] = not app.running + assertions['updater_not_running'] = not app.updater.running + assertions['job_queue_not_running'] = not app.job_queue.scheduler.running + + monkeypatch.setattr(app.bot, 'set_webhook', set_webhook) + monkeypatch.setattr(app.bot, 'delete_webhook', delete_webhook) + app.add_handler(TypeHandler(object, self.callback_set_count(42))) + + thread = Thread(target=thread_target) + thread.start() + + ip = '127.0.0.1' + port = randrange(1024, 49152) + + app.run_webhook( + ip_address=ip, + port=port, + url_path='TOKEN', + drop_pending_updates=True, + ready=ready_event, + close_loop=False, + ) + thread.join() + + assert len(assertions) == 7 + for key, value in assertions.items(): + assert value, f"assertion '{key}' failed!" + + def test_run_webhook_parameters_passing(self, bot, monkeypatch): + # Check that we pass them correctly + + async def start_webhook(_, **kwargs): + self.received = kwargs + return True + + ready_event = threading.Event() + + # First check that the default values match and that we have all arguments there + updater_signature = inspect.signature(Updater.start_webhook) + + monkeypatch.setattr(Updater, 'start_webhook', start_webhook) + app = ApplicationBuilder().token(bot.token).build() + app_signature = inspect.signature(app.run_webhook) + + for name, param in updater_signature.parameters.items(): + if name == 'self': + continue + assert name in app_signature.parameters + assert param.kind == app_signature.parameters[name].kind + assert param.default == app_signature.parameters[name].default + + def thread_target(): + ready_event.wait() + time.sleep(0.1) + os.kill(os.getpid(), signal.SIGINT) + + thread = Thread(target=thread_target) + thread.start() + app.run_webhook(ready=ready_event, close_loop=False) + thread.join() + ready_event.clear() + + assert set(self.received.keys()) == set(updater_signature.parameters.keys()) - {'self'} + for name, param in updater_signature.parameters.items(): + if name == 'self': + continue + assert self.received[name] == param.default + + expected = {name: name for name in updater_signature.parameters if name != 'self'} + thread = Thread(target=thread_target) + thread.start() + app.run_webhook(ready=ready_event, close_loop=False, **expected) + thread.join() + ready_event.clear() + + assert set(self.received.keys()) == set(expected.keys()) + assert self.received == expected diff --git a/tests/test_request.py b/tests/test_request.py index f0c4044a694..ad273d7f86a 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -355,12 +355,10 @@ async def test_multiple_inits_and_shutdowns(self, monkeypatch): class Client(httpx.AsyncClient): def __init__(*args, **kwargs): - print('this is init') orig_init(*args, **kwargs) self.test_flag['init'] += 1 async def aclose(*args, **kwargs): - print('this is aclose') await orig_aclose(*args, **kwargs) self.test_flag['shutdown'] += 1 diff --git a/tests/test_updater.py b/tests/test_updater.py index 4d751625b27..48160c53e12 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -22,10 +22,8 @@ from http import HTTPStatus from pathlib import Path from random import randrange -from typing import Optional import pytest -from httpx import AsyncClient, Response from telegram import ( Bot, @@ -42,7 +40,13 @@ ) from telegram.ext._utils.webhookhandler import WebhookServer from telegram.request import HTTPXRequest -from tests.conftest import make_message_update, make_message, DictBot, data_file +from tests.conftest import ( + make_message_update, + make_message, + DictBot, + data_file, + send_webhook_message, +) class TestUpdater: @@ -71,39 +75,6 @@ def callback(self, update, context): self.received = update.message.text self.cb_handler_called.set() - @staticmethod - async def _send_webhook_message( - ip: str, - port: int, - payload_str: Optional[str], - url_path: str = '', - content_len: int = -1, - content_type: str = 'application/json', - get_method: str = None, - ) -> Response: - headers = { - 'content-type': content_type, - } - - if not payload_str: - content_len = None - payload = None - else: - payload = bytes(payload_str, encoding='utf-8') - - if content_len == -1: - content_len = len(payload) - - if content_len is not None: - headers['content-length'] = str(content_len) - - url = f'http://{ip}:{port}/{url_path}' - - async with AsyncClient() as client: - return await client.request( - url=url, method=get_method or 'POST', data=payload, headers=headers - ) - @pytest.mark.asyncio async def test_slot_behaviour(self, updater, mro_slots): async with updater: @@ -545,15 +516,15 @@ async def set_webhook(*args, **kwargs): # Now, we send an update to the server update = make_message_update('Webhook') - await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') + await send_webhook_message(ip, port, update.to_json(), 'TOKEN') assert (await updater.update_queue.get()).to_dict() == update.to_dict() # Returns Not Found if path is incorrect - response = await self._send_webhook_message(ip, port, '123456', 'webhook_handler.py') + response = await send_webhook_message(ip, port, '123456', 'webhook_handler.py') assert response.status_code == HTTPStatus.NOT_FOUND # Returns METHOD_NOT_ALLOWED if method is not allowed - response = await self._send_webhook_message(ip, port, None, 'TOKEN', get_method='HEAD') + response = await send_webhook_message(ip, port, None, 'TOKEN', get_method='HEAD') assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED await updater.stop() @@ -573,7 +544,7 @@ async def set_webhook(*args, **kwargs): ) assert updater.running update = make_message_update('Webhook') - await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') + await send_webhook_message(ip, port, update.to_json(), 'TOKEN') assert (await updater.update_queue.get()).to_dict() == update.to_dict() await updater.stop() assert not updater.running @@ -711,7 +682,7 @@ async def return_true(*args, **kwargs): user=updater.bot.bot, ) - await self._send_webhook_message(ip, port, update.to_json(), 'TOKEN') + await send_webhook_message(ip, port, update.to_json(), 'TOKEN') received_update = await updater.update_queue.get() assert received_update.update_id == update.update_id @@ -789,7 +760,7 @@ def webhook_server_init(*args, **kwargs): # Now, we send an update to the server update = make_message_update(message='test_message') - await self._send_webhook_message(ip, port, update.to_json()) + await send_webhook_message(ip, port, update.to_json()) assert (await updater.update_queue.get()).to_dict() == update.to_dict() assert self.test_flag == [True, True] await updater.stop() @@ -832,10 +803,10 @@ async def return_true(*args, **kwargs): async with updater: await updater.start_webhook(listen=ip, port=port) - response = await self._send_webhook_message(ip, port, None, content_type='invalid') + response = await send_webhook_message(ip, port, None, content_type='invalid') assert response.status_code == HTTPStatus.FORBIDDEN - response = await self._send_webhook_message( + response = await send_webhook_message( ip, port, payload_str='data', @@ -843,18 +814,16 @@ async def return_true(*args, **kwargs): ) assert response.status_code == HTTPStatus.FORBIDDEN - response = await self._send_webhook_message( - ip, port, 'dummy-payload', content_len=None - ) + response = await send_webhook_message(ip, port, 'dummy-payload', content_len=None) assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR # httpx already complains about bad content length in _send_webhook_message # before the requests below reach the webhook, but not testing this is probably # okay - # response = await self._send_webhook_message( + # response = await send_webhook_message( # ip, port, 'dummy-payload', content_len=-2) # assert response.status_code == HTTPStatus.FORBIDDEN - # response = await self._send_webhook_message( + # response = await send_webhook_message( # ip, port, 'dummy-payload', content_len='not-a-number') # assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR From 24dc3e6a8c3c6c175d1ebb46b9eb0d67fe2994ad Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 08:39:17 +0100 Subject: [PATCH 044/153] skip more tests on windows --- tests/test_application.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_application.py b/tests/test_application.py index 8fb5b6f626c..b503254aded 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1309,6 +1309,10 @@ def thread_target(): for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" + @pytest.mark.skipif( + platform.system() == 'Windows', + reason="Can't send signals without stopping whole process on windows", + ) def test_run_polling_parameters_passing(self, app, monkeypatch): ready_event = threading.Event() @@ -1424,6 +1428,10 @@ def thread_target(): for key, value in assertions.items(): assert value, f"assertion '{key}' failed!" + @pytest.mark.skipif( + platform.system() == 'Windows', + reason="Can't send signals without stopping whole process on windows", + ) def test_run_webhook_parameters_passing(self, bot, monkeypatch): # Check that we pass them correctly From 8ab79f3ef92c663ab6141c32f58565220bb4e51c Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 09:22:05 +0100 Subject: [PATCH 045/153] Small fix for applicationbuilder --- telegram/ext/_applicationbuilder.py | 6 +++++- ...st_builders.py => test_applicationbuilder.py} | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) rename tests/{test_builders.py => test_applicationbuilder.py} (95%) diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index a7fd0bb91cf..3516651d5ff 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -258,7 +258,11 @@ def build( else: bot = self._bot update_queue = DefaultValue.get_value(self._update_queue) - updater = Updater(bot=bot, update_queue=update_queue) + + if self._updater is None: + updater = None + else: + updater = Updater(bot=bot, update_queue=update_queue) else: updater = self._updater bot = self._updater.bot diff --git a/tests/test_builders.py b/tests/test_applicationbuilder.py similarity index 95% rename from tests/test_builders.py rename to tests/test_applicationbuilder.py index cc6a6078d50..78ea9d9958b 100644 --- a/tests/test_builders.py +++ b/tests/test_applicationbuilder.py @@ -354,3 +354,19 @@ def test_all_private_key_input_types(self, builder, bot, input_type): ) bot = builder.build().bot assert bot.private_key + + def test_no_updater(self, bot, builder): + app = builder.token(bot.token).updater(None).build() + assert app.bot.token == bot.token + assert app.updater is None + assert isinstance(app.update_queue, asyncio.Queue) + assert isinstance(app.job_queue, JobQueue) + assert app.job_queue.application is app + + def test_no_job_queue(self, bot, builder): + app = builder.token(bot.token).job_queue(None).build() + assert app.bot.token == bot.token + assert app.bot.token == bot.token + assert app.job_queue is None + assert isinstance(app.update_queue, asyncio.Queue) + assert isinstance(app.updater, Updater) From c97a5323df1c1418a0d3aabe8899ff65191a802d Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 09:46:35 +0100 Subject: [PATCH 046/153] Test a few more edge cases for application --- tests/test_application.py | 69 ++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index b503254aded..7616f877004 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -203,9 +203,9 @@ def test_custom_context_init(self, bot): assert isinstance(application.bot_data, complex) @pytest.mark.asyncio - async def test_initialize(self, bot, monkeypatch): - """Initialization of persistence is tested eslewhere""" - # TODO: do this! + @pytest.mark.asyncio('updater', (True, False)) + async def test_initialize(self, bot, monkeypatch, updater): + """Initialization of persistence is tested test_persistence_integration""" self.test_flag = set() async def initialize_bot(*args, **kwargs): @@ -217,13 +217,17 @@ async def initialize_updater(*args, **kwargs): monkeypatch.setattr(Bot, 'initialize', initialize_bot) monkeypatch.setattr(Updater, 'initialize', initialize_updater) - await ApplicationBuilder().token(bot.token).build().initialize() - assert self.test_flag == {'bot', 'updater'} + if updater: + await ApplicationBuilder().token(bot.token).build().initialize() + assert self.test_flag == {'bot', 'updater'} + else: + await ApplicationBuilder().token(bot.token).updater(None).build().initialize() + assert self.test_flag == {'bot'} @pytest.mark.asyncio - async def test_shutdown(self, bot, monkeypatch): - """Studown of persistence is tested eslewhere""" - # TODO: do this! + @pytest.mark.parametrize('updater', (True, False)) + async def test_shutdown(self, bot, monkeypatch, updater): + """Shutdown of persistence is tested in test_persistence_integration""" self.test_flag = set() async def shutdown_bot(*args, **kwargs): @@ -235,9 +239,14 @@ async def shutdown_updater(*args, **kwargs): monkeypatch.setattr(Bot, 'shutdown', shutdown_bot) monkeypatch.setattr(Updater, 'shutdown', shutdown_updater) - async with ApplicationBuilder().token(bot.token).build(): - pass - assert self.test_flag == {'bot', 'updater'} + if updater: + async with ApplicationBuilder().token(bot.token).build(): + pass + assert self.test_flag == {'bot', 'updater'} + else: + async with ApplicationBuilder().token(bot.token).updater(None).build(): + pass + assert self.test_flag == {'bot'} @pytest.mark.asyncio async def test_multiple_inits_and_shutdowns(self, app, monkeypatch): @@ -352,14 +361,23 @@ def test_builder(self, app): builder_2.token(app.bot.token) @pytest.mark.asyncio - async def test_start_stop_processing_updates(self, app): + @pytest.mark.parametrize('job_queue', (True, False)) + async def test_start_stop_processing_updates(self, bot, job_queue): # TODO: repeat a similar test for create_task, persistence processing and job queue + if job_queue: + app = ApplicationBuilder().token(bot.token).build() + else: + app = ApplicationBuilder().token(bot.token).job_queue(None).build() + async def callback(u, c): self.received = u assert not app.running assert not app.updater.running - assert not app.job_queue.scheduler.running + if job_queue: + assert not app.job_queue.scheduler.running + else: + assert app.job_queue is None app.add_handler(TypeHandler(object, callback)) await app.update_queue.put(1) @@ -370,7 +388,10 @@ async def callback(u, c): async with app: await app.start() assert app.running - assert app.job_queue.scheduler.running + if job_queue: + assert app.job_queue.scheduler.running + else: + assert app.job_queue is None assert not app.updater.running await asyncio.sleep(0.05) assert app.update_queue.empty() @@ -379,7 +400,10 @@ async def callback(u, c): await app.stop() assert not app.running assert not app.updater.running - assert not app.job_queue.scheduler.running + if job_queue: + assert not app.job_queue.scheduler.running + else: + assert app.job_queue is None await app.update_queue.put(2) await asyncio.sleep(0.05) assert not app.update_queue.empty() @@ -432,9 +456,12 @@ def test_add_handler_errors(self, app): app.add_handler(handler, 'one') @pytest.mark.asyncio - async def test_add_remove_handler(self, app): + @pytest.mark.parametrize('group_empty', (True, False)) + async def test_add_remove_handler(self, app, group_empty): handler = MessageHandler(filters.ALL, self.callback_increase_count) app.add_handler(handler) + if not group_empty: + app.add_handler(handler) async with app: await app.start() @@ -442,6 +469,7 @@ async def test_add_remove_handler(self, app): await asyncio.sleep(0.05) assert self.count == 1 app.remove_handler(handler) + assert (0 in app.handlers) == (not group_empty) await app.update_queue.put(self.message_update) assert self.count == 1 await app.stop() @@ -1481,3 +1509,12 @@ def thread_target(): assert set(self.received.keys()) == set(expected.keys()) assert self.received == expected + + def test_run_without_updater(self, bot): + app = ApplicationBuilder().token(bot.token).updater(None).build() + + with pytest.raises(RuntimeError, match='only available if the application has an Updater'): + app.run_webhook() + + with pytest.raises(RuntimeError, match='only available if the application has an Updater'): + app.run_polling() From 272da63371dfe52410a220c5b5ea02fd1a7f3859 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 09:48:33 +0100 Subject: [PATCH 047/153] Make test suit run on PRs against asyncio branch --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 73a639512b9..0c28ee6ccde 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,6 +4,7 @@ on: branches: - master - v14 + - asyncio push: branches: - master From b1301158aec51c16e2330fb384700bdd327bcc9d Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 17:48:42 +0100 Subject: [PATCH 048/153] Remove doubled assertion --- tests/test_applicationbuilder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_applicationbuilder.py b/tests/test_applicationbuilder.py index 78ea9d9958b..6ee13d1605d 100644 --- a/tests/test_applicationbuilder.py +++ b/tests/test_applicationbuilder.py @@ -366,7 +366,6 @@ def test_no_updater(self, bot, builder): def test_no_job_queue(self, bot, builder): app = builder.token(bot.token).job_queue(None).build() assert app.bot.token == bot.token - assert app.bot.token == bot.token assert app.job_queue is None assert isinstance(app.update_queue, asyncio.Queue) assert isinstance(app.updater, Updater) From 9a347caf58eac769f407747aec60dbfa316a1397 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 18:51:07 +0100 Subject: [PATCH 049/153] Slowly get started on testing persistence integration --- tests/test_persistence_integration.py | 190 ++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/tests/test_persistence_integration.py b/tests/test_persistence_integration.py index 89166b294b9..79cf8d23e84 100644 --- a/tests/test_persistence_integration.py +++ b/tests/test_persistence_integration.py @@ -16,14 +16,178 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import collections +import time +from typing import NamedTuple + import pytest from telegram.ext import ( ApplicationBuilder, PersistenceInput, + BasePersistence, + Application, ) +class TrackingPersistence(BasePersistence): + """A dummy implementation of BasePersistence that will help us a great deal in keeping + the individual tests as short as reasonably possible.""" + + def __init__( + self, + store_data: PersistenceInput = None, + update_interval: float = 60, + fill_data: bool = False, + ): + super().__init__(store_data=store_data, update_interval=update_interval) + self.updated_chat_ids = collections.Counter() + self.updated_user_ids = collections.Counter() + self.refreshed_chat_ids = collections.Counter() + self.refreshed_user_ids = collections.Counter() + self.updated_conversations = collections.defaultdict(collections.Counter) + self.updated_bot_data: bool = False + self.refreshed_bot_data: bool = False + self.updated_callback_data: bool = False + self.flushed = False + + self.chat_data = dict() + self.user_data = dict() + self.conversations = collections.defaultdict(dict) + self.bot_data = dict() + self.callback_data = ([], {}) + + if fill_data: + self.fill() + + def fill(self): + self.chat_data[1]['key'] = 'entry' + self.chat_data[2]['foo'] = 'bar' + self.user_data[1]['key'] = 'entry' + self.user_data[2]['foo'] = 'bar' + self.bot_data['key'] = 'entry' + self.conversations['conv_1'][(1, 1, 1)] = 'STATE_1' + self.conversations['conv_1'][(2, 2, 2)] = 'STATE_2' + self.conversations['conv_2'][(3, 3, 3)] = 'STATE_3' + self.conversations['conv_2'][(4, 4, 4)] = 'STATE_4' + self.callback_data = ( + [('uuid', time.time(), {'uuid4': 'callback_data'})], + {'query_id', 'keyboard_id'}, + ) + + def reset_tracking(self): + self.updated_user_ids.clear() + self.updated_chat_ids.clear() + self.refreshed_chat_ids = collections.Counter() + self.refreshed_user_ids = collections.Counter() + self.updated_conversations.clear() + self.updated_bot_data = False + self.refreshed_bot_data = False + self.updated_callback_data = False + self.flushed = False + + self.chat_data = dict() + self.user_data = dict() + self.conversations = collections.defaultdict(dict) + self.bot_data = None + self.callback_data = None + + async def update_bot_data(self, data): + self.updated_bot_data = True + self.bot_data = data + + async def update_chat_data(self, chat_id: int, data): + self.updated_chat_ids[chat_id] += 1 + self.chat_data[chat_id] = data + + async def update_user_data(self, user_id: int, data): + self.updated_user_ids[user_id] += 1 + self.user_data[user_id] = data + + async def update_conversation(self, name: str, key, new_state): + self.updated_conversations[name][key] += 1 + self.conversations[name][key] = new_state + + async def update_callback_data(self, data): + self.updated_callback_data = True + self.callback_data = data + + async def get_conversations(self, name): + return self.conversations[name] + + async def get_bot_data(self): + return self.bot_data + + async def get_chat_data(self): + return self.chat_data + + async def get_user_data(self): + return self.user_data + + async def get_callback_data(self): + return self.callback_data + + async def drop_chat_data(self, chat_id): + self.chat_data.pop(chat_id, None) + + async def drop_user_data(self, user_id): + self.user_data.pop(user_id, None) + + async def refresh_user_data(self, user_id: int, user_data: dict): + self.refreshed_user_ids[user_id] += 1 + user_data['refreshed'] = True + + async def refresh_chat_data(self, chat_id: int, chat_data: dict): + self.refreshed_chat_ids[chat_id] += 1 + chat_data['refreshed'] = True + + async def refresh_bot_data(self, bot_data: dict): + self.refreshed_bot_data = True + bot_data['refreshed'] = True + + async def flush(self) -> None: + self.flushed = True + + +class PappInput(NamedTuple): + bot_data: bool = None + chat_data: bool = None + user_data: bool = None + callback_data: bool = None + update_interval: float = None + fill_data: bool = False + + +def build_papp( + token: str, store_data: dict = None, update_interval: float = None, fill_data: bool = False +) -> Application: + store_data = PersistenceInput(**store_data) + if update_interval is not None: + persistence = TrackingPersistence( + store_data=store_data, update_interval=update_interval, fill_data=fill_data + ) + else: + persistence = TrackingPersistence(store_data=store_data, fill_data=fill_data) + + return ApplicationBuilder().token(token).persistence(persistence).build() + + +@pytest.fixture(scope='function') +def papp(request, bot) -> Application: + papp_input = request.param + store_data = dict() + if papp_input.bot_data is not None: + store_data['bot_data'] = papp_input.bot_data + if papp_input.chat_data is not None: + store_data['chat_data'] = papp_input.chat_data + if papp_input.user_data is not None: + store_data['user_data'] = papp_input.user_data + if papp_input.callback_data is not None: + store_data['callback_data'] = papp_input.callback_data + + return build_papp(bot.token, store_data=store_data, update_interval=papp_input.update_interval) + + class TestPersistenceIntegration: # TODO: # * Test add_handler with persistent conversationhandler @@ -42,6 +206,32 @@ def __init__(self): ): ApplicationBuilder().bot(bot).persistence(MyPersistence()).build() + @pytest.mark.parametrize( + 'papp', + [PappInput(fill_data=True), PappInput(False, False, False, False, fill_data=True)], + indirect=True, + ) + @pytest.mark.asyncio + async def test_initialization_basic(self, papp: Application): + assert not papp.chat_data + assert not papp.user_data + assert not papp.bot_data + assert papp.bot.callback_data_cache.persistence_data == ([], {}) + async with papp: + # We check just bot_data because we set all to the same value + if papp.persistence.store_data.bot_data: + assert papp.chat_data == papp.persistence.chat_data + assert papp.user_data == papp.persistence.user_data + assert papp.bot_data == papp.persistence.bot_data + assert ( + papp.bot.callback_data_cache.persistence_data == papp.persistence.callback_data + ) + else: + assert not papp.chat_data + assert not papp.user_data + assert not papp.bot_data + assert papp.bot.callback_data_cache.persistence_data == ([], {}) + # # def test_error_while_saving_chat_data(self, bot): # increment = [] From 83d7367cc10de3aba6900d757e62720a6fc7241c Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 21:10:37 +0100 Subject: [PATCH 050/153] also test conversationhandler-persistence-integration --- tests/test_persistence_integration.py | 152 +++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 18 deletions(-) diff --git a/tests/test_persistence_integration.py b/tests/test_persistence_integration.py index 79cf8d23e84..fcddd39843d 100644 --- a/tests/test_persistence_integration.py +++ b/tests/test_persistence_integration.py @@ -17,17 +17,41 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. import collections +import enum +import functools import time from typing import NamedTuple import pytest +from telegram import User, Chat from telegram.ext import ( ApplicationBuilder, PersistenceInput, BasePersistence, Application, + ConversationHandler, + MessageHandler, + filters, + Handler, ) +from tests.conftest import make_message_update + + +class HandlerStates(int, enum.Enum): + END = ConversationHandler.END + STATE_1 = 1 + STATE_2 = 2 + STATE_3 = 3 + STATE_4 = 4 + + def next(self): + cls = self.__class__ + members = list(cls) + index = members.index(self) + 1 + if index >= len(members): + index = 0 + return members[index] class TrackingPersistence(BasePersistence): @@ -51,29 +75,31 @@ def __init__( self.updated_callback_data: bool = False self.flushed = False - self.chat_data = dict() - self.user_data = dict() + self.chat_data = collections.defaultdict(dict) + self.user_data = collections.defaultdict(dict) self.conversations = collections.defaultdict(dict) self.bot_data = dict() - self.callback_data = ([], {}) + self.callback_data = ([], dict()) if fill_data: self.fill() + CALLBACK_DATA = ( + [('uuid', time.time(), {'uuid4': 'callback_data'})], + {'query_id': 'keyboard_id'}, + ) + def fill(self): self.chat_data[1]['key'] = 'entry' self.chat_data[2]['foo'] = 'bar' self.user_data[1]['key'] = 'entry' self.user_data[2]['foo'] = 'bar' self.bot_data['key'] = 'entry' - self.conversations['conv_1'][(1, 1, 1)] = 'STATE_1' - self.conversations['conv_1'][(2, 2, 2)] = 'STATE_2' - self.conversations['conv_2'][(3, 3, 3)] = 'STATE_3' - self.conversations['conv_2'][(4, 4, 4)] = 'STATE_4' - self.callback_data = ( - [('uuid', time.time(), {'uuid4': 'callback_data'})], - {'query_id', 'keyboard_id'}, - ) + self.conversations['conv_1'][(1, 1)] = HandlerStates.STATE_1 + self.conversations['conv_1'][(2, 2)] = HandlerStates.STATE_2 + self.conversations['conv_2'][(3, 3)] = HandlerStates.STATE_3 + self.conversations['conv_2'][(4, 4)] = HandlerStates.STATE_4 + self.callback_data = self.CALLBACK_DATA def reset_tracking(self): self.updated_user_ids.clear() @@ -113,7 +139,7 @@ async def update_callback_data(self, data): self.callback_data = data async def get_conversations(self, name): - return self.conversations[name] + return self.conversations.get(name, {}) async def get_bot_data(self): return self.bot_data @@ -149,11 +175,39 @@ async def flush(self) -> None: self.flushed = True +class TrackingConversationHandler(ConversationHandler): + def __init__(self, *args, **kwargs): + fallbacks = [] + states = {state.value: [self.build_handler(state)] for state in HandlerStates} + entry_points = [self.build_handler(HandlerStates.END)] + super().__init__( + *args, **kwargs, fallbacks=fallbacks, states=states, entry_points=entry_points + ) + + @staticmethod + async def callback(update, context, state): + return state.next() + + @staticmethod + def build_update(state: HandlerStates, chat_id: int): + user = User(id=chat_id, first_name='', is_bot=False) + chat = Chat(id=chat_id, type='') + return make_message_update(message=str(state.value), user=user, chat=chat) + + @classmethod + def build_handler(cls, state: HandlerStates): + return MessageHandler( + filters.Regex(f'^{state.value}$'), + functools.partial(cls.callback, state=state.value), + ) + + class PappInput(NamedTuple): bot_data: bool = None chat_data: bool = None user_data: bool = None callback_data: bool = None + conversations: bool = True update_interval: float = None fill_data: bool = False @@ -172,6 +226,10 @@ def build_papp( return ApplicationBuilder().token(token).persistence(persistence).build() +def build_conversation_handler(name: str, persistent: bool = True) -> Handler: + return TrackingConversationHandler(name=name, persistent=persistent) + + @pytest.fixture(scope='function') def papp(request, bot) -> Application: papp_input = request.param @@ -185,7 +243,21 @@ def papp(request, bot) -> Application: if papp_input.callback_data is not None: store_data['callback_data'] = papp_input.callback_data - return build_papp(bot.token, store_data=store_data, update_interval=papp_input.update_interval) + app = build_papp( + bot.token, + store_data=store_data, + update_interval=papp_input.update_interval, + fill_data=papp_input.fill_data, + ) + + app.add_handlers( + [ + build_conversation_handler(name='conv_1', persistent=papp_input.conversations), + build_conversation_handler(name='conv_2', persistent=papp_input.conversations), + ] + ) + + return app class TestPersistenceIntegration: @@ -208,29 +280,73 @@ def __init__(self): @pytest.mark.parametrize( 'papp', - [PappInput(fill_data=True), PappInput(False, False, False, False, fill_data=True)], + [PappInput(fill_data=True), PappInput(False, False, False, False, False, fill_data=True)], indirect=True, ) @pytest.mark.asyncio async def test_initialization_basic(self, papp: Application): + # Check that no data is there before init assert not papp.chat_data assert not papp.user_data assert not papp.bot_data assert papp.bot.callback_data_cache.persistence_data == ([], {}) + assert not papp.handlers[0][0].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + ) + assert not papp.handlers[0][0].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_2, chat_id=2) + ) + assert not papp.handlers[0][1].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_3, chat_id=3) + ) + assert not papp.handlers[0][1].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_4, chat_id=4) + ) + async with papp: + # Check that data is loaded on init + # We check just bot_data because we set all to the same value if papp.persistence.store_data.bot_data: - assert papp.chat_data == papp.persistence.chat_data - assert papp.user_data == papp.persistence.user_data - assert papp.bot_data == papp.persistence.bot_data + assert papp.chat_data[1]['key'] == 'entry' + assert papp.chat_data[2]['foo'] == 'bar' + assert papp.user_data[1]['key'] == 'entry' + assert papp.user_data[2]['foo'] == 'bar' + assert papp.bot_data == {'key': 'entry'} assert ( - papp.bot.callback_data_cache.persistence_data == papp.persistence.callback_data + papp.bot.callback_data_cache.persistence_data + == TrackingPersistence.CALLBACK_DATA + ) + + assert papp.handlers[0][0].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + ) + assert papp.handlers[0][0].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_2, chat_id=2) + ) + assert papp.handlers[0][1].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_3, chat_id=3) + ) + assert papp.handlers[0][1].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_4, chat_id=4) ) else: assert not papp.chat_data assert not papp.user_data assert not papp.bot_data assert papp.bot.callback_data_cache.persistence_data == ([], {}) + assert not papp.handlers[0][0].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + ) + assert not papp.handlers[0][0].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_2, chat_id=2) + ) + assert not papp.handlers[0][1].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_3, chat_id=3) + ) + assert not papp.handlers[0][1].check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_4, chat_id=4) + ) # # def test_error_while_saving_chat_data(self, bot): From 98d35af0d1df979a26200014b75a72f7f7bce235 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 21:26:46 +0100 Subject: [PATCH 051/153] Actually write some tests --- tests/test_persistence_integration.py | 59 ++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/test_persistence_integration.py b/tests/test_persistence_integration.py index fcddd39843d..81fc8812376 100644 --- a/tests/test_persistence_integration.py +++ b/tests/test_persistence_integration.py @@ -20,6 +20,7 @@ import enum import functools import time +from pathlib import Path from typing import NamedTuple import pytest @@ -35,7 +36,8 @@ filters, Handler, ) -from tests.conftest import make_message_update +from telegram.warnings import PTBUserWarning +from tests.conftest import make_message_update, PROJECT_ROOT_PATH class HandlerStates(int, enum.Enum): @@ -348,6 +350,61 @@ async def test_initialization_basic(self, papp: Application): TrackingConversationHandler.build_update(HandlerStates.STATE_4, chat_id=4) ) + @pytest.mark.parametrize( + 'papp', + [PappInput(fill_data=True)], + indirect=True, + ) + @pytest.mark.asyncio + async def test_initialization_invalid_bot_data(self, papp: Application, monkeypatch): + async def get_bot_data(*args, **kwargs): + return 'invalid' + + monkeypatch.setattr(papp.persistence, 'get_bot_data', get_bot_data) + + with pytest.raises(ValueError, match='bot_data must be'): + await papp.initialize() + + @pytest.mark.parametrize( + 'papp', + [PappInput(fill_data=True)], + indirect=True, + ) + @pytest.mark.parametrize('callback_data', ('invalid', (1, 2, 3))) + @pytest.mark.asyncio + async def test_initialization_invalid_callback_data( + self, papp: Application, callback_data, monkeypatch + ): + async def get_callback_data(*args, **kwargs): + return callback_data + + monkeypatch.setattr(papp.persistence, 'get_callback_data', get_callback_data) + + with pytest.raises(ValueError, match='callback_data must be'): + await papp.initialize() + + @pytest.mark.parametrize( + 'papp', + [PappInput()], + indirect=True, + ) + @pytest.mark.asyncio + async def test_add_conversation_handler_after_init(self, papp: Application, recwarn): + async with papp: + papp.add_handler(build_conversation_handler('name', persistent=True)) + + assert len(recwarn) == 1 + assert recwarn[0].category is PTBUserWarning + assert 'after `Application.initialize` was called' in str(recwarn[-1].message) + assert ( + Path(recwarn[-1].filename) + == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' + ), "incorrect stacklevel!" + + def test_add_conversation_without_persistence(self, app): + with pytest.raises(ValueError, match='if application has no persistence'): + app.add_handler(build_conversation_handler('name', persistent=True)) + # # def test_error_while_saving_chat_data(self, bot): # increment = [] From d6c8652c06977a498fd15f13be225275d203281c Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 10 Mar 2022 21:27:16 +0100 Subject: [PATCH 052/153] Small fix for persistence init --- telegram/ext/_application.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 9d984491822..654e911790a 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -361,7 +361,7 @@ async def _initialize_persistence(self) -> None: if self.persistence.store_data.callback_data: persistent_data = await self.persistence.get_callback_data() if persistent_data is not None: - if not isinstance(persistent_data, tuple) and len(persistent_data) != 2: + if not isinstance(persistent_data, tuple) or len(persistent_data) != 2: raise ValueError('callback_data must be a tuple of length 2') # Mypy doesn't know that persistence.set_bot (see above) already checks that # self.bot is an instance of ExtBot if callback_data should be stored ... @@ -803,7 +803,8 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> warn( 'A persistent `ConversationHandler` was passed to `add_handler`, ' 'after `Application.initialize` was called. Conversation states will not be ' - 'loaded from persistence! ' + 'loaded from persistence!', + stacklevel=1, ) if group not in self.handlers: From 988b300993ae66fb4475863ceffb07f69aeb6a87 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 11 Mar 2022 22:19:51 +0100 Subject: [PATCH 053/153] More persistence integration tests --- tests/test_persistence_integration.py | 396 ++++++++++++++++++++++++-- 1 file changed, 378 insertions(+), 18 deletions(-) diff --git a/tests/test_persistence_integration.py b/tests/test_persistence_integration.py index 81fc8812376..04a823fd79e 100644 --- a/tests/test_persistence_integration.py +++ b/tests/test_persistence_integration.py @@ -16,6 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio import collections import enum import functools @@ -25,7 +26,7 @@ import pytest -from telegram import User, Chat +from telegram import User, Chat, InlineKeyboardMarkup, InlineKeyboardButton from telegram.ext import ( ApplicationBuilder, PersistenceInput, @@ -35,9 +36,10 @@ MessageHandler, filters, Handler, + ApplicationHandlerStop, ) from telegram.warnings import PTBUserWarning -from tests.conftest import make_message_update, PROJECT_ROOT_PATH +from tests.conftest import make_message_update, PROJECT_ROOT_PATH, DictApplication class HandlerStates(int, enum.Enum): @@ -71,6 +73,8 @@ def __init__( self.updated_user_ids = collections.Counter() self.refreshed_chat_ids = collections.Counter() self.refreshed_user_ids = collections.Counter() + self.dropped_chat_ids = collections.Counter() + self.dropped_user_ids = collections.Counter() self.updated_conversations = collections.defaultdict(collections.Counter) self.updated_bot_data: bool = False self.refreshed_bot_data: bool = False @@ -80,8 +84,8 @@ def __init__( self.chat_data = collections.defaultdict(dict) self.user_data = collections.defaultdict(dict) self.conversations = collections.defaultdict(dict) - self.bot_data = dict() - self.callback_data = ([], dict()) + self.bot_data = {} + self.callback_data = ([], {}) if fill_data: self.fill() @@ -92,11 +96,11 @@ def __init__( ) def fill(self): - self.chat_data[1]['key'] = 'entry' + self.chat_data[1]['key'] = 'value' self.chat_data[2]['foo'] = 'bar' - self.user_data[1]['key'] = 'entry' + self.user_data[1]['key'] = 'value' self.user_data[2]['foo'] = 'bar' - self.bot_data['key'] = 'entry' + self.bot_data['key'] = 'value' self.conversations['conv_1'][(1, 1)] = HandlerStates.STATE_1 self.conversations['conv_1'][(2, 2)] = HandlerStates.STATE_2 self.conversations['conv_2'][(3, 3)] = HandlerStates.STATE_3 @@ -106,6 +110,8 @@ def fill(self): def reset_tracking(self): self.updated_user_ids.clear() self.updated_chat_ids.clear() + self.dropped_user_ids.clear() + self.dropped_chat_ids.clear() self.refreshed_chat_ids = collections.Counter() self.refreshed_user_ids = collections.Counter() self.updated_conversations.clear() @@ -114,11 +120,11 @@ def reset_tracking(self): self.updated_callback_data = False self.flushed = False - self.chat_data = dict() - self.user_data = dict() + self.chat_data = {} + self.user_data = {} self.conversations = collections.defaultdict(dict) - self.bot_data = None - self.callback_data = None + self.bot_data = {} + self.callback_data = ([], {}) async def update_bot_data(self, data): self.updated_bot_data = True @@ -156,9 +162,11 @@ async def get_callback_data(self): return self.callback_data async def drop_chat_data(self, chat_id): + self.dropped_chat_ids[chat_id] += 1 self.chat_data.pop(chat_id, None) async def drop_user_data(self, user_id): + self.dropped_user_ids[user_id] += 1 self.user_data.pop(user_id, None) async def refresh_user_data(self, user_id: int, user_data: dict): @@ -200,7 +208,7 @@ def build_update(state: HandlerStates, chat_id: int): def build_handler(cls, state: HandlerStates): return MessageHandler( filters.Regex(f'^{state.value}$'), - functools.partial(cls.callback, state=state.value), + functools.partial(cls.callback, state=state), ) @@ -225,7 +233,14 @@ def build_papp( else: persistence = TrackingPersistence(store_data=store_data, fill_data=fill_data) - return ApplicationBuilder().token(token).persistence(persistence).build() + return ( + ApplicationBuilder() + .token(token) + .persistence(persistence) + .application_class(DictApplication) + .arbitrary_callback_data(True) + .build() + ) def build_conversation_handler(name: str, persistent: bool = True) -> Handler: @@ -235,7 +250,7 @@ def build_conversation_handler(name: str, persistent: bool = True) -> Handler: @pytest.fixture(scope='function') def papp(request, bot) -> Application: papp_input = request.param - store_data = dict() + store_data = {} if papp_input.bot_data is not None: store_data['bot_data'] = papp_input.bot_data if papp_input.chat_data is not None: @@ -269,6 +284,31 @@ class TestPersistenceIntegration: # * Test drop_chat/user_data # * Test update_persistence & flush getting called on shutdown # * Test the update parameter of create_task + # * conversations: pending states, ending conversations, unresolved pending states + + async def job_callback(self, context): + pass + + def handler_callback(self, chat_id: int = None, sleep: float = None): + async def callback(update, context): + if sleep: + await asyncio.sleep(sleep) + + context.user_data['key'] = 'value' + context.chat_data['key'] = 'value' + context.bot_data['key'] = 'value' + + if chat_id: + await context.bot.send_message( + chat_id=chat_id, + text='text', + reply_markup=InlineKeyboardMarkup.from_button( + InlineKeyboardButton(text='text', callback_data='callback_data') + ), + ) + raise ApplicationHandlerStop + + return callback def test_construction_with_bad_persistence(self, caplog, bot): class MyPersistence: @@ -310,11 +350,11 @@ async def test_initialization_basic(self, papp: Application): # We check just bot_data because we set all to the same value if papp.persistence.store_data.bot_data: - assert papp.chat_data[1]['key'] == 'entry' + assert papp.chat_data[1]['key'] == 'value' assert papp.chat_data[2]['foo'] == 'bar' - assert papp.user_data[1]['key'] == 'entry' + assert papp.user_data[1]['key'] == 'value' assert papp.user_data[2]['foo'] == 'bar' - assert papp.bot_data == {'key': 'entry'} + assert papp.bot_data == {'key': 'value'} assert ( papp.bot.callback_data_cache.persistence_data == TrackingPersistence.CALLBACK_DATA @@ -405,7 +445,327 @@ def test_add_conversation_without_persistence(self, app): with pytest.raises(ValueError, match='if application has no persistence'): app.add_handler(build_conversation_handler('name', persistent=True)) - # + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'papp', + [ + PappInput(update_interval=1), + ], + indirect=True, + ) + async def test_update_interval(self, papp: Application, monkeypatch): + call_times = [] + + async def update_persistence(*args, **kwargs): + call_times.append(time.time()) + + monkeypatch.setattr(papp, 'update_persistence', update_persistence) + async with papp: + await papp.start() + await asyncio.sleep(3) + await papp.stop() + + diffs = [j - i for i, j in zip(call_times[:-1], call_times[1:])] + for diff in diffs: + assert diff == pytest.approx(papp.persistence.update_interval, rel=1e-1) + + @pytest.mark.parametrize( + 'papp', + [ + PappInput(), + PappInput(bot_data=False), + PappInput(chat_data=False), + PappInput(user_data=False), + PappInput(callback_data=False), + PappInput(False, False, False, False), + ], + ids=( + 'all_data', + 'no_bot_data', + 'no_chat_data', + 'not_user_data', + 'no_callback_data', + 'no_data', + ), + indirect=True, + ) + @pytest.mark.asyncio + async def test_update_persistence_loop_call_count_update_handling(self, papp: Application): + async with papp: + for _ in range(5): + # second pass processes update in conv_2 + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.END, chat_id=1) + ) + assert not papp.persistence.updated_bot_data + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert not papp.persistence.updated_callback_data + assert not papp.persistence.updated_conversations + + await papp.update_persistence() + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + if papp.persistence.store_data.user_data: + assert papp.persistence.updated_user_ids == {1: 1} + else: + assert not papp.persistence.updated_user_ids + if papp.persistence.store_data.chat_data: + assert papp.persistence.updated_chat_ids == {1: 1} + else: + assert not papp.persistence.updated_chat_ids + assert papp.persistence.updated_conversations == { + 'conv_1': {(1, 1): 1}, + 'conv_2': {(1, 1): 1}, + } + + # Nothing should have been updated after handling nothing + papp.persistence.reset_tracking() + await papp.update_persistence() + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.updated_conversations + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + + # Nothing should have been updated after handling an update without associated + # user/chat_data + papp.persistence.reset_tracking() + await papp.process_update('string_update') + await papp.update_persistence() + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.updated_conversations + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + + @pytest.mark.parametrize( + 'papp', + [ + PappInput(), + PappInput(bot_data=False), + PappInput(chat_data=False), + PappInput(user_data=False), + PappInput(callback_data=False), + PappInput(False, False, False, False), + ], + ids=( + 'all_data', + 'no_bot_data', + 'no_chat_data', + 'not_user_data', + 'no_callback_data', + 'no_data', + ), + indirect=True, + ) + @pytest.mark.asyncio + async def test_update_persistence_loop_call_count_job(self, papp: Application): + async with papp: + papp.job_queue.start() + papp.job_queue.run_once(self.job_callback, when=0.05, chat_id=1, user_id=1) + await asyncio.sleep(0.1) + assert not papp.persistence.updated_bot_data + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert not papp.persistence.updated_callback_data + assert not papp.persistence.updated_conversations + + await papp.update_persistence() + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + if papp.persistence.store_data.user_data: + assert papp.persistence.updated_user_ids == {1: 1} + else: + assert not papp.persistence.updated_user_ids + if papp.persistence.store_data.chat_data: + assert papp.persistence.updated_chat_ids == {1: 1} + else: + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_conversations + + # Nothing should have been updated after no job ran + papp.persistence.reset_tracking() + await papp.update_persistence() + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.updated_conversations + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + + # Nothing should have been updated after running job without associated user/chat_data + papp.persistence.reset_tracking() + papp.job_queue.run_once(self.job_callback, when=0.1) + await asyncio.sleep(0.2) + await papp.update_persistence() + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.updated_conversations + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + + @pytest.mark.parametrize('papp', [PappInput()], indirect=True) + @pytest.mark.asyncio + async def test_calls_on_shutdown(self, papp, chat_id): + papp.add_handler( + MessageHandler(filters.ALL, callback=self.handler_callback(chat_id=chat_id)), group=-1 + ) + + async with papp: + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + ) + + # Make sure this this outside the context manager, which is where shutdown is called! + assert papp.persistence.bot_data == {'key': 'value', 'refreshed': True} + assert papp.persistence.callback_data[1] == {} + assert len(papp.persistence.callback_data[0]) == 1 + assert papp.persistence.user_data == {1: {'key': 'value', 'refreshed': True}} + assert papp.persistence.chat_data == {1: {'key': 'value', 'refreshed': True}} + assert not papp.persistence.conversations + assert papp.persistence.flushed + + @pytest.mark.parametrize( + 'papp', + [ + PappInput(), + PappInput(bot_data=False), + PappInput(chat_data=False), + PappInput(user_data=False), + PappInput(callback_data=False), + PappInput(False, False, False, False), + ], + ids=( + 'all_data', + 'no_bot_data', + 'no_chat_data', + 'not_user_data', + 'no_callback_data', + 'no_data', + ), + indirect=True, + ) + @pytest.mark.asyncio + async def test_update_persistence_loop_saved_data_update_handling( + self, papp: Application, chat_id + ): + papp.add_handler( + MessageHandler(filters.ALL, callback=self.handler_callback(chat_id=chat_id)), group=-1 + ) + + async with papp: + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + ) + assert not papp.persistence.bot_data + assert not papp.persistence.chat_data + assert not papp.persistence.user_data + assert papp.persistence.callback_data == ([], {}) + assert not papp.persistence.conversations + + await papp.update_persistence() + if papp.persistence.store_data.bot_data: + assert papp.persistence.bot_data == {'key': 'value', 'refreshed': True} + else: + assert not papp.persistence.bot_data + if papp.persistence.store_data.callback_data: + assert papp.persistence.callback_data[1] == {} + assert len(papp.persistence.callback_data[0]) == 1 + else: + assert papp.persistence.callback_data == ([], {}) + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + if papp.persistence.store_data.user_data: + assert papp.persistence.user_data == {1: {'key': 'value', 'refreshed': True}} + else: + assert not papp.persistence.user_data + if papp.persistence.store_data.chat_data: + assert papp.persistence.chat_data == {1: {'key': 'value', 'refreshed': True}} + else: + assert not papp.persistence.chat_data + assert not papp.persistence.conversations + + @pytest.mark.parametrize('papp', [PappInput()], indirect=True) + @pytest.mark.parametrize('delay_type', ('job', 'handler', 'task')) + @pytest.mark.asyncio + async def test_update_persistence_loop_async_logic( + self, papp: Application, delay_type: str, chat_id + ): + sleep = 0.1 + update = TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + + async with papp: + if delay_type == 'job': + papp.job_queue.start() + papp.job_queue.run_once(self.job_callback, when=sleep, chat_id=1, user_id=1) + elif delay_type == 'handler': + papp.add_handler( + MessageHandler( + filters.ALL, + self.handler_callback(sleep=sleep), + block=False, + ) + ) + await papp.process_update(update) + else: + papp.create_task(asyncio.sleep(sleep), update=update) + + await papp.update_persistence() + assert papp.persistence.updated_bot_data + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert papp.persistence.updated_callback_data + assert not papp.persistence.updated_conversations + + await asyncio.sleep(sleep + 0.05) + await papp.update_persistence() + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + if papp.persistence.store_data.user_data: + assert papp.persistence.updated_user_ids == {1: 1} + else: + assert not papp.persistence.updated_user_ids + if papp.persistence.store_data.chat_data: + assert papp.persistence.updated_chat_ids == {1: 1} + else: + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_conversations + # def test_error_while_saving_chat_data(self, bot): # increment = [] # From 5509d71cd8497df47ef0e4e1f34c70482d665d84 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 11 Mar 2022 22:20:35 +0100 Subject: [PATCH 054/153] fix smaller bugs that surfaced while testing persistence integration --- telegram/ext/_application.py | 23 +++++++++++++++-------- telegram/ext/_conversationhandler.py | 6 ++++-- telegram/ext/_jobqueue.py | 2 ++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 654e911790a..0858b383470 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -725,6 +725,7 @@ async def process_update(self, update: object) -> None: """ context = None + any_blocking = False for handlers in self.handlers.values(): try: @@ -744,6 +745,7 @@ async def process_update(self, update: object) -> None: ): self.create_task(coroutine, update=update) else: + any_blocking = True await coroutine break @@ -758,7 +760,10 @@ async def process_update(self, update: object) -> None: _logger.debug('Error handler stopped further handlers.') break - self._mark_for_persistence_update(update=update) + if any_blocking: + # Only need to mark the update for persistence if there was at least one + # blocking handler - the non-blocking handlers mark the update again when finished + self._mark_for_persistence_update(update=update) def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: """Register a handler. @@ -1088,15 +1093,17 @@ async def __update_persistence(self) -> None: result = new_state.old_state else: result = new_state.resolve() + else: + result = new_state - effective_new_state = None if result is TrackingDict.DELETED else result - # TODO: Test that we actually pass `None` here in case the conversation had ended, - # i.e. effective_new_state is TrackingDict.DELETED - coroutines.add( - self.persistence.update_conversation( - name=name, key=key, new_state=effective_new_state - ) + effective_new_state = None if result is TrackingDict.DELETED else result + # TODO: Test that we actually pass `None` here in case the conversation had ended, + # i.e. effective_new_state is TrackingDict.DELETED + coroutines.add( + self.persistence.update_conversation( + name=name, key=key, new_state=effective_new_state ) + ) results = await asyncio.gather(*coroutines, return_exceptions=True) _logger.debug('Finished updating persistence.') diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 1d08a17d696..26a8b1580f0 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -39,6 +39,8 @@ ) from telegram import Update +from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram._utils.types import DVInput from telegram.ext import ( CallbackContext, CallbackQueryHandler, @@ -221,7 +223,7 @@ class ConversationHandler(Handler[Update, CCT]): .. versionadded:: 13.2 .. versionchanged:: 14.0 - No longer overrides the handlers settings + No longer *overrides* the handlers settings. Raises: ValueError @@ -278,7 +280,7 @@ def __init__( name: str = None, persistent: bool = False, map_to_parent: Dict[object, object] = None, - block: bool = False, + block: DVInput[bool] = DEFAULT_TRUE, ): # these imports need to be here because of circular import error otherwise from telegram.ext import ( # pylint: disable=import-outside-toplevel diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index fb11624836d..ae4a2aae75c 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -490,6 +490,8 @@ def run_custom( return job def start(self) -> None: + # TODO: Make this async - not needed yet, but it's probably saver to have it async already + # in case future versions need that """Starts the job_queue thread.""" if not self.scheduler.running: self.scheduler.start() From 464253f5e533a9e6780c3f54c72fb138810fdc15 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Mar 2022 11:01:18 +0100 Subject: [PATCH 055/153] Review + according tests --- telegram/ext/_application.py | 15 ++++---- telegram/ext/_applicationbuilder.py | 11 +----- tests/test_application.py | 53 ++++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 18 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 0858b383470..690195fcef1 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -460,10 +460,6 @@ async def stop(self) -> None: self._running = False _logger.info('Application is stopping. This might take a moment.') - if self.updater and self.updater.running: - _logger.debug('Waiting for updater to stop fetching updates') - await self.updater.stop() - # Stop listening for new updates and handle all pending ones await self.update_queue.put(_STOP_SIGNAL) _logger.debug('Waiting for update_queue to join') @@ -591,6 +587,8 @@ def __run( finally: # We arrive here either by catching the exceptions above or if the loop gets stopped try: + # Mypy doesn't know that we already check if updater is None + loop.run_until_complete(self.updater.stop()) # type: ignore[union-attr] loop.run_until_complete(self.stop()) loop.run_until_complete(self.shutdown()) finally: @@ -644,10 +642,13 @@ def __create_task( return task def __create_task_done_callback(self, task: asyncio.Task) -> None: + self.__create_task_tasks.discard(task) # We just retrieve the eventual exception so that asyncio doesn't complain in case # it's not retrieved somewhere else - task.exception() - self.__create_task_tasks.discard(task) + try: + task.exception() + except (asyncio.CancelledError, asyncio.InvalidStateError): + pass async def __create_task_callback( self, @@ -701,7 +702,7 @@ async def _update_fetcher(self) -> None: _logger.debug('Processing update %s', update) if self._concurrent_updates: - asyncio.create_task(self.__process_update_wrapper(update)) + self.create_task(self.__process_update_wrapper(update), update=update) else: await self.__process_update_wrapper(update) diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index 3516651d5ff..aeef5f8f008 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -23,8 +23,6 @@ TypeVar, Generic, TYPE_CHECKING, - Callable, - Any, Dict, Union, Type, @@ -47,18 +45,13 @@ # Type hinting is a bit complicated here because we try to get to a sane level of # leveraging generics and therefore need a number of type variables. -OAppT = TypeVar('OAppT', bound=Union[None, Application]) -AppT = TypeVar('AppT', bound=Application) InBT = TypeVar('InBT', bound=Bot) InJQ = TypeVar('InJQ', bound=Union[None, JobQueue]) -InPT = TypeVar('InPT', bound=Union[None, 'BasePersistence']) -InAppT = TypeVar('InAppT', bound=Union[None, Application]) InCCT = TypeVar('InCCT', bound='CallbackContext') InUD = TypeVar('InUD') InCD = TypeVar('InCD') InBD = TypeVar('InBD') BuilderType = TypeVar('BuilderType', bound='ApplicationBuilder') -CT = TypeVar('CT', bound=Callable[..., Any]) if TYPE_CHECKING: DEF_CCT = CallbackContext.DEFAULT_TYPE # type: ignore[misc] @@ -76,6 +69,7 @@ _BOT_CHECKS = [ ('request', 'request instance'), + ('get_updates_request', 'get_updates_request instance'), ('connection_pool_size', 'connection_pool_size'), ('proxy_url', 'proxy_url'), ('pool_timeout', 'pool_timeout'), @@ -88,7 +82,6 @@ ('get_updates_connect_timeout', 'get_updates_connect_timeout'), ('get_updates_read_timeout', 'get_updates_read_timeout'), ('get_updates_write_timeout', 'get_updates_write_timeout'), - ('get_updates_request', 'get_updates_request instance'), ('base_file_url', 'base_file_url'), ('base_url', 'base_url'), ('token', 'token'), @@ -755,8 +748,6 @@ def updater(self: BuilderType, updater: Optional[Updater]) -> BuilderType: for attr, error in ( (self._bot, 'bot instance'), - (self._request, 'request instance'), - (self._get_updates_request, 'get_updates_request instance'), (self._update_queue, 'update_queue'), ): if not isinstance(attr, DefaultValue): diff --git a/tests/test_application.py b/tests/test_application.py index b503254aded..9080014ec65 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -371,14 +371,17 @@ async def callback(u, c): await app.start() assert app.running assert app.job_queue.scheduler.running + # app.start() should not start the updater! assert not app.updater.running await asyncio.sleep(0.05) assert app.update_queue.empty() assert self.received == 1 + await app.updater.start_polling() await app.stop() assert not app.running - assert not app.updater.running + # app.stop() should not stop the updater! + assert app.updater.running assert not app.job_queue.scheduler.running await app.update_queue.put(2) await asyncio.sleep(0.05) @@ -386,6 +389,8 @@ async def callback(u, c): assert self.received != 2 assert self.received == 1 + await app.updater.stop() + @pytest.mark.asyncio async def test_error_start_stop_twice(self, app): async with app: @@ -1167,6 +1172,32 @@ async def error(update_arg, context): assert self.received[0] is update assert self.received[1] is exception + @pytest.mark.asyncio + async def test_create_task_cancel_task(self, app): + async def callback(): + await asyncio.sleep(1) + + async def error(update_arg, context): + self.received = update_arg, context.error + + app.add_error_handler(error) + async with app: + await app.start() + task = app.create_task(callback()) + await asyncio.sleep(0.05) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + with pytest.raises(asyncio.CancelledError): + assert task.exception() + + # Error handlers should not be called if task was cancelled + assert self.received is None + + # make sure that the cancelled task doesn't block the stopping of the app + await app.stop() + @pytest.mark.asyncio async def test_await_create_task_tasks_on_stop(self, app): async def callback_1(): @@ -1251,6 +1282,26 @@ async def callback(u, c): await app.stop() + @pytest.mark.asyncio + async def test_concurrent_updates_done_on_shutdown(self, bot): + app = Application.builder().token(bot.token).concurrent_updates(True).build() + event = asyncio.Event() + + async def callback(update, context): + await event.wait() + + app.add_handler(TypeHandler(object, callback)) + + async with app: + await app.start() + await app.update_queue.put(1) + stop_task = asyncio.create_task(app.stop()) + await asyncio.sleep(0.1) + assert not stop_task.done() + event.set() + await asyncio.sleep(0.05) + assert stop_task.done() + @pytest.mark.skipif( platform.system() == 'Windows', reason="Can't send signals without stopping whole process on windows", From 14e4277bf9869d1f4717ba1fc0867eb7e25892cf Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Mar 2022 13:53:22 +0100 Subject: [PATCH 056/153] Try fixing tests --- tests/test_application.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_application.py b/tests/test_application.py index 9080014ec65..12b8d363cf3 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1249,8 +1249,10 @@ async def callback(u, c): await app.stop() @pytest.mark.asyncio - @pytest.mark.parametrize('concurrent_updates', (True, 15, 50, 256)) + @pytest.mark.parametrize('concurrent_updates', (15, 50, 100)) async def test_concurrent_updates(self, bot, concurrent_updates): + # We don't test with `True` since the large number of parallel coroutines quickly leads + # to test instabilities app = Application.builder().token(bot.token).concurrent_updates(concurrent_updates).build() events = {i: asyncio.Event() for i in range(app.concurrent_updates + 10)} queue = asyncio.Queue() @@ -1384,12 +1386,16 @@ async def start_polling(_, **kwargs): self.received = kwargs return True + async def stop(_, **kwargs): + return True + def thread_target(): ready_event.wait() time.sleep(0.1) os.kill(os.getpid(), signal.SIGINT) monkeypatch.setattr(Updater, 'start_polling', start_polling) + monkeypatch.setattr(Updater, 'stop', stop) thread = Thread(target=thread_target) thread.start() app.run_polling(ready=ready_event, close_loop=False) @@ -1490,12 +1496,16 @@ async def start_webhook(_, **kwargs): self.received = kwargs return True + async def stop(_, **kwargs): + return True + ready_event = threading.Event() # First check that the default values match and that we have all arguments there updater_signature = inspect.signature(Updater.start_webhook) monkeypatch.setattr(Updater, 'start_webhook', start_webhook) + monkeypatch.setattr(Updater, 'stop', stop) app = ApplicationBuilder().token(bot.token).build() app_signature = inspect.signature(app.run_webhook) From b2a7520606aabf5d5422c442383fa0ff8a3822b2 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Mar 2022 14:34:44 +0100 Subject: [PATCH 057/153] Try harder --- telegram/ext/_application.py | 4 ++++ tests/test_application.py | 15 ++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 690195fcef1..6e54b5ce6f5 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -658,6 +658,10 @@ async def __create_task_callback( ) -> _RT: try: return await coroutine + except asyncio.CancelledError as cancel: + # TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this + # close when we drop py3.7 + raise cancel except Exception as exception: if isinstance(exception, ApplicationHandlerStop): warn( diff --git a/tests/test_application.py b/tests/test_application.py index 12b8d363cf3..9ebf51f9d17 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1175,7 +1175,7 @@ async def error(update_arg, context): @pytest.mark.asyncio async def test_create_task_cancel_task(self, app): async def callback(): - await asyncio.sleep(1) + await asyncio.sleep(10) async def error(update_arg, context): self.received = update_arg, context.error @@ -1200,23 +1200,28 @@ async def error(update_arg, context): @pytest.mark.asyncio async def test_await_create_task_tasks_on_stop(self, app): + event_1 = asyncio.Event() + event_2 = asyncio.Event() + async def callback_1(): - await asyncio.sleep(0.5) + await event_1.wait() async def callback_2(): - await asyncio.sleep(0.1) + await event_2.wait() async with app: await app.start() task_1 = app.create_task(callback_1()) task_2 = app.create_task(callback_2()) + event_2.set() await task_2 assert not task_1.done() stop_task = asyncio.create_task(app.stop()) assert not stop_task.done() - await asyncio.sleep(0.3) + await asyncio.sleep(0.1) assert not stop_task.done() - await asyncio.sleep(0.15) + event_1.set() + await asyncio.sleep(0.05) assert stop_task.done() @pytest.mark.asyncio From 326117a2b40ae435573aaf5978579908ff443933 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Mar 2022 16:01:04 +0100 Subject: [PATCH 058/153] rename test_persistence_integration.py to test_basepersistence.py --- tests/test_application.py | 8 ++++---- ...persistence_integration.py => test_basepersistence.py} | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) rename tests/{test_persistence_integration.py => test_basepersistence.py} (99%) diff --git a/tests/test_application.py b/tests/test_application.py index dc86e0d7d35..8231801e540 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""The integration of persistence into the application is tested in test_persistence_integration. +"""The integration of persistence into the application is tested in test_basepersistence. """ import asyncio import inspect @@ -64,7 +64,7 @@ class CustomContext(CallbackContext): class TestApplication: """The integration of persistence into the application is tested in - test_persistence_integration. + test_basepersistence. """ message_update = make_message_update(message='Text') @@ -205,7 +205,7 @@ def test_custom_context_init(self, bot): @pytest.mark.asyncio @pytest.mark.asyncio('updater', (True, False)) async def test_initialize(self, bot, monkeypatch, updater): - """Initialization of persistence is tested test_persistence_integration""" + """Initialization of persistence is tested test_basepersistence""" self.test_flag = set() async def initialize_bot(*args, **kwargs): @@ -227,7 +227,7 @@ async def initialize_updater(*args, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize('updater', (True, False)) async def test_shutdown(self, bot, monkeypatch, updater): - """Shutdown of persistence is tested in test_persistence_integration""" + """Shutdown of persistence is tested in test_basepersistence""" self.test_flag = set() async def shutdown_bot(*args, **kwargs): diff --git a/tests/test_persistence_integration.py b/tests/test_basepersistence.py similarity index 99% rename from tests/test_persistence_integration.py rename to tests/test_basepersistence.py index 04a823fd79e..8b839c01a2b 100644 --- a/tests/test_persistence_integration.py +++ b/tests/test_basepersistence.py @@ -277,7 +277,10 @@ def papp(request, bot) -> Application: return app -class TestPersistenceIntegration: +class TestBasePersistence: + """Tests basic bahvior of BasePersistence and (most importantly) the integration of persistence + into the Application.""" + # TODO: # * Test add_handler with persistent conversationhandler # * Test migrate_chat_data From 037f53c78894011c80ad2e7c20321da6ab88eef1 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 12 Mar 2022 17:09:46 +0100 Subject: [PATCH 059/153] Add some basic BasePersistence tests & adjust to the dropped replace/insert_bot --- tests/conftest.py | 9 ++- tests/test_basepersistence.py | 122 ++++++++++++++++++++++++++++------ 2 files changed, 109 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7ee280c7f76..dac54b3d1b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -400,10 +400,15 @@ def timezone(tzinfo): @pytest.fixture() def mro_slots(): - def _mro_slots(_class): + def _mro_slots(_class, only_parents: bool = False): + if only_parents: + classes = _class.__class__.__mro__[1:-1] + else: + classes = _class.__class__.__mro__[:-1] + return [ attr - for cls in _class.__class__.__mro__[:-1] + for cls in classes if hasattr(cls, '__slots__') # The Exception class doesn't have slots for attr in cls.__slots__ if attr != '__dict__' # left here for classes which still has __dict__ diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index 8b839c01a2b..ce1b02f6da8 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -18,8 +18,10 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. import asyncio import collections +import copy import enum import functools +import logging import time from pathlib import Path from typing import NamedTuple @@ -150,16 +152,16 @@ async def get_conversations(self, name): return self.conversations.get(name, {}) async def get_bot_data(self): - return self.bot_data + return copy.deepcopy(self.bot_data) async def get_chat_data(self): - return self.chat_data + return copy.deepcopy(self.chat_data) async def get_user_data(self): - return self.user_data + return copy.deepcopy(self.user_data) async def get_callback_data(self): - return self.callback_data + return copy.deepcopy(self.callback_data) async def drop_chat_data(self, chat_id): self.dropped_chat_ids[chat_id] += 1 @@ -313,6 +315,45 @@ async def callback(update, context): return callback + def test_slot_behaviour(self, mro_slots): + inst = TrackingPersistence() + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + # We're interested in BasePersistence, not in the implementation + slots = mro_slots(inst, only_parents=True) + print(slots) + assert len(slots) == len(set(slots)), "duplicate slot" + + @pytest.mark.parametrize('bot_data', (True, False)) + @pytest.mark.parametrize('chat_data', (True, False)) + @pytest.mark.parametrize('user_data', (True, False)) + @pytest.mark.parametrize('callback_data', (True, False)) + def test_init_store_data_update_interval(self, bot_data, chat_data, user_data, callback_data): + store_data = PersistenceInput( + bot_data=bot_data, + chat_data=chat_data, + user_data=user_data, + callback_data=callback_data, + ) + persistence = TrackingPersistence(store_data=store_data, update_interval=3.14) + assert persistence.store_data.bot_data == bot_data + assert persistence.store_data.chat_data == chat_data + assert persistence.store_data.user_data == user_data + assert persistence.store_data.callback_data == callback_data + + def test_abstract_methods(self): + with pytest.raises( + TypeError, + match=( + 'drop_chat_data, drop_user_data, flush, get_bot_data, get_callback_data, ' + 'get_chat_data, get_conversations, ' + 'get_user_data, refresh_bot_data, refresh_chat_data, ' + 'refresh_user_data, update_bot_data, update_callback_data, ' + 'update_chat_data, update_conversation, update_user_data' + ), + ): + BasePersistence() + def test_construction_with_bad_persistence(self, caplog, bot): class MyPersistence: def __init__(self): @@ -448,6 +489,16 @@ def test_add_conversation_without_persistence(self, app): with pytest.raises(ValueError, match='if application has no persistence'): app.add_handler(build_conversation_handler('name', persistent=True)) + @pytest.mark.parametrize( + 'papp', + [PappInput()], + indirect=True, + ) + @pytest.mark.asyncio + async def test_add_conversation_handler_without_name(self, papp: Application): + with pytest.raises(ValueError, match="when handler is unnamed"): + papp.add_handler(build_conversation_handler(name=None, persistent=True)) + @pytest.mark.asyncio @pytest.mark.parametrize( 'papp', @@ -493,7 +544,9 @@ async def update_persistence(*args, **kwargs): indirect=True, ) @pytest.mark.asyncio - async def test_update_persistence_loop_call_count_update_handling(self, papp: Application): + async def test_update_persistence_loop_call_count_update_handling( + self, papp: Application, caplog + ): async with papp: for _ in range(5): # second pass processes update in conv_2 @@ -530,7 +583,11 @@ async def test_update_persistence_loop_call_count_update_handling(self, papp: Ap # Nothing should have been updated after handling nothing papp.persistence.reset_tracking() - await papp.update_persistence() + with caplog.at_level(logging.ERROR): + await papp.update_persistence() + # Make sure that "nothing updated" is not just due to an error + assert not caplog.text + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data assert ( papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data @@ -545,7 +602,10 @@ async def test_update_persistence_loop_call_count_update_handling(self, papp: Ap # user/chat_data papp.persistence.reset_tracking() await papp.process_update('string_update') - await papp.update_persistence() + with caplog.at_level(logging.ERROR): + await papp.update_persistence() + # Make sure that "nothing updated" is not just due to an error + assert not caplog.text assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data assert ( papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data @@ -577,7 +637,7 @@ async def test_update_persistence_loop_call_count_update_handling(self, papp: Ap indirect=True, ) @pytest.mark.asyncio - async def test_update_persistence_loop_call_count_job(self, papp: Application): + async def test_update_persistence_loop_call_count_job(self, papp: Application, caplog): async with papp: papp.job_queue.start() papp.job_queue.run_once(self.job_callback, when=0.05, chat_id=1, user_id=1) @@ -609,7 +669,10 @@ async def test_update_persistence_loop_call_count_job(self, papp: Application): # Nothing should have been updated after no job ran papp.persistence.reset_tracking() - await papp.update_persistence() + with caplog.at_level(logging.ERROR): + await papp.update_persistence() + # Make sure that "nothing updated" is not just due to an error + assert not caplog.text assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data assert ( papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data @@ -624,7 +687,10 @@ async def test_update_persistence_loop_call_count_job(self, papp: Application): papp.persistence.reset_tracking() papp.job_queue.run_once(self.job_callback, when=0.1) await asyncio.sleep(0.2) - await papp.update_persistence() + with caplog.at_level(logging.ERROR): + await papp.update_persistence() + # Make sure that "nothing updated" is not just due to an error + assert not caplog.text assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data assert ( papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data @@ -689,32 +755,48 @@ async def test_update_persistence_loop_saved_data_update_handling( TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) ) assert not papp.persistence.bot_data + assert papp.persistence.bot_data is not papp.bot_data assert not papp.persistence.chat_data + assert papp.persistence.chat_data is not papp.chat_data assert not papp.persistence.user_data + assert papp.persistence.user_data is not papp.user_data assert papp.persistence.callback_data == ([], {}) + assert ( + papp.persistence.callback_data is not papp.bot.callback_data_cache.persistence_data + ) assert not papp.persistence.conversations await papp.update_persistence() + + assert papp.persistence.bot_data is not papp.bot_data if papp.persistence.store_data.bot_data: assert papp.persistence.bot_data == {'key': 'value', 'refreshed': True} else: assert not papp.persistence.bot_data - if papp.persistence.store_data.callback_data: - assert papp.persistence.callback_data[1] == {} - assert len(papp.persistence.callback_data[0]) == 1 + + assert papp.persistence.chat_data is not papp.chat_data + if papp.persistence.store_data.chat_data: + assert papp.persistence.chat_data == {1: {'key': 'value', 'refreshed': True}} + assert papp.persistence.chat_data[1] is not papp.chat_data[1] else: - assert papp.persistence.callback_data == ([], {}) - assert ( - papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data - ) + assert not papp.persistence.chat_data + + assert papp.persistence.user_data is not papp.user_data if papp.persistence.store_data.user_data: assert papp.persistence.user_data == {1: {'key': 'value', 'refreshed': True}} + assert papp.persistence.user_data[1] is not papp.chat_data[1] else: assert not papp.persistence.user_data - if papp.persistence.store_data.chat_data: - assert papp.persistence.chat_data == {1: {'key': 'value', 'refreshed': True}} + + assert ( + papp.persistence.callback_data is not papp.bot.callback_data_cache.persistence_data + ) + if papp.persistence.store_data.callback_data: + assert papp.persistence.callback_data[1] == {} + assert len(papp.persistence.callback_data[0]) == 1 else: - assert not papp.persistence.chat_data + assert papp.persistence.callback_data == ([], {}) + assert not papp.persistence.conversations @pytest.mark.parametrize('papp', [PappInput()], indirect=True) From 82bfb017f4c1d3ec9a8c7515b18203aefe38f3e7 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 13 Mar 2022 16:34:54 +0100 Subject: [PATCH 060/153] Try stabilizing on macOS --- tests/test_basepersistence.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index ce1b02f6da8..e2dc6ec8048 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -27,6 +27,7 @@ from typing import NamedTuple import pytest +from flaky import flaky from telegram import User, Chat, InlineKeyboardMarkup, InlineKeyboardButton from telegram.ext import ( @@ -477,6 +478,8 @@ async def test_add_conversation_handler_after_init(self, papp: Application, recw async with papp: papp.add_handler(build_conversation_handler('name', persistent=True)) + # For debugging if this test fails in CI + print([rec.message for rec in recwarn]) assert len(recwarn) == 1 assert recwarn[0].category is PTBUserWarning assert 'after `Application.initialize` was called' in str(recwarn[-1].message) @@ -499,6 +502,7 @@ async def test_add_conversation_handler_without_name(self, papp: Application): with pytest.raises(ValueError, match="when handler is unnamed"): papp.add_handler(build_conversation_handler(name=None, persistent=True)) + @flaky(3, 1) @pytest.mark.asyncio @pytest.mark.parametrize( 'papp', @@ -508,6 +512,8 @@ async def test_add_conversation_handler_without_name(self, papp: Application): indirect=True, ) async def test_update_interval(self, papp: Application, monkeypatch): + """If we don't want this test to take much longer to run, the accuracy will be a bit low. + A few tenths of seconds are easy to go astray ... That's why it's flaky.""" call_times = [] async def update_persistence(*args, **kwargs): @@ -640,8 +646,8 @@ async def test_update_persistence_loop_call_count_update_handling( async def test_update_persistence_loop_call_count_job(self, papp: Application, caplog): async with papp: papp.job_queue.start() - papp.job_queue.run_once(self.job_callback, when=0.05, chat_id=1, user_id=1) - await asyncio.sleep(0.1) + papp.job_queue.run_once(self.job_callback, when=0.1, chat_id=1, user_id=1) + await asyncio.sleep(0.2) assert not papp.persistence.updated_bot_data assert not papp.persistence.updated_chat_ids assert not papp.persistence.updated_user_ids @@ -833,7 +839,7 @@ async def test_update_persistence_loop_async_logic( assert papp.persistence.updated_callback_data assert not papp.persistence.updated_conversations - await asyncio.sleep(sleep + 0.05) + await asyncio.sleep(sleep + 0.1) await papp.update_persistence() assert not papp.persistence.dropped_chat_ids assert not papp.persistence.dropped_user_ids From cb7934ec3ecbeebd5187914927538cc7b52f4346 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 13 Mar 2022 16:58:17 +0100 Subject: [PATCH 061/153] Try again --- tests/test_basepersistence.py | 63 ++++++++++++----------------------- 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index e2dc6ec8048..2b56520e21c 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -478,15 +478,18 @@ async def test_add_conversation_handler_after_init(self, papp: Application, recw async with papp: papp.add_handler(build_conversation_handler('name', persistent=True)) - # For debugging if this test fails in CI - print([rec.message for rec in recwarn]) - assert len(recwarn) == 1 - assert recwarn[0].category is PTBUserWarning - assert 'after `Application.initialize` was called' in str(recwarn[-1].message) - assert ( - Path(recwarn[-1].filename) - == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' - ), "incorrect stacklevel!" + assert len(recwarn) >= 1 + found = False + for warning in recwarn: + if 'after `Application.initialize` was called' in str(warning.message): + found = True + assert warning.category is PTBUserWarning + assert ( + Path(warning.filename) + == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' + ), "incorrect stacklevel!" + + assert found def test_add_conversation_without_persistence(self, app): with pytest.raises(ValueError, match='if application has no persistence'): @@ -507,7 +510,7 @@ async def test_add_conversation_handler_without_name(self, papp: Application): @pytest.mark.parametrize( 'papp', [ - PappInput(update_interval=1), + PappInput(update_interval=1.5), ], indirect=True, ) @@ -522,29 +525,23 @@ async def update_persistence(*args, **kwargs): monkeypatch.setattr(papp, 'update_persistence', update_persistence) async with papp: await papp.start() - await asyncio.sleep(3) + await asyncio.sleep(5) await papp.stop() + # Make assertions before calling shutdown, as that calls update_persistence again! diffs = [j - i for i, j in zip(call_times[:-1], call_times[1:])] - for diff in diffs: - assert diff == pytest.approx(papp.persistence.update_interval, rel=1e-1) + assert sum(diffs) / len(diffs) == pytest.approx( + papp.persistence.update_interval, rel=1e-1 + ) @pytest.mark.parametrize( 'papp', [ PappInput(), - PappInput(bot_data=False), - PappInput(chat_data=False), - PappInput(user_data=False), - PappInput(callback_data=False), PappInput(False, False, False, False), ], ids=( 'all_data', - 'no_bot_data', - 'no_chat_data', - 'not_user_data', - 'no_callback_data', 'no_data', ), indirect=True, @@ -626,18 +623,10 @@ async def test_update_persistence_loop_call_count_update_handling( 'papp', [ PappInput(), - PappInput(bot_data=False), - PappInput(chat_data=False), - PappInput(user_data=False), - PappInput(callback_data=False), PappInput(False, False, False, False), ], ids=( 'all_data', - 'no_bot_data', - 'no_chat_data', - 'not_user_data', - 'no_callback_data', 'no_data', ), indirect=True, @@ -646,8 +635,8 @@ async def test_update_persistence_loop_call_count_update_handling( async def test_update_persistence_loop_call_count_job(self, papp: Application, caplog): async with papp: papp.job_queue.start() - papp.job_queue.run_once(self.job_callback, when=0.1, chat_id=1, user_id=1) - await asyncio.sleep(0.2) + papp.job_queue.run_once(self.job_callback, when=1.5, chat_id=1, user_id=1) + await asyncio.sleep(2.5) assert not papp.persistence.updated_bot_data assert not papp.persistence.updated_chat_ids assert not papp.persistence.updated_user_ids @@ -732,18 +721,10 @@ async def test_calls_on_shutdown(self, papp, chat_id): 'papp', [ PappInput(), - PappInput(bot_data=False), - PappInput(chat_data=False), - PappInput(user_data=False), - PappInput(callback_data=False), PappInput(False, False, False, False), ], ids=( 'all_data', - 'no_bot_data', - 'no_chat_data', - 'not_user_data', - 'no_callback_data', 'no_data', ), indirect=True, @@ -811,7 +792,7 @@ async def test_update_persistence_loop_saved_data_update_handling( async def test_update_persistence_loop_async_logic( self, papp: Application, delay_type: str, chat_id ): - sleep = 0.1 + sleep = 1.5 update = TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) async with papp: @@ -839,7 +820,7 @@ async def test_update_persistence_loop_async_logic( assert papp.persistence.updated_callback_data assert not papp.persistence.updated_conversations - await asyncio.sleep(sleep + 0.1) + await asyncio.sleep(sleep + 1) await papp.update_persistence() assert not papp.persistence.dropped_chat_ids assert not papp.persistence.dropped_user_ids From 8fd66593dc000d0f778156831db28d73c5cf44a3 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 13 Mar 2022 20:32:16 +0100 Subject: [PATCH 062/153] Test drop/migrate data --- tests/test_basepersistence.py | 152 ++++++++++++++++++++++------------ 1 file changed, 97 insertions(+), 55 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index 2b56520e21c..24f27048466 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -280,16 +280,28 @@ def papp(request, bot) -> Application: return app +# Decorator shortcuts +default_papp = pytest.mark.parametrize('papp', [PappInput()], indirect=True) +filled_papp = pytest.mark.parametrize('papp', [PappInput(fill_data=True)], indirect=True) +papp_store_all_or_none = pytest.mark.parametrize( + 'papp', + [ + PappInput(), + PappInput(False, False, False, False), + ], + ids=( + 'all_data', + 'no_data', + ), + indirect=True, +) + + class TestBasePersistence: - """Tests basic bahvior of BasePersistence and (most importantly) the integration of persistence - into the Application.""" + """Tests basic behavior of BasePersistence and (most importantly) the integration of + persistence into the Application.""" # TODO: - # * Test add_handler with persistent conversationhandler - # * Test migrate_chat_data - # * Test drop_chat/user_data - # * Test update_persistence & flush getting called on shutdown - # * Test the update parameter of create_task # * conversations: pending states, ending conversations, unresolved pending states async def job_callback(self, context): @@ -468,11 +480,7 @@ async def get_callback_data(*args, **kwargs): with pytest.raises(ValueError, match='callback_data must be'): await papp.initialize() - @pytest.mark.parametrize( - 'papp', - [PappInput()], - indirect=True, - ) + @default_papp @pytest.mark.asyncio async def test_add_conversation_handler_after_init(self, papp: Application, recwarn): async with papp: @@ -495,11 +503,7 @@ def test_add_conversation_without_persistence(self, app): with pytest.raises(ValueError, match='if application has no persistence'): app.add_handler(build_conversation_handler('name', persistent=True)) - @pytest.mark.parametrize( - 'papp', - [PappInput()], - indirect=True, - ) + @default_papp @pytest.mark.asyncio async def test_add_conversation_handler_without_name(self, papp: Application): with pytest.raises(ValueError, match="when handler is unnamed"): @@ -534,18 +538,7 @@ async def update_persistence(*args, **kwargs): papp.persistence.update_interval, rel=1e-1 ) - @pytest.mark.parametrize( - 'papp', - [ - PappInput(), - PappInput(False, False, False, False), - ], - ids=( - 'all_data', - 'no_data', - ), - indirect=True, - ) + @papp_store_all_or_none @pytest.mark.asyncio async def test_update_persistence_loop_call_count_update_handling( self, papp: Application, caplog @@ -619,18 +612,7 @@ async def test_update_persistence_loop_call_count_update_handling( assert not papp.persistence.dropped_chat_ids assert not papp.persistence.dropped_user_ids - @pytest.mark.parametrize( - 'papp', - [ - PappInput(), - PappInput(False, False, False, False), - ], - ids=( - 'all_data', - 'no_data', - ), - indirect=True, - ) + @papp_store_all_or_none @pytest.mark.asyncio async def test_update_persistence_loop_call_count_job(self, papp: Application, caplog): async with papp: @@ -696,7 +678,7 @@ async def test_update_persistence_loop_call_count_job(self, papp: Application, c assert not papp.persistence.dropped_chat_ids assert not papp.persistence.dropped_user_ids - @pytest.mark.parametrize('papp', [PappInput()], indirect=True) + @default_papp @pytest.mark.asyncio async def test_calls_on_shutdown(self, papp, chat_id): papp.add_handler( @@ -707,28 +689,28 @@ async def test_calls_on_shutdown(self, papp, chat_id): await papp.process_update( TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) ) + assert not papp.persistence.updated_bot_data + assert not papp.persistence.updated_callback_data + assert not papp.persistence.updated_user_ids + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_conversations + assert not papp.persistence.flushed # Make sure this this outside the context manager, which is where shutdown is called! + assert papp.persistence.updated_bot_data assert papp.persistence.bot_data == {'key': 'value', 'refreshed': True} + assert papp.persistence.updated_callback_data assert papp.persistence.callback_data[1] == {} assert len(papp.persistence.callback_data[0]) == 1 + assert papp.persistence.updated_user_ids == {1: 1} assert papp.persistence.user_data == {1: {'key': 'value', 'refreshed': True}} + assert papp.persistence.updated_chat_ids == {1: 1} assert papp.persistence.chat_data == {1: {'key': 'value', 'refreshed': True}} + assert not papp.persistence.updated_conversations assert not papp.persistence.conversations assert papp.persistence.flushed - @pytest.mark.parametrize( - 'papp', - [ - PappInput(), - PappInput(False, False, False, False), - ], - ids=( - 'all_data', - 'no_data', - ), - indirect=True, - ) + @papp_store_all_or_none @pytest.mark.asyncio async def test_update_persistence_loop_saved_data_update_handling( self, papp: Application, chat_id @@ -786,7 +768,7 @@ async def test_update_persistence_loop_saved_data_update_handling( assert not papp.persistence.conversations - @pytest.mark.parametrize('papp', [PappInput()], indirect=True) + @default_papp @pytest.mark.parametrize('delay_type', ('job', 'handler', 'task')) @pytest.mark.asyncio async def test_update_persistence_loop_async_logic( @@ -838,6 +820,66 @@ async def test_update_persistence_loop_async_logic( assert not papp.persistence.updated_chat_ids assert not papp.persistence.updated_conversations + @filled_papp + @pytest.mark.asyncio + async def test_drop_chat_data(self, papp: Application): + async with papp: + assert papp.persistence.chat_data == {1: {'key': 'value'}, 2: {'foo': 'bar'}} + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.updated_chat_ids + + papp.drop_chat_data(1) + + assert papp.persistence.chat_data == {1: {'key': 'value'}, 2: {'foo': 'bar'}} + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.updated_chat_ids + + await papp.update_persistence() + + assert papp.persistence.chat_data == {2: {'foo': 'bar'}} + assert papp.persistence.dropped_chat_ids == {1: 1} + assert not papp.persistence.updated_chat_ids + + @filled_papp + @pytest.mark.asyncio + async def test_drop_user_data(self, papp: Application): + async with papp: + assert papp.persistence.user_data == {1: {'key': 'value'}, 2: {'foo': 'bar'}} + assert not papp.persistence.dropped_user_ids + assert not papp.persistence.updated_user_ids + + papp.drop_user_data(1) + + assert papp.persistence.user_data == {1: {'key': 'value'}, 2: {'foo': 'bar'}} + assert not papp.persistence.dropped_user_ids + assert not papp.persistence.updated_user_ids + + await papp.update_persistence() + + assert papp.persistence.user_data == {2: {'foo': 'bar'}} + assert papp.persistence.dropped_user_ids == {1: 1} + assert not papp.persistence.updated_user_ids + + @filled_papp + @pytest.mark.asyncio + async def test_migrate_chat_data(self, papp: Application): + async with papp: + assert papp.persistence.chat_data == {1: {'key': 'value'}, 2: {'foo': 'bar'}} + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.updated_chat_ids + + papp.migrate_chat_data(old_chat_id=1, new_chat_id=2) + + assert papp.persistence.chat_data == {1: {'key': 'value'}, 2: {'foo': 'bar'}} + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.updated_chat_ids + + await papp.update_persistence() + + assert papp.persistence.chat_data == {2: {'key': 'value'}} + assert papp.persistence.dropped_chat_ids == {1: 1} + assert papp.persistence.updated_chat_ids == {2: 1} + # def test_error_while_saving_chat_data(self, bot): # increment = [] # From a3f15f9921d7c2bc4b4681b6ce149eb04a76a817 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 13 Mar 2022 21:19:51 +0100 Subject: [PATCH 063/153] Add a convenience utility to conftest --- tests/conftest.py | 33 +++++++++++++++++++++++++++++++++ tests/test_application.py | 12 +++++++----- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index dac54b3d1b6..120444a05d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -417,6 +417,39 @@ def _mro_slots(_class, only_parents: bool = False): return _mro_slots +def call_after(function: Callable, after: Callable): + """Run a callable after another has executed. Useful when trying to make sure that a function + did actually run, but just monkeypatching it doesn't work because this would break some other + functionality. + + Example usage: + + def test_stuff(self, bot, monkeypatch): + + def after(arg): + # arg is the return value of `send_message` + self.received = arg + + monkeypatch.setattr(bot, 'send_message', call_after(bot.send_message, after) + + """ + if asyncio.iscoroutinefunction(function): + + async def wrapped(*args, **kwargs): + out = await function(*args, **kwargs) + after(out) + return out + + else: + + def wrapped(*args, **kwargs): + out = function(*args, **kwargs) + after(out) + return out + + return wrapped + + async def expect_bad_request(func, message, reason): """ Wrapper for testing bot functions expected to result in an :class:`telegram.error.BadRequest`. diff --git a/tests/test_application.py b/tests/test_application.py index 8231801e540..9c73475619a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -55,7 +55,7 @@ from telegram.error import TelegramError from telegram.warnings import PTBUserWarning -from tests.conftest import make_message_update, PROJECT_ROOT_PATH, send_webhook_message +from tests.conftest import make_message_update, PROJECT_ROOT_PATH, send_webhook_message, call_after class CustomContext(CallbackContext): @@ -230,14 +230,16 @@ async def test_shutdown(self, bot, monkeypatch, updater): """Shutdown of persistence is tested in test_basepersistence""" self.test_flag = set() - async def shutdown_bot(*args, **kwargs): + def after_bot_shutdown(*args, **kwargs): self.test_flag.add('bot') - async def shutdown_updater(*args, **kwargs): + def after_updater_shutdown(*args, **kwargs): self.test_flag.add('updater') - monkeypatch.setattr(Bot, 'shutdown', shutdown_bot) - monkeypatch.setattr(Updater, 'shutdown', shutdown_updater) + monkeypatch.setattr(Bot, 'shutdown', call_after(Bot.shutdown, after_bot_shutdown)) + monkeypatch.setattr( + Updater, 'shutdown', call_after(Updater.shutdown, after_updater_shutdown) + ) if updater: async with ApplicationBuilder().token(bot.token).build(): From 41224509ced353f84731e800db14fd5fa6e41793 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 15 Mar 2022 08:16:51 +0100 Subject: [PATCH 064/153] test Updater.start_* return value --- tests/test_updater.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_updater.py b/tests/test_updater.py index 48160c53e12..8308b20a67c 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -235,7 +235,8 @@ async def delete_webhook(*args, **kwargs): monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) async with updater: - await updater.start_polling(drop_pending_updates=drop_pending_updates) + return_value = await updater.start_polling(drop_pending_updates=drop_pending_updates) + assert return_value is updater.update_queue assert updater.running await updates.join() await updater.stop() @@ -506,12 +507,13 @@ async def set_webhook(*args, **kwargs): port = randrange(1024, 49152) # Select random port async with updater: - await updater.start_webhook( + return_value = await updater.start_webhook( drop_pending_updates=drop_pending_updates, ip_address=ip, port=port, url_path='TOKEN', ) + assert return_value is updater.update_queue assert updater.running # Now, we send an update to the server From 9e6c4d85ce8fb2e4a4c54a2fc07fb36ade2a04fc Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 17 Mar 2022 19:53:29 +0100 Subject: [PATCH 065/153] Large parts of review --- telegram/ext/_application.py | 49 ++++++++-------- telegram/ext/_callbackqueryhandler.py | 7 ++- telegram/ext/_conversationhandler.py | 18 ++++-- telegram/ext/_dictpersistence.py | 6 +- telegram/ext/_updater.py | 80 ++++++++++++--------------- tests/test_application.py | 12 ++-- 6 files changed, 92 insertions(+), 80 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 6e54b5ce6f5..f9e938cca4d 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -21,7 +21,6 @@ import inspect import itertools import logging -from asyncio import Event from collections import defaultdict from copy import deepcopy from pathlib import Path @@ -60,10 +59,11 @@ from telegram import Message from telegram.ext._jobqueue import Job from telegram.ext._applicationbuilder import InitApplicationBuilder + from telegram.ext import ConversationHandler DEFAULT_GROUP: int = 0 -_DispType = TypeVar('_DispType', bound="Application") +_AppType = TypeVar('_AppType', bound="Application") _RT = TypeVar('_RT') _STOP_SIGNAL = object() @@ -326,7 +326,7 @@ async def shutdown(self) -> None: self._initialized = False - async def __aenter__(self: _DispType) -> _DispType: + async def __aenter__(self: _AppType) -> _AppType: try: await self.initialize() return self @@ -382,7 +382,7 @@ def builder() -> 'InitApplicationBuilder': return ApplicationBuilder() - async def start(self, ready: Event = None) -> None: + async def start(self) -> None: """Starts * a background task that fetches updates from :attr:`update_queue` and @@ -395,17 +395,15 @@ async def start(self, ready: Event = None) -> None: This does *not* start fetching updates from Telegram. You need to either start :attr:`updater` manually or use one of :meth:`run_polling` or :meth:`run_webhook`. - Args: - ready (:obj:`asyncio.Event`, optional): If specified, the event will be set once the - application is ready. - Raises: :exc:`RuntimeError`: If the application is already running or was not initialized. """ if self.running: raise RuntimeError('This Application is already running!') if not self._initialized: - raise RuntimeError('This Application is not initialized!') + raise RuntimeError( + 'This Application was not initialized via `Application.initialize`!' + ) self._running = True self.__update_persistence_event.clear() @@ -430,8 +428,6 @@ async def start(self, ready: Event = None) -> None: ) _logger.info('Application started') - if ready is not None: - ready.set() except Exception as exc: self._running = False raise exc @@ -496,7 +492,6 @@ def run_polling( pool_timeout: ODVInput[float] = DEFAULT_NONE, allowed_updates: List[str] = None, drop_pending_updates: bool = None, - ready: asyncio.Event = None, close_loop: bool = True, ) -> None: """Temp docstring to make this referencable @@ -523,7 +518,6 @@ def error_callback(exc: TelegramError) -> None: drop_pending_updates=drop_pending_updates, error_callback=error_callback, ), - ready=ready, close_loop=close_loop, ) @@ -540,7 +534,6 @@ def run_webhook( drop_pending_updates: bool = None, ip_address: str = None, max_connections: int = 40, - ready: asyncio.Event = None, close_loop: bool = True, ) -> None: """Temp docstring to make this referencable @@ -565,20 +558,17 @@ def run_webhook( ip_address=ip_address, max_connections=max_connections, ), - ready=ready, close_loop=close_loop, ) - def __run( - self, updater_coroutine: Coroutine, ready: asyncio.Event = None, close_loop: bool = True - ) -> None: + def __run(self, updater_coroutine: Coroutine, close_loop: bool = True) -> None: # Calling get_event_loop() should still be okay even in py3.10+ as long as there is a # running event loop or we are in the main thread, which are the intended use cases. # See the docs of get_event_loop() and get_running_loop() for more info loop = asyncio.get_event_loop() loop.run_until_complete(self.initialize()) loop.run_until_complete(updater_coroutine) - loop.run_until_complete(self.start(ready=ready)) + loop.run_until_complete(self.start()) try: loop.run_forever() # TODO: maybe allow for custom exception classes to catch here? Or provide a custom one? @@ -770,6 +760,13 @@ async def process_update(self, update: object) -> None: # blocking handler - the non-blocking handlers mark the update again when finished self._mark_for_persistence_update(update=update) + async def _add_ch_after_init(self, handler: 'ConversationHandler') -> None: + self._conversation_handler_conversations[ + handler.name # type: ignore[index] + ] = await handler._initialize_persistence( # pylint: disable=protected-access + self + ) + def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: """Register a handler. @@ -790,6 +787,13 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> group will not be used. The order in which handlers were added to the group defines the priority. + Warning: + Adding persistent :class:`telegram.ext.ConversationHandler` after the application has + been initialized is discouraged. This is because the persisted conversation states need + to be loaded into memory while the application is already processing updates, which + might lead to race conditions and undesired behavior. In particular, current + conversation states may be overridden by the loaded data. + Args: handler (:class:`telegram.ext.Handler`): A Handler instance. group (:obj:`int`, optional): The group identifier. Default is 0. @@ -810,10 +814,11 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> f"can not be persistent if application has no persistence" ) if self._initialized: + self.create_task(self._add_ch_after_init(handler)) warn( 'A persistent `ConversationHandler` was passed to `add_handler`, ' - 'after `Application.initialize` was called. Conversation states will not be ' - 'loaded from persistence!', + 'after `Application.initialize` was called. This is discouraged.' + 'See the docs of `Application.add_handler` for details.', stacklevel=1, ) @@ -970,7 +975,7 @@ def migrate_chat_data( self.drop_chat_data(old_chat_id) self._chat_ids_to_be_updated_in_persistence.add(new_chat_id) - self._chat_ids_to_be_deleted_in_persistence.add(old_chat_id) + # old_chat_id is marked for deletion by drop_chat_data above def _mark_for_persistence_update(self, *, update: object = None, job: 'Job' = None) -> None: # TODO: This should be at the end of `Application.process_update`, when the task created diff --git a/telegram/ext/_callbackqueryhandler.py b/telegram/ext/_callbackqueryhandler.py index 553e1a32b20..96b81a12f6b 100644 --- a/telegram/ext/_callbackqueryhandler.py +++ b/telegram/ext/_callbackqueryhandler.py @@ -17,7 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the CallbackQueryHandler class.""" - +import asyncio import re from typing import ( TYPE_CHECKING, @@ -115,6 +115,11 @@ def __init__( ): super().__init__(callback, block=block) + if callable(pattern) and asyncio.iscoroutinefunction(pattern): + raise ValueError( + 'The `pattern` must not be a coroutine function! Use an ordinary function instead.' + ) + if isinstance(pattern, str): pattern = re.compile(pattern) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 26a8b1580f0..54415947bda 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -542,11 +542,19 @@ async def _initialize_persistence( 'persistence!' ) - self._conversations = cast( - TrackingDict[Tuple[int, ...], object], - TrackingDict(), - ) - self._conversations.update(await application.persistence.get_conversations(self.name)) + with self._conversations_lock: + current_conversations = self._conversations + self._conversations = cast( + TrackingDict[Tuple[int, ...], object], + TrackingDict(), + ) + # In the conversation already processed updates + self._conversations.update(current_conversations) + # above might be partly overridden but that's okay since we warn about that in + # add_handler + self._conversations.update_no_track( + await application.persistence.get_conversations(self.name) + ) for handler in self._child_conversations: await handler._initialize_persistence( # pylint: disable=protected-access diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index 4cb6686a720..f73f0214620 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -421,7 +421,7 @@ async def flush(self) -> None: """ @staticmethod - def _encode_conversations_to_json(conversations: Dict[str, Dict[Tuple, object]]) -> str: + def _encode_conversations_to_json(conversations: Dict[str, ConversationDict]) -> str: """Helper method to encode a conversations dict (that uses tuples as keys) to a JSON-serializable way. Use :meth:`self._decode_conversations_from_json` to decode. @@ -450,12 +450,12 @@ def _decode_conversations_from_json(json_string: str) -> Dict[str, ConversationD :obj:`dict`: The conversations dict after decoding """ tmp = json.loads(json_string) - conversations: Dict[str, Dict[Tuple, object]] = {} + conversations: Dict[str, ConversationDict] = {} for handler, states in tmp.items(): conversations[handler] = {} for key, state in states.items(): conversations[handler][tuple(json.loads(key))] = state - return conversations # type: ignore[return-value] + return conversations @staticmethod def _decode_user_chat_data_from_json(data: str) -> Dict[int, Dict[object, object]]: diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index a12252e9ffe..c3ca9902cf9 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -71,7 +71,7 @@ class Updater: 'bot', '_logger', 'update_queue', - 'last_update_id', + '_last_update_id', '_running', '_initialized', '_httpd', @@ -87,7 +87,7 @@ def __init__( self.bot = bot self.update_queue = update_queue - self.last_update_id = 0 + self._last_update_id = 0 self._running = False self._initialized = False self._httpd: Optional[WebhookServer] = None @@ -187,6 +187,11 @@ async def start_polling( while calling :meth:`telegram.Bot.get_updates` during polling. Defaults to :obj:`None`, in which case errors will be logged. + Note: + The :paramref:`error_callback` must *not* be a coroutine function! If + asynchorous behavior of the callback is wanted, please schedule a task from + within the callback. + Returns: :class:`asyncio.Queue`: The update queue that can be filled from the main thread. @@ -198,7 +203,7 @@ async def start_polling( if self.running: raise RuntimeError('This Updater is already running!') if not self._initialized: - raise RuntimeError('This Updater is not initialized!') + raise RuntimeError('This Updater was not initialized via `Updater.initialize`!') self._running = True @@ -222,7 +227,7 @@ async def start_polling( self._logger.debug('Waiting for polling to start') await polling_ready.wait() - self._logger.debug('Polling to started') + self._logger.debug('Polling updates from Telegram started') return self.update_queue except Exception as exc: @@ -257,7 +262,7 @@ async def _start_polling( async def polling_action_cb() -> bool: updates = await self.bot.get_updates( - offset=self.last_update_id, + offset=self._last_update_id, timeout=timeout, read_timeout=read_timeout, connect_timeout=connect_timeout, @@ -275,7 +280,7 @@ async def polling_action_cb() -> bool: else: for update in updates: await self.update_queue.put(update) - self.last_update_id = updates[-1].update_id + 1 + self._last_update_id = updates[-1].update_id + 1 return True @@ -445,25 +450,15 @@ async def _start_webhook( ) # We pass along the cert to the webhook if present. - if cert is not None: - await self._bootstrap( - cert=cert, - max_retries=bootstrap_retries, - drop_pending_updates=drop_pending_updates, - webhook_url=webhook_url, - allowed_updates=allowed_updates, - ip_address=ip_address, - max_connections=max_connections, - ) - else: - await self._bootstrap( - max_retries=bootstrap_retries, - drop_pending_updates=drop_pending_updates, - webhook_url=webhook_url, - allowed_updates=allowed_updates, - ip_address=ip_address, - max_connections=max_connections, - ) + await self._bootstrap( + cert=cert, + max_retries=bootstrap_retries, + drop_pending_updates=drop_pending_updates, + webhook_url=webhook_url, + allowed_updates=allowed_updates, + ip_address=ip_address, + max_connections=max_connections, + ) await self._httpd.serve_forever(ready=ready) @@ -514,7 +509,9 @@ async def _network_loop_retry( except TelegramError as telegram_exc: self._logger.error('Error while %s: %s', description, telegram_exc) on_err_cb(telegram_exc) - cur_interval = self._increase_poll_interval(cur_interval) + + # increase waiting times on subsequent errors up to 30secs + cur_interval = 1 if cur_interval == 0 else min(30, 1.5 * cur_interval) else: cur_interval = interval @@ -525,17 +522,6 @@ async def _network_loop_retry( self._logger.debug('Network loop retry %s was cancelled', description) break - @staticmethod - def _increase_poll_interval(current_interval: float) -> float: - # increase waiting times on subsequent errors up to 30secs - if current_interval == 0: - current_interval = 1 - elif current_interval < 30: - current_interval *= 1.5 - else: - current_interval = min(30.0, current_interval) - return current_interval - async def _bootstrap( self, max_retries: int, @@ -547,7 +533,7 @@ async def _bootstrap( ip_address: str = None, max_connections: int = 40, ) -> None: - retries = [0] + retries = 0 async def bootstrap_del_webhook() -> bool: self._logger.debug('Deleting webhook') @@ -571,13 +557,17 @@ async def bootstrap_set_webhook() -> bool: return False def bootstrap_on_err_cb(exc: Exception) -> None: - if not isinstance(exc, InvalidToken) and (max_retries < 0 or retries[0] < max_retries): - retries[0] += 1 + # We need this since retries is an immutable object otherwise and the changes + # wouldn't propagate outside of thi function + nonlocal retries + + if not isinstance(exc, InvalidToken) and (max_retries < 0 or retries < max_retries): + retries += 1 self._logger.warning( - 'Failed bootstrap phase; try=%s max_retries=%s', retries[0], max_retries + 'Failed bootstrap phase; try=%s max_retries=%s', retries, max_retries ) else: - self._logger.error('Failed bootstrap phase after %s retries (%s)', retries[0], exc) + self._logger.error('Failed bootstrap phase after %s retries (%s)', retries, exc) raise exc # Dropping pending updates from TG can be efficiently done with the drop_pending_updates @@ -591,7 +581,9 @@ def bootstrap_on_err_cb(exc: Exception) -> None: 'bootstrap del webhook', bootstrap_interval, ) - retries[0] = 0 + + # Reset the retries counter for the next _network_loop_retry call + retries = 0 # Restore/set webhook settings, if needed. Again, we don't know ahead if a webhook is set, # so we set it anyhow. @@ -630,7 +622,7 @@ async def _stop_httpd(self) -> None: async def _stop_polling(self) -> None: if self.__polling_task: - self._logger.debug('Waiting background polling task to join.') + self._logger.debug('Waiting background polling task to finish up.') self.__polling_task.cancel() try: await self.__polling_task diff --git a/tests/test_application.py b/tests/test_application.py index 9ebf51f9d17..0ce7cbf3472 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -285,14 +285,16 @@ async def test_shutdown_while_running(self, app): await app.stop() @pytest.mark.asyncio - async def test_start_not_running_after_failure(self, app): - class Event(asyncio.Event): - def set(self) -> None: - raise Exception('Test Exception') + async def test_start_not_running_after_failure(self, bot, monkeypatch): + def start(_): + raise Exception('Test Exception') + + monkeypatch.setattr(JobQueue, 'start', start) + app = ApplicationBuilder().token(bot.token).job_queue(JobQueue()).build() async with app: with pytest.raises(Exception, match='Test Exception'): - await app.start(ready=Event()) + await app.start() assert app.running is False @pytest.mark.asyncio From cee0d3c49a6af258f12e71fe4b491b07c1d698bf Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Mar 2022 21:25:32 +0100 Subject: [PATCH 066/153] More review --- telegram/ext/_updater.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index c3ca9902cf9..97c18995260 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -624,8 +624,5 @@ async def _stop_polling(self) -> None: if self.__polling_task: self._logger.debug('Waiting background polling task to finish up.') self.__polling_task.cancel() - try: - await self.__polling_task - except asyncio.CancelledError: - pass + await self.__polling_task self.__polling_task = None From 234858c4049bd9655a46f869ccc5d9e19cf7596d Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Mar 2022 21:55:26 +0100 Subject: [PATCH 067/153] make meta-tests path-agnostic --- tests/test_meta.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_meta.py b/tests/test_meta.py index 7aceda7f936..4d85c5f4dce 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -22,16 +22,22 @@ from tests.conftest import env_var_2_bool - -@pytest.mark.skipif( +skip_disabled = pytest.mark.skipif( not env_var_2_bool(os.getenv('TEST_BUILD', False)), reason='TEST_BUILD not enabled' ) + + +# To make the tests agnostic of the cwd +@pytest.fixture(autouse=True) +def change_test_dir(request, monkeypatch): + monkeypatch.chdir(request.config.rootdir) + + +@skip_disabled def test_build(): assert os.system('python setup.py bdist_dumb') == 0 # pragma: no cover -@pytest.mark.skipif( - not env_var_2_bool(os.getenv('TEST_BUILD', False)), reason='TEST_BUILD not enabled' -) +@skip_disabled def test_build_raw(): assert os.system('python setup-raw.py bdist_dumb') == 0 # pragma: no cover From 3ef4d67b104d2bf3ea38b2308d9a13502988c6ba Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Mar 2022 22:06:38 +0100 Subject: [PATCH 068/153] fix join request test --- tests/test_bot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_bot.py b/tests/test_bot.py index 67cec4f3b36..23bbdfe1f58 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -2301,7 +2301,10 @@ async def test_decline_chat_join_request(self, bot, chat_id, channel_id): # TODO: Need incoming join request to properly test # Since we can't create join requests on the fly, we just tests the call to TG # by checking that it complains about declining a user who is already in the chat - with pytest.raises(BadRequest, match='User_already_participant'): + # + # The error message Hide_requester_missing started showing up instead of + # User_already_participant. Don't know why … + with pytest.raises(BadRequest, match='User_already_participant|Hide_requester_missing'): await bot.decline_chat_join_request(chat_id=channel_id, user_id=chat_id) @flaky(3, 1) From 745bfd15f6872a4d2e0f94652d32a23f6365a0de Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Mar 2022 22:26:15 +0100 Subject: [PATCH 069/153] Remove a sleep --- tests/test_updater.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_updater.py b/tests/test_updater.py index 8308b20a67c..678e315dd0c 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -171,7 +171,6 @@ async def set_webhook(*args, **kwargs): port=port, ) else: - await asyncio.sleep(1) await getattr(updater, method)() with pytest.raises(RuntimeError, match='still running'): From 02940bb1f6e5a210b608c1551598bf2c9ed03f9f Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 18 Mar 2022 22:26:40 +0100 Subject: [PATCH 070/153] revert catching cancelledError on polling stop --- telegram/ext/_updater.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 97c18995260..e07fe564a8b 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -624,5 +624,12 @@ async def _stop_polling(self) -> None: if self.__polling_task: self._logger.debug('Waiting background polling task to finish up.') self.__polling_task.cancel() - await self.__polling_task + + try: + await self.__polling_task + except asyncio.CancelledError: + # This only happens in rare edge-cases, e.g. when `stop()` is called directly + # after start_polling(), but let's better be safe than sorry ... + pass + self.__polling_task = None From 257b161de2d1a66374a2bbb98bd1e1ffd22ec569 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 19 Mar 2022 09:09:54 +0100 Subject: [PATCH 071/153] Adjust tests for now bahvior of adding ConversationHandler after init --- tests/test_basepersistence.py | 50 +++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index 24f27048466..db8c39b9fda 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -40,6 +40,7 @@ filters, Handler, ApplicationHandlerStop, + CallbackContext, ) from telegram.warnings import PTBUserWarning from tests.conftest import make_message_update, PROJECT_ROOT_PATH, DictApplication @@ -480,11 +481,42 @@ async def get_callback_data(*args, **kwargs): with pytest.raises(ValueError, match='callback_data must be'): await papp.initialize() - @default_papp + @filled_papp @pytest.mark.asyncio async def test_add_conversation_handler_after_init(self, papp: Application, recwarn): + context = CallbackContext(application=papp) + + # Set it up such that the handler has a conversation in progress that's not persisted + papp.persistence.conversations['conv_1'].pop((2, 2)) + conversation = build_conversation_handler('conv_1', persistent=True) + update = TrackingConversationHandler.build_update(state=HandlerStates.END, chat_id=2) + check = conversation.check_update(update=update) + await conversation.handle_update( + update=update, check_result=check, application=papp, context=context + ) + + assert conversation.check_update( + TrackingConversationHandler.build_update(state=HandlerStates.STATE_1, chat_id=2) + ) + + # and another one that will be overridden + update = TrackingConversationHandler.build_update(state=HandlerStates.END, chat_id=1) + check = conversation.check_update(update=update) + await conversation.handle_update( + update=update, check_result=check, application=papp, context=context + ) + update = TrackingConversationHandler.build_update(state=HandlerStates.STATE_1, chat_id=1) + check = conversation.check_update(update=update) + await conversation.handle_update( + update=update, check_result=check, application=papp, context=context + ) + + assert conversation.check_update( + TrackingConversationHandler.build_update(state=HandlerStates.STATE_2, chat_id=1) + ) + async with papp: - papp.add_handler(build_conversation_handler('name', persistent=True)) + papp.add_handler(conversation) assert len(recwarn) >= 1 found = False @@ -499,6 +531,20 @@ async def test_add_conversation_handler_after_init(self, papp: Application, recw assert found + await asyncio.sleep(0.05) + # conversation with chat_id 2 must not have been overridden + assert conversation.check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=2) + ) + + # conversation with chat_id 1 must have been overridden + assert not conversation.check_update( + TrackingConversationHandler.build_update(state=HandlerStates.STATE_2, chat_id=1) + ) + assert conversation.check_update( + TrackingConversationHandler.build_update(state=HandlerStates.STATE_1, chat_id=1) + ) + def test_add_conversation_without_persistence(self, app): with pytest.raises(ValueError, match='if application has no persistence'): app.add_handler(build_conversation_handler('name', persistent=True)) From 38df2b74350fe9c09a552eb5c7ac66f33be846e6 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 19 Mar 2022 10:58:08 +0100 Subject: [PATCH 072/153] Allow App.process_update only after App. initialize was called --- telegram/ext/_application.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index f9e938cca4d..51b5dfc31a7 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -255,6 +255,12 @@ def __init__( self.__update_persistence_lock = asyncio.Lock() self.__create_task_tasks: Set[asyncio.Task] = set() + def _check_initialized(self) -> None: + if not self._initialized: + raise RuntimeError( + 'This Application was not initialized via `Application.initialize`!' + ) + @property def running(self) -> bool: """:obj:`bool`: Indicates if this application is running. @@ -400,10 +406,7 @@ async def start(self) -> None: """ if self.running: raise RuntimeError('This Application is already running!') - if not self._initialized: - raise RuntimeError( - 'This Application was not initialized via `Application.initialize`!' - ) + self._check_initialized() self._running = True self.__update_persistence_event.clear() @@ -718,7 +721,12 @@ async def process_update(self, update: object) -> None: :class:`telegram.error.TelegramError`): The update to process. + Raises: + :exc:`RuntimeError`: If the application was not initialized. """ + # Processing updates before initialize() is a problem e.g. if persistence is used + self._check_initialized() + context = None any_blocking = False From aa864e987b671bd0bbdffa04a71e9a0537c5e7a4 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 19 Mar 2022 11:03:43 +0100 Subject: [PATCH 073/153] adjust tests --- tests/test_application.py | 142 ++++++++++++++++++++------------------ 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 518792c7e37..9b2c7af5fd0 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -115,12 +115,13 @@ async def callback_context(self, update, context): ): self.received = context.error.message - def test_slot_behaviour(self, bot, mro_slots): - app = ApplicationBuilder().bot(bot).build() - for at in app.__slots__: - at = f"_Application{at}" if at.startswith('__') and not at.endswith('__') else at - assert getattr(app, at, 'err') != 'err', f"got extra slot '{at}'" - assert len(mro_slots(app)) == len(set(mro_slots(app))), "duplicate slot" + @pytest.mark.asyncio + async def test_slot_behaviour(self, bot, mro_slots): + async with ApplicationBuilder().token(bot.token).build() as app: + for at in app.__slots__: + at = f"_Application{at}" if at.startswith('__') and not at.endswith('__') else at + assert getattr(app, at, 'err') != 'err', f"got extra slot '{at}'" + assert len(mro_slots(app)) == len(set(mro_slots(app))), "duplicate slot" def test_manual_init_warning(self, recwarn, updater): Application( @@ -196,7 +197,7 @@ def test_custom_context_init(self, bot): bot_data=complex, ) - application = ApplicationBuilder().bot(bot).context_types(cc).build() + application = ApplicationBuilder().token(bot.token).context_types(cc).build() assert isinstance(application.user_data[1], int) assert isinstance(application.chat_data[1], float) @@ -447,13 +448,14 @@ def two(update, context): if context is self.received: pytest.fail('First handler was wrongly called') - app.add_handler(MessageHandler(filters.Regex('test'), one), group=1) - app.add_handler(MessageHandler(filters.ALL, two), group=2) - u = make_message_update(message='test') - await app.process_update(u) - self.received = None - u.message.text = 'something' - await app.process_update(u) + async with app: + app.add_handler(MessageHandler(filters.Regex('test'), one), group=1) + app.add_handler(MessageHandler(filters.ALL, two), group=2) + u = make_message_update(message='test') + await app.process_update(u) + self.received = None + u.message.text = 'something' + await app.process_update(u) def test_add_handler_errors(self, app): handler = 'not a handler' @@ -639,13 +641,14 @@ async def start3(b, u): ), ) - # If ApplicationHandlerStop raised handlers in other groups should not be called. - passed = [] - app.add_handler(CommandHandler('start', start1), 1) - app.add_handler(CommandHandler('start', start3), 1) - app.add_handler(CommandHandler('start', start2), 2) - await app.process_update(update) - assert passed == ['start1'] + async with app: + # If ApplicationHandlerStop raised handlers in other groups should not be called. + passed = [] + app.add_handler(CommandHandler('start', start1), 1) + app.add_handler(CommandHandler('start', start3), 1) + app.add_handler(CommandHandler('start', start2), 2) + await app.process_update(update) + assert passed == ['start1'] @pytest.mark.asyncio async def test_flow_stop_by_error_handler(self, app, bot): @@ -667,14 +670,15 @@ async def error(u, c): passed.append(c.error) raise ApplicationHandlerStop - # If ApplicationHandlerStop raised handlers in other groups should not be called. - passed = [] - app.add_error_handler(error) - app.add_handler(TypeHandler(object, start1), 1) - app.add_handler(TypeHandler(object, start2), 1) - app.add_handler(TypeHandler(object, start3), 2) - await app.process_update(1) - assert passed == ['start1', 'error', exception] + async with app: + # If ApplicationHandlerStop raised handlers in other groups should not be called. + passed = [] + app.add_error_handler(error) + app.add_handler(TypeHandler(object, start1), 1) + app.add_handler(TypeHandler(object, start2), 1) + app.add_handler(TypeHandler(object, start3), 2) + await app.process_update(1) + assert passed == ['start1', 'error', exception] @pytest.mark.asyncio async def test_error_in_handler_part_1(self, app): @@ -730,15 +734,16 @@ async def error(u, c): ), ) - # If an unhandled exception was caught, no further handlers from the same group should be - # called. Also, the error handler should be called and receive the exception - passed = [] - app.add_handler(CommandHandler('start', start1), 1) - app.add_handler(CommandHandler('start', start2), 1) - app.add_handler(CommandHandler('start', start3), 2) - app.add_error_handler(error) - await app.process_update(update) - assert passed == ['start1', 'error', err, 'start3'] + async with app: + # If an unhandled exception was caught, no further handlers from the same group should + # be called. Also, the error handler should be called and receive the exception + passed = [] + app.add_handler(CommandHandler('start', start1), 1) + app.add_handler(CommandHandler('start', start2), 1) + app.add_handler(CommandHandler('start', start3), 2) + app.add_error_handler(error) + await app.process_update(update) + assert passed == ['start1', 'error', err, 'start3'] @pytest.mark.asyncio @pytest.mark.parametrize('block', (True, False)) @@ -820,7 +825,7 @@ async def error_handler(_, context): application = ( ApplicationBuilder() - .bot(bot) + .token(bot.token) .context_types( ContextTypes( context=CustomContext, bot_data=int, user_data=float, chat_data=complex @@ -833,9 +838,10 @@ async def error_handler(_, context): MessageHandler(filters.ALL, self.callback_raise_error('TestError')) ) - await application.process_update(self.message_update) - await asyncio.sleep(0.05) - assert self.received == (CustomContext, float, complex, int) + async with application: + await application.process_update(self.message_update) + await asyncio.sleep(0.05) + assert self.received == (CustomContext, float, complex, int) @pytest.mark.asyncio async def test_custom_context_handler_callback(self, bot): @@ -849,7 +855,7 @@ def callback(_, context): application = ( ApplicationBuilder() - .bot(bot) + .token(bot.token) .context_types( ContextTypes( context=CustomContext, bot_data=int, user_data=float, chat_data=complex @@ -859,9 +865,10 @@ def callback(_, context): ) application.add_handler(MessageHandler(filters.ALL, callback)) - await application.process_update(self.message_update) - await asyncio.sleep(0.05) - assert self.received == (CustomContext, float, complex, int) + async with application: + await application.process_update(self.message_update) + await asyncio.sleep(0.05) + assert self.received == (CustomContext, float, complex, int) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -888,13 +895,14 @@ async def handle_update( ) self.received = check_result - app.add_handler(MyHandler(self.callback_increase_count)) - await app.process_update(1) - assert self.count == (1 if expected else 0) - if expected: - assert self.received == check - else: - assert self.received is None + async with app: + app.add_handler(MyHandler(self.callback_increase_count)) + await app.process_update(1) + assert self.count == (1 if expected else 0) + if expected: + assert self.received == check + else: + assert self.received is None @pytest.mark.asyncio async def test_non_blocking_handler(self, app): @@ -1024,24 +1032,26 @@ async def error_handler(*args, **kwargs): self.count = 5 app = Application.builder().token(bot.token).defaults(Defaults(block=block)).build() - app.add_handler(TypeHandler(object, self.callback_raise_error)) - app.add_error_handler(error_handler) - await app.process_update(1) - await asyncio.sleep(0.05) - assert self.count == expected_output - await asyncio.sleep(0.1) - assert self.count == 5 + async with app: + app.add_handler(TypeHandler(object, self.callback_raise_error)) + app.add_error_handler(error_handler) + await app.process_update(1) + await asyncio.sleep(0.05) + assert self.count == expected_output + await asyncio.sleep(0.1) + assert self.count == 5 @pytest.mark.parametrize(['block', 'expected_output'], [(False, 0), (True, 5)]) @pytest.mark.asyncio async def test_default_block_handler(self, bot, block, expected_output): app = Application.builder().token(bot.token).defaults(Defaults(block=block)).build() - app.add_handler(TypeHandler(object, self.callback_set_count(5, sleep=0.1))) - await app.process_update(1) - await asyncio.sleep(0.05) - assert self.count == expected_output - await asyncio.sleep(0.15) - assert self.count == 5 + async with app: + app.add_handler(TypeHandler(object, self.callback_set_count(5, sleep=0.1))) + await app.process_update(1) + await asyncio.sleep(0.05) + assert self.count == expected_output + await asyncio.sleep(0.15) + assert self.count == 5 @pytest.mark.asyncio @pytest.mark.parametrize('handler_block', (True, False)) From 27d439a1ee073a1c21769f39f174fa751c7a0367 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 19 Mar 2022 11:34:08 +0100 Subject: [PATCH 074/153] Fix run_* tests on non-windows --- tests/test_application.py | 51 ++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 9b2c7af5fd0..715d012f627 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1356,7 +1356,6 @@ async def callback(update, context): reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_basic(self, app, monkeypatch): - ready_event = threading.Event() exception_event = threading.Event() update_event = threading.Event() exception = TelegramError('This is a test error') @@ -1371,7 +1370,12 @@ async def get_updates(*args, **kwargs): return [self.message_update] def thread_target(): - ready_event.wait() + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") # Check that everything's running assertions['app_running'] = app.running @@ -1402,7 +1406,7 @@ def thread_target(): thread = Thread(target=thread_target) thread.start() - app.run_polling(drop_pending_updates=True, ready=ready_event, close_loop=False) + app.run_polling(drop_pending_updates=True, close_loop=False) thread.join() assert len(assertions) == 8 @@ -1414,8 +1418,6 @@ def thread_target(): reason="Can't send signals without stopping whole process on windows", ) def test_run_polling_parameters_passing(self, app, monkeypatch): - ready_event = threading.Event() - # First check that the default values match and that we have all arguments there updater_signature = inspect.signature(app.updater.start_polling) app_signature = inspect.signature(app.run_polling) @@ -1437,7 +1439,13 @@ async def stop(_, **kwargs): return True def thread_target(): - ready_event.wait() + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + time.sleep(0.1) os.kill(os.getpid(), signal.SIGINT) @@ -1445,9 +1453,8 @@ def thread_target(): monkeypatch.setattr(Updater, 'stop', stop) thread = Thread(target=thread_target) thread.start() - app.run_polling(ready=ready_event, close_loop=False) + app.run_polling(close_loop=False) thread.join() - ready_event.clear() assert set(self.received.keys()) == set(updater_signature.parameters.keys()) for name, param in updater_signature.parameters.items(): @@ -1461,9 +1468,8 @@ def thread_target(): } thread = Thread(target=thread_target) thread.start() - app.run_polling(ready=ready_event, close_loop=False, **expected) + app.run_polling(close_loop=False, **expected) thread.join() - ready_event.clear() assert set(self.received.keys()) == set(updater_signature.parameters.keys()) assert self.received.pop('error_callback', None) @@ -1474,7 +1480,6 @@ def thread_target(): reason="Can't send signals without stopping whole process on windows", ) def test_run_webhook_basic(self, app, monkeypatch): - ready_event = threading.Event() assertions = {} async def delete_webhook(*args, **kwargs): @@ -1484,7 +1489,12 @@ async def set_webhook(*args, **kwargs): return True def thread_target(): - ready_event.wait() + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") # Check that everything's running assertions['app_running'] = app.running @@ -1523,7 +1533,6 @@ def thread_target(): port=port, url_path='TOKEN', drop_pending_updates=True, - ready=ready_event, close_loop=False, ) thread.join() @@ -1546,8 +1555,6 @@ async def start_webhook(_, **kwargs): async def stop(_, **kwargs): return True - ready_event = threading.Event() - # First check that the default values match and that we have all arguments there updater_signature = inspect.signature(Updater.start_webhook) @@ -1564,15 +1571,20 @@ async def stop(_, **kwargs): assert param.default == app_signature.parameters[name].default def thread_target(): - ready_event.wait() + waited = 0 + while not app.running: + time.sleep(0.05) + waited += 0.05 + if waited > 5: + pytest.fail("App apparently won't start") + time.sleep(0.1) os.kill(os.getpid(), signal.SIGINT) thread = Thread(target=thread_target) thread.start() - app.run_webhook(ready=ready_event, close_loop=False) + app.run_webhook(close_loop=False) thread.join() - ready_event.clear() assert set(self.received.keys()) == set(updater_signature.parameters.keys()) - {'self'} for name, param in updater_signature.parameters.items(): @@ -1583,9 +1595,8 @@ def thread_target(): expected = {name: name for name in updater_signature.parameters if name != 'self'} thread = Thread(target=thread_target) thread.start() - app.run_webhook(ready=ready_event, close_loop=False, **expected) + app.run_webhook(close_loop=False, **expected) thread.join() - ready_event.clear() assert set(self.received.keys()) == set(expected.keys()) assert self.received == expected From 2e897f0f37da32a45d9c15c1b3b2c28be7ef5986 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 22 Mar 2022 20:44:19 +0100 Subject: [PATCH 075/153] Review --- telegram/ext/_application.py | 2 +- telegram/ext/_updater.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 51b5dfc31a7..d321c639ea9 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -827,7 +827,7 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> 'A persistent `ConversationHandler` was passed to `add_handler`, ' 'after `Application.initialize` was called. This is discouraged.' 'See the docs of `Application.add_handler` for details.', - stacklevel=1, + stacklevel=2, ) if group not in self.handlers: diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index e07fe564a8b..6b0df0be735 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -366,7 +366,7 @@ async def start_webhook( if self.running: raise RuntimeError('This Updater is already running!') if not self._initialized: - raise RuntimeError('This Updater is not initialized!') + raise RuntimeError('This Updater was not initialized via `Updater.initialize`!') self._running = True From af8a57aae59f4be9f6f5bfbf19117fed0e1ba6ab Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 22 Mar 2022 21:18:44 +0100 Subject: [PATCH 076/153] More review --- telegram/_bot.py | 2 +- telegram/ext/_application.py | 2 +- telegram/ext/_updater.py | 8 +++++++- telegram/ext/_utils/webhookhandler.py | 2 +- telegram/request/_httpxrequest.py | 2 +- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index 01b933de643..c528fe60318 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -379,7 +379,7 @@ async def shutdown(self) -> None: :meth:`telegram.request.BaseRequest.shutdown` for the request objects used by this bot. """ if not self._initialized: - self._logger.warning('This Bot is already shut down.') + self._logger.debug('This Bot is already shut down. Returning.') return await asyncio.gather(self._request[0].shutdown(), self._request[1].shutdown()) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index d321c639ea9..d76c3c33be3 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -317,7 +317,7 @@ async def shutdown(self) -> None: raise RuntimeError('This Application is still running!') if not self._initialized: - _logger.warning('This Application is already shut down.') + _logger.debug('This Application is already shut down. Returning.') return await self.bot.shutdown() diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 6b0df0be735..31118b208ae 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -119,7 +119,7 @@ async def shutdown(self) -> None: raise RuntimeError('This Updater is still running!') if not self._initialized: - self._logger.warning('This Updater is already shut down.') + self._logger.debug('This Updater is already shut down. Returning.') return await self.bot.shutdown() @@ -199,6 +199,12 @@ async def start_polling( :exc:`RuntimeError`: If the updater is already running. """ + if error_callback and asyncio.iscoroutinefunction(error_callback): + raise ValueError( + 'The `error_callback` must not be a coroutine function! Use an ordinary function ' + 'instead. ' + ) + async with self.__lock: if self.running: raise RuntimeError('This Updater is already running!') diff --git a/telegram/ext/_utils/webhookhandler.py b/telegram/ext/_utils/webhookhandler.py index 1b7631c37f0..90f2db08d21 100644 --- a/telegram/ext/_utils/webhookhandler.py +++ b/telegram/ext/_utils/webhookhandler.py @@ -76,7 +76,7 @@ async def serve_forever(self, ready: asyncio.Event = None) -> None: async def shutdown(self) -> None: async with self._shutdown_lock: if not self.is_running: - self._logger.warning('Webhook Server already stopped.') + self._logger.debug('Webhook Server is already shut down. Returning') return self.is_running = False self._http_server.stop() diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 0038825a1e8..425b92224e7 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -117,7 +117,7 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """See :meth:`BaseRequest.shutdown`.""" if self._client.is_closed: - _logger.warning('This HTTPXRequest is already shut down.') + _logger.debug('This HTTPXRequest is already shut down. Returning.') return await self._client.aclose() From 09965590b13baee9b94b00c0aa86d7812ee6aae7 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 22 Mar 2022 21:21:48 +0100 Subject: [PATCH 077/153] Change error types --- telegram/ext/_callbackqueryhandler.py | 2 +- telegram/ext/_updater.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/telegram/ext/_callbackqueryhandler.py b/telegram/ext/_callbackqueryhandler.py index 96b81a12f6b..2af9a474f56 100644 --- a/telegram/ext/_callbackqueryhandler.py +++ b/telegram/ext/_callbackqueryhandler.py @@ -116,7 +116,7 @@ def __init__( super().__init__(callback, block=block) if callable(pattern) and asyncio.iscoroutinefunction(pattern): - raise ValueError( + raise TypeError( 'The `pattern` must not be a coroutine function! Use an ordinary function instead.' ) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 31118b208ae..9dae213ea56 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -200,7 +200,7 @@ async def start_polling( """ if error_callback and asyncio.iscoroutinefunction(error_callback): - raise ValueError( + raise TypeError( 'The `error_callback` must not be a coroutine function! Use an ordinary function ' 'instead. ' ) From ab891d1b0f60199f9229ba1fc4135c0dd187f662 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 22 Mar 2022 21:47:25 +0100 Subject: [PATCH 078/153] adjust tests --- tests/test_basepersistence.py | 7 ++----- tests/test_updater.py | 3 +++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index db8c39b9fda..24ece7e7906 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -43,7 +43,7 @@ CallbackContext, ) from telegram.warnings import PTBUserWarning -from tests.conftest import make_message_update, PROJECT_ROOT_PATH, DictApplication +from tests.conftest import make_message_update, DictApplication class HandlerStates(int, enum.Enum): @@ -524,10 +524,7 @@ async def test_add_conversation_handler_after_init(self, papp: Application, recw if 'after `Application.initialize` was called' in str(warning.message): found = True assert warning.category is PTBUserWarning - assert ( - Path(warning.filename) - == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' - ), "incorrect stacklevel!" + assert Path(warning.filename) == Path(__file__), "incorrect stacklevel!" assert found diff --git a/tests/test_updater.py b/tests/test_updater.py index 678e315dd0c..5e31cccc76f 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -388,6 +388,9 @@ async def get_updates(*args, **kwargs): monkeypatch.setattr(updater.bot, 'get_updates', get_updates) monkeypatch.setattr(updater.bot, 'set_webhook', lambda *args, **kwargs: True) + with pytest.raises(TypeError, match='`error_callback` must not be a coroutine function'): + await updater.start_polling(error_callback=get_updates) + async with updater: self.err_handler_called = asyncio.Event() From 767278973e8cb59859482983880f254a6981630c Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 23 Mar 2022 09:01:27 +0100 Subject: [PATCH 079/153] Two more BP tests --- telegram/ext/_application.py | 8 +- tests/test_basepersistence.py | 437 +++++++++++----------------------- 2 files changed, 142 insertions(+), 303 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index d76c3c33be3..a644f717f21 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -1128,9 +1128,11 @@ async def __update_persistence(self) -> None: # dispatch any errors await asyncio.gather( - self.dispatch_error(update=None, error=result) - for result in results - if isinstance(result, Exception) + *( + self.dispatch_error(update=None, error=result) + for result in results + if isinstance(result, Exception) + ) ) def add_error_handler( diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index 24ece7e7906..d6d9c957308 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -305,8 +305,24 @@ class TestBasePersistence: # TODO: # * conversations: pending states, ending conversations, unresolved pending states - async def job_callback(self, context): - pass + def job_callback(self, chat_id: int = None): + async def callback(context): + if context.user_data: + context.user_data['key'] = 'value' + if context.chat_data: + context.chat_data['key'] = 'value' + context.bot_data['key'] = 'value' + + if chat_id: + await context.bot.send_message( + chat_id=chat_id, + text='text', + reply_markup=InlineKeyboardMarkup.from_button( + InlineKeyboardButton(text='text', callback_data='callback_data') + ), + ) + + return callback def handler_callback(self, chat_id: int = None, sleep: float = None): async def callback(update, context): @@ -660,7 +676,7 @@ async def test_update_persistence_loop_call_count_update_handling( async def test_update_persistence_loop_call_count_job(self, papp: Application, caplog): async with papp: papp.job_queue.start() - papp.job_queue.run_once(self.job_callback, when=1.5, chat_id=1, user_id=1) + papp.job_queue.run_once(self.job_callback(), when=1.5, chat_id=1, user_id=1) await asyncio.sleep(2.5) assert not papp.persistence.updated_bot_data assert not papp.persistence.updated_chat_ids @@ -705,7 +721,7 @@ async def test_update_persistence_loop_call_count_job(self, papp: Application, c # Nothing should have been updated after running job without associated user/chat_data papp.persistence.reset_tracking() - papp.job_queue.run_once(self.job_callback, when=0.1) + papp.job_queue.run_once(self.job_callback(), when=0.1) await asyncio.sleep(0.2) with caplog.at_level(logging.ERROR): await papp.update_persistence() @@ -811,19 +827,78 @@ async def test_update_persistence_loop_saved_data_update_handling( assert not papp.persistence.conversations + @papp_store_all_or_none + @pytest.mark.asyncio + async def test_update_persistence_loop_saved_data_job(self, papp: Application, chat_id): + papp.add_handler( + MessageHandler(filters.ALL, callback=self.handler_callback(chat_id=chat_id)), group=-1 + ) + + async with papp: + papp.job_queue.start() + papp.job_queue.run_once(self.job_callback(), when=1.5, chat_id=1, user_id=1) + await asyncio.sleep(2.5) + + assert not papp.persistence.bot_data + assert papp.persistence.bot_data is not papp.bot_data + assert not papp.persistence.chat_data + assert papp.persistence.chat_data is not papp.chat_data + assert not papp.persistence.user_data + assert papp.persistence.user_data is not papp.user_data + assert papp.persistence.callback_data == ([], {}) + assert ( + papp.persistence.callback_data is not papp.bot.callback_data_cache.persistence_data + ) + assert not papp.persistence.conversations + + await papp.update_persistence() + + assert papp.persistence.bot_data is not papp.bot_data + if papp.persistence.store_data.bot_data: + assert papp.persistence.bot_data == {'key': 'value', 'refreshed': True} + else: + assert not papp.persistence.bot_data + + assert papp.persistence.chat_data is not papp.chat_data + if papp.persistence.store_data.chat_data: + assert papp.persistence.chat_data == {1: {'key': 'value', 'refreshed': True}} + assert papp.persistence.chat_data[1] is not papp.chat_data[1] + else: + assert not papp.persistence.chat_data + + assert papp.persistence.user_data is not papp.user_data + if papp.persistence.store_data.user_data: + assert papp.persistence.user_data == {1: {'key': 'value', 'refreshed': True}} + assert papp.persistence.user_data[1] is not papp.chat_data[1] + else: + assert not papp.persistence.user_data + + assert ( + papp.persistence.callback_data is not papp.bot.callback_data_cache.persistence_data + ) + if papp.persistence.store_data.callback_data: + assert papp.persistence.callback_data[1] == {} + assert len(papp.persistence.callback_data[0]) == 1 + else: + assert papp.persistence.callback_data == ([], {}) + + assert not papp.persistence.conversations + @default_papp @pytest.mark.parametrize('delay_type', ('job', 'handler', 'task')) @pytest.mark.asyncio async def test_update_persistence_loop_async_logic( self, papp: Application, delay_type: str, chat_id ): + """All three kinds of 'asyncio background processes' should mark things for update once + they're done.""" sleep = 1.5 update = TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) async with papp: if delay_type == 'job': papp.job_queue.start() - papp.job_queue.run_once(self.job_callback, when=sleep, chat_id=1, user_id=1) + papp.job_queue.run_once(self.job_callback(), when=sleep, chat_id=1, user_id=1) elif delay_type == 'handler': papp.add_handler( MessageHandler( @@ -845,6 +920,7 @@ async def test_update_persistence_loop_async_logic( assert papp.persistence.updated_callback_data assert not papp.persistence.updated_conversations + # Wait for the asyncio process to be done await asyncio.sleep(sleep + 1) await papp.update_persistence() assert not papp.persistence.dropped_chat_ids @@ -923,301 +999,62 @@ async def test_migrate_chat_data(self, papp: Application): assert papp.persistence.dropped_chat_ids == {1: 1} assert papp.persistence.updated_chat_ids == {2: 1} - # def test_error_while_saving_chat_data(self, bot): - # increment = [] - # - # class OwnPersistence(BasePersistence): - # def get_callback_data(self): - # return None - # - # def update_callback_data(self, data): - # raise Exception - # - # def get_bot_data(self): - # return {} - # - # def update_bot_data(self, data): - # raise Exception - # - # def drop_chat_data(self, chat_id): - # pass - # - # def drop_user_data(self, user_id): - # pass - # - # def get_chat_data(self): - # return defaultdict(dict) - # - # def update_chat_data(self, chat_id, data): - # raise Exception - # - # def get_user_data(self): - # return defaultdict(dict) - # - # def update_user_data(self, user_id, data): - # raise Exception - # - # def get_conversations(self, name): - # pass - # - # def update_conversation(self, name, key, new_state): - # pass - # - # def refresh_user_data(self, user_id, user_data): - # pass - # - # def refresh_chat_data(self, chat_id, chat_data): - # pass - # - # def refresh_bot_data(self, bot_data): - # pass - # - # def flush(self): - # pass - # - # def start1(u, c): - # pass - # - # def error(u, c): - # increment.append("error") - # - # # If updating a user_data or chat_data from a persistence object throws an error, - # # the error handler should catch it - # - # update = Update( - # 1, - # message=Message( - # 1, - # None, - # Chat(1, "lala"), - # from_user=User(1, "Test", False), - # text='/start', - # entities=[ - # MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) - # ], - # bot=bot, - # ), - # ) - # my_persistence = OwnPersistence() - # app = ApplicationBuilder().bot(bot).persistence(my_persistence).build() - # app.add_handler(CommandHandler('start', start1)) - # app.add_error_handler(error) - # app.process_update(update) - # assert increment == ["error", "error", "error", "error"] - # - # def test_error_while_persisting(self, app, caplog): - # class OwnPersistence(BasePersistence): - # def update(self, data): - # raise Exception('PersistenceError') - # - # def update_callback_data(self, data): - # self.update(data) - # - # def update_bot_data(self, data): - # self.update(data) - # - # def update_chat_data(self, chat_id, data): - # self.update(data) - # - # def update_user_data(self, user_id, data): - # self.update(data) - # - # def drop_user_data(self, user_id): - # pass - # - # def drop_chat_data(self, chat_id): - # pass - # - # def get_chat_data(self): - # pass - # - # def get_bot_data(self): - # pass - # - # def get_user_data(self): - # pass - # - # def get_callback_data(self): - # pass - # - # def get_conversations(self, name): - # pass - # - # def update_conversation(self, name, key, new_state): - # pass - # - # def refresh_bot_data(self, bot_data): - # pass - # - # def refresh_user_data(self, user_id, user_data): - # pass - # - # def refresh_chat_data(self, chat_id, chat_data): - # pass - # - # def flush(self): - # pass - # - # def callback(update, context): - # pass - # - # test_flag = [] - # - # def error(update, context): - # nonlocal test_flag - # test_flag.append(str(context.error) == 'PersistenceError') - # raise Exception('ErrorHandlingError') - # - # update = Update( - # 1, message=Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') - # ) - # handler = MessageHandler(filters.ALL, callback) - # app.add_handler(handler) - # app.add_error_handler(error) - # - # app.persistence = OwnPersistence() - # - # with caplog.at_level(logging.ERROR): - # app.process_update(update) - # - # assert test_flag == [True, True, True, True] - # assert len(caplog.records) == 4 - # for record in caplog.records: - # message = record.getMessage() - # assert message.startswith('An error was raised and an uncaught') - # - # def test_persisting_no_user_no_chat(self, app): - # class OwnPersistence(BasePersistence): - # def __init__(self): - # super().__init__() - # self.test_flag_bot_data = False - # self.test_flag_chat_data = False - # self.test_flag_user_data = False - # - # def update_bot_data(self, data): - # self.test_flag_bot_data = True - # - # def update_chat_data(self, chat_id, data): - # self.test_flag_chat_data = True - # - # def update_user_data(self, user_id, data): - # self.test_flag_user_data = True - # - # def update_conversation(self, name, key, new_state): - # pass - # - # def drop_chat_data(self, chat_id): - # pass - # - # def drop_user_data(self, user_id): - # pass - # - # def get_conversations(self, name): - # pass - # - # def get_user_data(self): - # pass - # - # def get_bot_data(self): - # pass - # - # def get_chat_data(self): - # pass - # - # def refresh_bot_data(self, bot_data): - # pass - # - # def refresh_user_data(self, user_id, user_data): - # pass - # - # def refresh_chat_data(self, chat_id, chat_data): - # pass - # - # def get_callback_data(self): - # pass - # - # def update_callback_data(self, data): - # pass - # - # def flush(self): - # pass - # - # def callback(update, context): - # pass - # - # handler = MessageHandler(filters.ALL, callback) - # app.add_handler(handler) - # app.persistence = OwnPersistence() - # - # update = Update( - # 1, message=Message(1, None, None, from_user=User(1, '', False), text='Text') - # ) - # app.process_update(update) - # assert app.persistence.test_flag_bot_data - # assert app.persistence.test_flag_user_data - # assert not app.persistence.test_flag_chat_data - # - # app.persistence.test_flag_bot_data = False - # app.persistence.test_flag_user_data = False - # app.persistence.test_flag_chat_data = False - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - # app.process_update(update) - # assert app.persistence.test_flag_bot_data - # assert not app.persistence.test_flag_user_data - # assert app.persistence.test_flag_chat_data - # - # def test_update_persistence_all_async(self, monkeypatch, app): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args, **kwargs): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # monkeypatch.setattr(app, 'block', dummy_callback) - # - # for group in range(5): - # app.add_handler( - # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group - # ) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - # app.process_update(update) - # assert self.count == 0 - # - # app.bot._defaults = Defaults(block=True) - # try: - # for group in range(5): - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, - # text='Text')) - # app.process_update(update) - # assert self.count == 0 - # finally: - # app.bot._defaults = None - # - # @pytest.mark.parametrize('block', [DEFAULT_FALSE, False]) - # def test_update_persistence_one_sync(self, monkeypatch, app, block): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args, **kwargs): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # monkeypatch.setattr(app, 'block', dummy_callback) - # - # for group in range(5): - # app.add_handler( - # MessageHandler(filters.TEXT, dummy_callback, block=True), group=group - # ) - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback, block=block),group=5) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, text='Text')) - # app.process_update(update) - # assert self.count == 1 - # + @pytest.mark.asyncio + async def test_errors_while_persisting(self, bot, caplog): + class ErrorPersistence(TrackingPersistence): + def raise_error(self): + raise Exception('PersistenceError') + + async def update_callback_data(self, data): + self.raise_error() + + async def update_bot_data(self, data): + self.raise_error() + + async def update_chat_data(self, chat_id, data): + self.raise_error() + + async def update_user_data(self, user_id, data): + self.raise_error() + + async def drop_user_data(self, user_id): + self.raise_error() + + async def drop_chat_data(self, chat_id): + self.raise_error() + + async def update_conversation(self, name, key, new_state): + self.raise_error() + + test_flag = [] + + async def error(update, context): + test_flag.append(str(context.error) == 'PersistenceError') + raise Exception('ErrorHandlingError') + + app = ApplicationBuilder().token(bot.token).persistence(ErrorPersistence()).build() + + async with app: + app.add_error_handler(error) + for _ in range(5): + # second pass processes update in conv_2 + await app.process_update( + TrackingConversationHandler.build_update(HandlerStates.END, chat_id=1) + ) + app.drop_chat_data(7) + app.drop_user_data(42) + + assert not caplog.records + + with caplog.at_level(logging.ERROR): + await app.update_persistence() + + assert len(caplog.records) == 6 + assert test_flag == [True, True, True, True, True, True] + for record in caplog.records: + message = record.getMessage() + assert message.startswith('An error was raised and an uncaught') + # @pytest.mark.parametrize('block,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)]) # def test_update_persistence_defaults_async(self, monkeypatch, app, block, expected): # def update_persistence(*args, **kwargs): From 43ede97114d23b6f20be232577930181d135a537 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 23 Mar 2022 20:12:43 +0100 Subject: [PATCH 080/153] fix a test --- tests/test_basepersistence.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index d6d9c957308..f5700206894 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -836,7 +836,9 @@ async def test_update_persistence_loop_saved_data_job(self, papp: Application, c async with papp: papp.job_queue.start() - papp.job_queue.run_once(self.job_callback(), when=1.5, chat_id=1, user_id=1) + papp.job_queue.run_once( + self.job_callback(chat_id=chat_id), when=1.5, chat_id=1, user_id=1 + ) await asyncio.sleep(2.5) assert not papp.persistence.bot_data From bdef5fd9e65a732380e848e3193bdce4019cd14b Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 23 Mar 2022 20:42:56 +0100 Subject: [PATCH 081/153] one more BP test --- tests/test_basepersistence.py | 90 ++++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 22 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index f5700206894..d5393146892 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -1057,25 +1057,71 @@ async def error(update, context): message = record.getMessage() assert message.startswith('An error was raised and an uncaught') - # @pytest.mark.parametrize('block,expected', [(DEFAULT_FALSE, 1), (False, 1), (True, 0)]) - # def test_update_persistence_defaults_async(self, monkeypatch, app, block, expected): - # def update_persistence(*args, **kwargs): - # self.count += 1 - # - # def dummy_callback(*args, **kwargs): - # pass - # - # monkeypatch.setattr(app, 'update_persistence', update_persistence) - # monkeypatch.setattr(app, 'block', dummy_callback) - # app.bot._defaults = Defaults(block=block) - # - # try: - # for group in range(5): - # app.add_handler(MessageHandler(filters.TEXT, dummy_callback), group=group) - # - # update = Update(1, message=Message(1, None, Chat(1, ''), from_user=None, - # text='Text')) - # app.process_update(update) - # assert self.count == expected - # finally: - # app.bot._defaults = None + @default_papp + @pytest.mark.parametrize( + 'delay_type', ('job', 'blocking_handler', 'nonblocking_handler', 'task') + ) + @pytest.mark.asyncio + async def test_update_persistence_after_exception( + self, papp: Application, delay_type: str, chat_id + ): + """Makes sure that persistence is updated even if an exception happened in a callback.""" + sleep = 1.5 + update = TrackingConversationHandler.build_update(HandlerStates.STATE_1, chat_id=1) + errors = 0 + + async def error(_, __): + nonlocal errors + errors += 1 + + async def raise_error(*args, **kwargs): + raise Exception + + async with papp: + papp.add_error_handler(error) + + await papp.update_persistence() + assert papp.persistence.updated_bot_data + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_user_ids + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert papp.persistence.updated_callback_data + assert not papp.persistence.updated_conversations + assert errors == 0 + + if delay_type == 'job': + papp.job_queue.start() + papp.job_queue.run_once(raise_error, when=sleep, chat_id=1, user_id=1) + elif delay_type.endswith('_handler'): + papp.add_handler( + MessageHandler( + filters.ALL, + raise_error, + block=delay_type.startswith('blocking'), + ) + ) + await papp.process_update(update) + else: + papp.create_task(raise_error(), update=update) + + # Wait for the asyncio process to be done + await asyncio.sleep(sleep + 1) + + assert errors == 1 + await papp.update_persistence() + assert not papp.persistence.dropped_chat_ids + assert not papp.persistence.dropped_user_ids + assert papp.persistence.updated_bot_data == papp.persistence.store_data.bot_data + assert ( + papp.persistence.updated_callback_data == papp.persistence.store_data.callback_data + ) + if papp.persistence.store_data.user_data: + assert papp.persistence.updated_user_ids == {1: 1} + else: + assert not papp.persistence.updated_user_ids + if papp.persistence.store_data.chat_data: + assert papp.persistence.updated_chat_ids == {1: 1} + else: + assert not papp.persistence.updated_chat_ids + assert not papp.persistence.updated_conversations From b3564c2d41beffbb2e2086d0594fdcfaac71f052 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 23 Mar 2022 22:01:44 +0100 Subject: [PATCH 082/153] review --- telegram/ext/_application.py | 4 +-- telegram/ext/_basepersistence.py | 5 ++-- telegram/ext/_conversationhandler.py | 42 +++++++++++++++------------- telegram/ext/_dictpersistence.py | 6 ++-- telegram/ext/_picklepersistence.py | 5 ++-- telegram/ext/_utils/types.py | 5 ++-- 6 files changed, 34 insertions(+), 33 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index d76c3c33be3..1d54124dc56 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -52,7 +52,7 @@ from telegram._utils.defaultvalue import DefaultValue, DEFAULT_TRUE, DEFAULT_NONE from telegram._utils.warnings import warn from telegram.ext._utils.trackingdict import TrackingDict -from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ, HandlerCallback +from telegram.ext._utils.types import CCT, UD, CD, BD, BT, JQ, HandlerCallback, ConversationKey from telegram.ext._utils.stack import was_called_by if TYPE_CHECKING: @@ -243,7 +243,7 @@ def __init__( # This attribute will hold references to the conversation dicts of all conversation # handlers so that we can extract the changed states during `update_persistence` self._conversation_handler_conversations: Dict[ - str, TrackingDict[Tuple[int, ...], object] + str, TrackingDict[ConversationKey, object] ] = {} # A number of low-level helpers for the internal logic diff --git a/telegram/ext/_basepersistence.py b/telegram/ext/_basepersistence.py index 5cbe381f6aa..992eb1ee6d0 100644 --- a/telegram/ext/_basepersistence.py +++ b/telegram/ext/_basepersistence.py @@ -21,7 +21,6 @@ from typing import ( Dict, Optional, - Tuple, Generic, NamedTuple, NoReturn, @@ -30,7 +29,7 @@ from telegram import Bot from telegram.ext import ExtBot -from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData +from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData, ConversationKey class PersistenceInput(NamedTuple): # skipcq: PYL-E0239 @@ -256,7 +255,7 @@ async def get_conversations(self, name: str) -> ConversationDict: @abstractmethod async def update_conversation( - self, name: str, key: Tuple[int, ...], new_state: Optional[object] + self, name: str, key: ConversationKey, new_state: Optional[object] ) -> None: """Will be called when a :class:`telegram.ext.ConversationHandler` changes states. This allows the storage of the new state in the persistence. diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 54415947bda..d69aff70493 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -54,12 +54,12 @@ ) from telegram._utils.warnings import warn from telegram.ext._utils.trackingdict import TrackingDict -from telegram.ext._utils.types import ConversationDict +from telegram.ext._utils.types import ConversationDict, ConversationKey from telegram.ext._utils.types import CCT if TYPE_CHECKING: from telegram.ext import Application, Job, JobQueue -CheckUpdateType = Tuple[object, Tuple[int, ...], Handler, object] +CheckUpdateType = Tuple[object, ConversationKey, Handler, object] _logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ class _ConversationTimeoutContext(Generic[CCT]): __slots__ = ('conversation_key', 'update', 'application', 'callback_context') - conversation_key: Tuple[int, ...] + conversation_key: ConversationKey update: Update application: 'Application[Any, CCT, Any, Any, Any, JobQueue]' callback_context: CallbackContext @@ -238,7 +238,6 @@ class ConversationHandler(Handler[Update, CCT]): """ __slots__ = ( - '__aplication', '_allow_reentry', '_child_conversations', '_conversation_timeout', @@ -246,13 +245,11 @@ class ConversationHandler(Handler[Update, CCT]): '_conversations_lock', '_entry_points', '_fallbacks', - '_logger', '_map_to_parent', '_name', '_per_chat', '_per_message', '_per_user', - '_persistence', '_states', '_timeout_jobs_lock', 'persistent', @@ -304,7 +301,7 @@ def __init__( self._name = name self._map_to_parent = map_to_parent - self.timeout_jobs: Dict[Tuple[int, ...], 'Job'] = {} + self.timeout_jobs: Dict[ConversationKey, 'Job'] = {} self._timeout_jobs_lock = asyncio.Lock() self._conversations: ConversationDict = {} # TODO: Do we still need this lock? @@ -527,7 +524,7 @@ def map_to_parent(self, value: object) -> NoReturn: async def _initialize_persistence( self, application: 'Application' - ) -> TrackingDict[Tuple[int, ...], object]: + ) -> TrackingDict[ConversationKey, object]: """Initializes the persistence for this handler. While this method is marked as protected, we expect it to be called by the Application/parent conversations. It's just protected to hide it from users. @@ -545,7 +542,7 @@ async def _initialize_persistence( with self._conversations_lock: current_conversations = self._conversations self._conversations = cast( - TrackingDict[Tuple[int, ...], object], + TrackingDict[ConversationKey, object], TrackingDict(), ) # In the conversation already processed updates @@ -563,23 +560,28 @@ async def _initialize_persistence( return self._conversations - def _get_key(self, update: Update) -> Tuple[int, ...]: + def _get_key(self, update: Update) -> ConversationKey: chat = update.effective_chat user = update.effective_user - key = [] + key: List[Union[int, str]] = [] if self.per_chat: - key.append(chat.id) # type: ignore[union-attr] + if chat is None: + raise RuntimeError("Can't build key for update without effective chat!") + key.append(chat.id) - if self.per_user and user is not None: + if self.per_user: + if user is None: + raise RuntimeError("Can't build key for update without effective user!") key.append(user.id) if self.per_message: - key.append( - update.callback_query.inline_message_id # type: ignore[union-attr] - or update.callback_query.message.message_id # type: ignore[union-attr] - ) + if update.callback_query is None: + raise RuntimeError("Can't build key for update without CallbackQuery!") + if update.callback_query.inline_message_id: + key.append(update.callback_query.inline_message_id) + key.append(update.callback_query.message.message_id) # type: ignore[union-attr] return tuple(key) @@ -589,7 +591,7 @@ async def _schedule_job_delayed( application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', update: Update, context: CallbackContext, - conversation_key: Tuple[int, ...], + conversation_key: ConversationKey, ) -> None: try: effective_new_state = await new_state @@ -614,7 +616,7 @@ def _schedule_job( application: 'Application[Any, CCT, Any, Any, Any, JobQueue]', update: Update, context: CallbackContext, - conversation_key: Tuple[int, ...], + conversation_key: ConversationKey, ) -> None: if new_state == self.END: return @@ -799,7 +801,7 @@ async def handle_update( # type: ignore[override] raise ApplicationHandlerStop() return None - def _update_state(self, new_state: object, key: Tuple[int, ...]) -> None: + def _update_state(self, new_state: object, key: ConversationKey) -> None: if new_state == self.END: with self._conversations_lock: if key in self._conversations: diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index f73f0214620..8ab8f8db0aa 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -18,13 +18,13 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains the DictPersistence class.""" -from typing import Dict, Optional, Tuple, cast +from typing import Dict, Optional, cast from copy import deepcopy from telegram.ext import BasePersistence, PersistenceInput from telegram._utils.types import JSONDict -from telegram.ext._utils.types import ConversationDict, CDCData +from telegram.ext._utils.types import ConversationDict, CDCData, ConversationKey try: import ujson as json @@ -296,7 +296,7 @@ async def get_conversations(self, name: str) -> ConversationDict: return self.conversations.get(name, {}).copy() # type: ignore[union-attr] async def update_conversation( - self, name: str, key: Tuple[int, ...], new_state: Optional[object] + self, name: str, key: ConversationKey, new_state: Optional[object] ) -> None: """Will update the conversations for the given handler. diff --git a/telegram/ext/_picklepersistence.py b/telegram/ext/_picklepersistence.py index d9ce1aa4a41..7899dea6675 100644 --- a/telegram/ext/_picklepersistence.py +++ b/telegram/ext/_picklepersistence.py @@ -40,8 +40,7 @@ from telegram._utils.warnings import warn from telegram.ext import BasePersistence, PersistenceInput from telegram.ext._contexttypes import ContextTypes -from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData - +from telegram.ext._utils.types import UD, CD, BD, ConversationDict, CDCData, ConversationKey _REPLACED_KNOWN_BOT = "a known bot replaced by PTB's PicklePersistence" _REPLACED_UNKNOWN_BOT = "an unknown bot replaced by PTB's PicklePersistence" @@ -393,7 +392,7 @@ async def get_conversations(self, name: str) -> ConversationDict: return self.conversations.get(name, {}).copy() # type: ignore[union-attr] async def update_conversation( - self, name: str, key: Tuple[int, ...], new_state: Optional[object] + self, name: str, key: ConversationKey, new_state: Optional[object] ) -> None: """Will update the conversations for the given handler and depending on :attr:`on_flush` save the pickle file. diff --git a/telegram/ext/_utils/types.py b/telegram/ext/_utils/types.py index 8113bff46b3..48e16143ee6 100644 --- a/telegram/ext/_utils/types.py +++ b/telegram/ext/_utils/types.py @@ -61,8 +61,9 @@ .. versionadded:: 14.0 """ -ConversationDict = MutableMapping[Tuple[int, ...], object] -"""Dict[Tuple[:obj:`int`, ...], Optional[:obj:`object`]]: +ConversationKey = Tuple[Union[int, str], ...] +ConversationDict = MutableMapping[ConversationKey, object] +"""Dict[Tuple[:obj:`int` | :obj:`str`, ...], Optional[:obj:`object`]]: Dicts as maintained by the :class:`telegram.ext.ConversationHandler`. .. versionadded:: 13.6 From 79de11cab21de7982ef3599daa26739f2bfb77e8 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 09:07:54 +0100 Subject: [PATCH 083/153] Remove remaining threading logic --- README.rst | 14 ++++ README_RAW.rst | 7 ++ telegram/ext/_callbackdatacache.py | 84 +++++++++++------------- telegram/ext/_conversationhandler.py | 52 ++++++--------- telegram/ext/filters.py | 95 ++++++++++++---------------- 5 files changed, 120 insertions(+), 132 deletions(-) diff --git a/README.rst b/README.rst index f573e62f0c9..ad83ab74398 100644 --- a/README.rst +++ b/README.rst @@ -113,6 +113,20 @@ Telegram API support All types and methods of the Telegram Bot API **5.7** are supported. +=========== +Concurrency +=========== + +Since v14.0, ``python-telegram-bot`` is built on top of Pythons ``asyncio`` module. +Because ``asyncio`` is in general single-threaded, ``python-telegram-bot`` does currently not aim to be thread-safe. +Noteworthy parts of ``python-telegram-bots`` API that are likely to cause issues (e.g. race conditions) when used in a multi-threaded setting include: + +* ``telegram.ext.Application/Updater.update_queue`` +* ``telegram.ext.ConversationHandler.check/handle_update`` +* ``telegram.ext.CallbackDataCache`` +* ``telegram.ext.BasePersistence`` +* all classes in the ``telegram.ext.filters`` module that allow to add/remove allowed users/chats at runtime + ========== Installing ========== diff --git a/README_RAW.rst b/README_RAW.rst index 676cc1c7231..e94a0d4bf62 100644 --- a/README_RAW.rst +++ b/README_RAW.rst @@ -107,6 +107,13 @@ Telegram API support All types and methods of the Telegram Bot API **5.7** are supported. +=========== +Concurrency +=========== + +Since v14.0, ``python-telegram-bot`` is built on top of Pythons ``asyncio`` module. +Because ``asyncio`` is in general single-threaded, ``python-telegram-bot`` does currently not aim to be thread-safe. + ========== Installing ========== diff --git a/telegram/ext/_callbackdatacache.py b/telegram/ext/_callbackdatacache.py index 8c8c4c02057..2656f3cc419 100644 --- a/telegram/ext/_callbackdatacache.py +++ b/telegram/ext/_callbackdatacache.py @@ -20,7 +20,6 @@ import logging import time from datetime import datetime -from threading import Lock from typing import Dict, Tuple, Union, Optional, MutableMapping, TYPE_CHECKING, cast from uuid import uuid4 @@ -119,7 +118,7 @@ class CallbackDataCache: """ - __slots__ = ('bot', 'maxsize', '_keyboard_data', '_callback_queries', '__lock', 'logger') + __slots__ = ('bot', 'maxsize', '_keyboard_data', '_callback_queries', 'logger') def __init__( self, @@ -133,7 +132,6 @@ def __init__( self.maxsize = maxsize self._keyboard_data: MutableMapping[str, _KeyboardData] = LRUCache(maxsize=maxsize) self._callback_queries: MutableMapping[str, str] = LRUCache(maxsize=maxsize) - self.__lock = Lock() if persistent_data: keyboard_data, callback_queries = persistent_data @@ -153,10 +151,9 @@ def persistence_data(self) -> CDCData: # While building a list/dict from the LRUCaches has linear runtime (in the number of # entries), the runtime is bounded by maxsize and it has the big upside of not throwing a # highly customized data structure at users trying to implement a custom persistence class - with self.__lock: - return [data.to_tuple() for data in self._keyboard_data.values()], dict( - self._callback_queries.items() - ) + return [data.to_tuple() for data in self._keyboard_data.values()], dict( + self._callback_queries.items() + ) def process_keyboard(self, reply_markup: InlineKeyboardMarkup) -> InlineKeyboardMarkup: """Registers the reply markup to the cache. If any of the buttons have @@ -171,10 +168,6 @@ def process_keyboard(self, reply_markup: InlineKeyboardMarkup) -> InlineKeyboard :class:`telegram.InlineKeyboardMarkup`: The keyboard to be passed to Telegram. """ - with self.__lock: - return self.__process_keyboard(reply_markup) - - def __process_keyboard(self, reply_markup: InlineKeyboardMarkup) -> InlineKeyboardMarkup: keyboard_uuid = uuid4().hex keyboard_data = _KeyboardData(keyboard_uuid) @@ -263,8 +256,7 @@ def process_message(self, message: Message) -> None: message (:class:`telegram.Message`): The message. """ - with self.__lock: - self.__process_message(message) + self.__process_message(message) def __process_message(self, message: Message) -> Optional[str]: """As documented in process_message, but returns the uuid of the attached keyboard, if any, @@ -324,31 +316,30 @@ def process_callback_query(self, callback_query: CallbackQuery) -> None: callback_query (:class:`telegram.CallbackQuery`): The callback query. """ - with self.__lock: - mapped = False - - if callback_query.data: - data = callback_query.data - - # Get the cached callback data for the CallbackQuery - keyboard_uuid, button_data = self.__get_keyboard_uuid_and_button_data(data) - callback_query.data = button_data # type: ignore[assignment] - - # Map the callback queries ID to the keyboards UUID for later use - if not mapped and not isinstance(button_data, InvalidCallbackData): - self._callback_queries[callback_query.id] = keyboard_uuid # type: ignore - mapped = True - - # Get the cached callback data for the inline keyboard attached to the - # CallbackQuery. - if callback_query.message: - self.__process_message(callback_query.message) - for message in ( - callback_query.message.pinned_message, - callback_query.message.reply_to_message, - ): - if message: - self.__process_message(message) + mapped = False + + if callback_query.data: + data = callback_query.data + + # Get the cached callback data for the CallbackQuery + keyboard_uuid, button_data = self.__get_keyboard_uuid_and_button_data(data) + callback_query.data = button_data # type: ignore[assignment] + + # Map the callback queries ID to the keyboards UUID for later use + if not mapped and not isinstance(button_data, InvalidCallbackData): + self._callback_queries[callback_query.id] = keyboard_uuid # type: ignore + mapped = True + + # Get the cached callback data for the inline keyboard attached to the + # CallbackQuery. + if callback_query.message: + self.__process_message(callback_query.message) + for message in ( + callback_query.message.pinned_message, + callback_query.message.reply_to_message, + ): + if message: + self.__process_message(message) def drop_data(self, callback_query: CallbackQuery) -> None: """Deletes the data for the specified callback query. @@ -364,12 +355,11 @@ def drop_data(self, callback_query: CallbackQuery) -> None: Raises: KeyError: If the callback query can not be found in the cache """ - with self.__lock: - try: - keyboard_uuid = self._callback_queries.pop(callback_query.id) - self.__drop_keyboard(keyboard_uuid) - except KeyError as exc: - raise KeyError('CallbackQuery was not found in cache.') from exc + try: + keyboard_uuid = self._callback_queries.pop(callback_query.id) + self.__drop_keyboard(keyboard_uuid) + except KeyError as exc: + raise KeyError('CallbackQuery was not found in cache.') from exc def __drop_keyboard(self, keyboard_uuid: str) -> None: try: @@ -387,13 +377,11 @@ def clear_callback_data(self, time_cutoff: Union[float, datetime] = None) -> Non bot will be used. """ - with self.__lock: - self.__clear(self._keyboard_data, time_cutoff=time_cutoff) + self.__clear(self._keyboard_data, time_cutoff=time_cutoff) def clear_callback_queries(self) -> None: """Clears the stored callback query IDs.""" - with self.__lock: - self.__clear(self._callback_queries) + self.__clear(self._callback_queries) def __clear(self, mapping: MutableMapping, time_cutoff: Union[float, datetime] = None) -> None: if not time_cutoff: diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index d69aff70493..30fce68c263 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -21,7 +21,6 @@ import asyncio import logging import datetime -import threading from dataclasses import dataclass from typing import ( # pylint: disable=unused-import # for the "Any" import TYPE_CHECKING, @@ -242,7 +241,6 @@ class ConversationHandler(Handler[Update, CCT]): '_child_conversations', '_conversation_timeout', '_conversations', - '_conversations_lock', '_entry_points', '_fallbacks', '_map_to_parent', @@ -304,8 +302,6 @@ def __init__( self.timeout_jobs: Dict[ConversationKey, 'Job'] = {} self._timeout_jobs_lock = asyncio.Lock() self._conversations: ConversationDict = {} - # TODO: Do we still need this lock? - self._conversations_lock = threading.Lock() self._child_conversations: Set['ConversationHandler'] = set() if persistent and not self.name: @@ -539,19 +535,18 @@ async def _initialize_persistence( 'persistence!' ) - with self._conversations_lock: - current_conversations = self._conversations - self._conversations = cast( - TrackingDict[ConversationKey, object], - TrackingDict(), - ) - # In the conversation already processed updates - self._conversations.update(current_conversations) - # above might be partly overridden but that's okay since we warn about that in - # add_handler - self._conversations.update_no_track( - await application.persistence.get_conversations(self.name) - ) + current_conversations = self._conversations + self._conversations = cast( + TrackingDict[ConversationKey, object], + TrackingDict(), + ) + # In the conversation already processed updates + self._conversations.update(current_conversations) + # above might be partly overridden but that's okay since we warn about that in + # add_handler + self._conversations.update_no_track( + await application.persistence.get_conversations(self.name) + ) for handler in self._child_conversations: await handler._initialize_persistence( # pylint: disable=protected-access @@ -660,8 +655,7 @@ def check_update(self, update: object) -> Optional[CheckUpdateType]: return None key = self._get_key(update) - with self._conversations_lock: - state = self._conversations.get(key) + state = self._conversations.get(key) # Resolve promises if isinstance(state, PendingState): @@ -671,8 +665,7 @@ def check_update(self, update: object) -> Optional[CheckUpdateType]: if state.done(): res = state.resolve() self._update_state(res, key) - with self._conversations_lock: - state = self._conversations.get(key) + state = self._conversations.get(key) # if not then handle WAITING state instead else: @@ -803,16 +796,14 @@ async def handle_update( # type: ignore[override] def _update_state(self, new_state: object, key: ConversationKey) -> None: if new_state == self.END: - with self._conversations_lock: - if key in self._conversations: - # If there is no key in conversations, nothing is done. - del self._conversations[key] + if key in self._conversations: + # If there is no key in conversations, nothing is done. + del self._conversations[key] elif isinstance(new_state, asyncio.Task): - with self._conversations_lock: - self._conversations[key] = PendingState( - old_state=self._conversations.get(key), task=new_state - ) + self._conversations[key] = PendingState( + old_state=self._conversations.get(key), task=new_state + ) elif new_state is not None: if new_state not in self.states: @@ -821,8 +812,7 @@ def _update_state(self, new_state: object, key: ConversationKey) -> None: f"ConversationHandler{' ' + self.name if self.name is not None else ''}.", stacklevel=2, ) - with self._conversations_lock: - self._conversations[key] = new_state + self._conversations[key] = new_state async def _trigger_timeout(self, context: CallbackContext) -> None: job = cast('Job', context.job) diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 842b412e64f..9701ede191f 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -92,7 +92,6 @@ import re from abc import ABC, abstractmethod -from threading import Lock from typing import ( Dict, FrozenSet, @@ -585,7 +584,6 @@ class _ChatUserBaseFilter(MessageFilter, ABC): '_chat_id_name', '_username_name', 'allow_empty', - '__lock', '_chat_ids', '_usernames', ) @@ -600,7 +598,6 @@ def __init__( self._chat_id_name = 'chat_id' self._username_name = 'username' self.allow_empty = allow_empty - self.__lock = Lock() self._chat_ids: Set[int] = set() self._usernames: Set[str] = set() @@ -629,27 +626,24 @@ def _parse_username(username: Optional[SLT[str]]) -> Set[str]: return {chat[1:] if chat.startswith('@') else chat for chat in username} def _set_chat_ids(self, chat_id: Optional[SLT[int]]) -> None: - with self.__lock: - if chat_id and self._usernames: - raise RuntimeError( - f"Can't set {self._chat_id_name} in conjunction with (already set) " - f"{self._username_name}s." - ) - self._chat_ids = self._parse_chat_id(chat_id) + if chat_id and self._usernames: + raise RuntimeError( + f"Can't set {self._chat_id_name} in conjunction with (already set) " + f"{self._username_name}s." + ) + self._chat_ids = self._parse_chat_id(chat_id) def _set_usernames(self, username: Optional[SLT[str]]) -> None: - with self.__lock: - if username and self._chat_ids: - raise RuntimeError( - f"Can't set {self._username_name} in conjunction with (already set) " - f"{self._chat_id_name}s." - ) - self._usernames = self._parse_username(username) + if username and self._chat_ids: + raise RuntimeError( + f"Can't set {self._username_name} in conjunction with (already set) " + f"{self._chat_id_name}s." + ) + self._usernames = self._parse_username(username) @property def chat_ids(self) -> FrozenSet[int]: - with self.__lock: - return frozenset(self._chat_ids) + return frozenset(self._chat_ids) @chat_ids.setter def chat_ids(self, chat_id: SLT[int]) -> None: @@ -669,8 +663,7 @@ def usernames(self) -> FrozenSet[str]: Returns: frozenset(:obj:`str`) """ - with self.__lock: - return frozenset(self._usernames) + return frozenset(self._usernames) @usernames.setter def usernames(self, username: SLT[str]) -> None: @@ -684,27 +677,25 @@ def add_usernames(self, username: SLT[str]) -> None: username(:obj:`str` | Tuple[:obj:`str`] | List[:obj:`str`]): Which username(s) to allow through. Leading ``'@'`` s in usernames will be discarded. """ - with self.__lock: - if self._chat_ids: - raise RuntimeError( - f"Can't set {self._username_name} in conjunction with (already set) " - f"{self._chat_id_name}s." - ) + if self._chat_ids: + raise RuntimeError( + f"Can't set {self._username_name} in conjunction with (already set) " + f"{self._chat_id_name}s." + ) - parsed_username = self._parse_username(username) - self._usernames |= parsed_username + parsed_username = self._parse_username(username) + self._usernames |= parsed_username def _add_chat_ids(self, chat_id: SLT[int]) -> None: - with self.__lock: - if self._usernames: - raise RuntimeError( - f"Can't set {self._chat_id_name} in conjunction with (already set) " - f"{self._username_name}s." - ) + if self._usernames: + raise RuntimeError( + f"Can't set {self._chat_id_name} in conjunction with (already set) " + f"{self._username_name}s." + ) - parsed_chat_id = self._parse_chat_id(chat_id) + parsed_chat_id = self._parse_chat_id(chat_id) - self._chat_ids |= parsed_chat_id + self._chat_ids |= parsed_chat_id def remove_usernames(self, username: SLT[str]) -> None: """ @@ -714,25 +705,23 @@ def remove_usernames(self, username: SLT[str]) -> None: username(:obj:`str` | Tuple[:obj:`str`] | List[:obj:`str`]): Which username(s) to disallow through. Leading ``'@'`` s in usernames will be discarded. """ - with self.__lock: - if self._chat_ids: - raise RuntimeError( - f"Can't set {self._username_name} in conjunction with (already set) " - f"{self._chat_id_name}s." - ) + if self._chat_ids: + raise RuntimeError( + f"Can't set {self._username_name} in conjunction with (already set) " + f"{self._chat_id_name}s." + ) - parsed_username = self._parse_username(username) - self._usernames -= parsed_username + parsed_username = self._parse_username(username) + self._usernames -= parsed_username def _remove_chat_ids(self, chat_id: SLT[int]) -> None: - with self.__lock: - if self._usernames: - raise RuntimeError( - f"Can't set {self._chat_id_name} in conjunction with (already set) " - f"{self._username_name}s." - ) - parsed_chat_id = self._parse_chat_id(chat_id) - self._chat_ids -= parsed_chat_id + if self._usernames: + raise RuntimeError( + f"Can't set {self._chat_id_name} in conjunction with (already set) " + f"{self._username_name}s." + ) + parsed_chat_id = self._parse_chat_id(chat_id) + self._chat_ids -= parsed_chat_id def filter(self, message: Message) -> bool: chat_or_user = self.get_chat_or_user(message) From 476a68d9a4b61a800c7e80d5730af5058f62dd2e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 20:53:22 +0100 Subject: [PATCH 084/153] Add notes on testing CH-BP integration and postpone that for now --- tests/test_basepersistence.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index d5393146892..af7f668d6b9 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -302,8 +302,12 @@ class TestBasePersistence: """Tests basic behavior of BasePersistence and (most importantly) the integration of persistence into the Application.""" - # TODO: - # * conversations: pending states, ending conversations, unresolved pending states + # TODO: Test integration of the more intricate ConversationHandler things once CH itself is + # tested. This includes: + # * pending states, i.e. non-blocking handlers + # * pending states being unresolved on shutdown + # * conversation timeouts + # * nested conversations (can conversations be persistent if their parents aren't?) def job_callback(self, chat_id: int = None): async def callback(context): From cae4e188062573f9d160f6d284ac9339945c074e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 21:12:43 +0100 Subject: [PATCH 085/153] Test CallbackContext --- tests/test_callbackcontext.py | 219 ++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 tests/test_callbackcontext.py diff --git a/tests/test_callbackcontext.py b/tests/test_callbackcontext.py new file mode 100644 index 00000000000..70b2f0a0492 --- /dev/null +++ b/tests/test_callbackcontext.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. + +import pytest + +from telegram import ( + Update, + Message, + Chat, + User, + Bot, + InlineKeyboardMarkup, + InlineKeyboardButton, + CallbackQuery, +) +from telegram.ext import CallbackContext, ApplicationBuilder +from telegram.error import TelegramError + +""" +CallbackContext.refresh_data is tested in TestBasePersistence +""" + + +class TestCallbackContext: + def test_slot_behaviour(self, app, mro_slots, recwarn): + c = CallbackContext(app) + for attr in c.__slots__: + assert getattr(c, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert not c.__dict__, f"got missing slot(s): {c.__dict__}" + assert len(mro_slots(c)) == len(set(mro_slots(c))), "duplicate slot" + c.args = c.args + assert len(recwarn) == 0, recwarn.list + + def test_from_job(self, app): + job = app.job_queue.run_once(lambda x: x, 10) + + callback_context = CallbackContext.from_job(job, app) + + assert callback_context.job is job + assert callback_context.chat_data is None + assert callback_context.user_data is None + assert callback_context.bot_data is app.bot_data + assert callback_context.bot is app.bot + assert callback_context.job_queue is app.job_queue + assert callback_context.update_queue is app.update_queue + + def test_from_update(self, app): + update = Update( + 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) + ) + + callback_context = CallbackContext.from_update(update, app) + + assert callback_context.chat_data == {} + assert callback_context.user_data == {} + assert callback_context.bot_data is app.bot_data + assert callback_context.bot is app.bot + assert callback_context.job_queue is app.job_queue + assert callback_context.update_queue is app.update_queue + + callback_context_same_user_chat = CallbackContext.from_update(update, app) + + callback_context.bot_data['test'] = 'bot' + callback_context.chat_data['test'] = 'chat' + callback_context.user_data['test'] = 'user' + + assert callback_context_same_user_chat.bot_data is callback_context.bot_data + assert callback_context_same_user_chat.chat_data is callback_context.chat_data + assert callback_context_same_user_chat.user_data is callback_context.user_data + + update_other_user_chat = Update( + 0, message=Message(0, None, Chat(2, 'chat'), from_user=User(2, 'user', False)) + ) + + callback_context_other_user_chat = CallbackContext.from_update(update_other_user_chat, app) + + assert callback_context_other_user_chat.bot_data is callback_context.bot_data + assert callback_context_other_user_chat.chat_data is not callback_context.chat_data + assert callback_context_other_user_chat.user_data is not callback_context.user_data + + def test_from_update_not_update(self, app): + callback_context = CallbackContext.from_update(None, app) + + assert callback_context.chat_data is None + assert callback_context.user_data is None + assert callback_context.bot_data is app.bot_data + assert callback_context.bot is app.bot + assert callback_context.job_queue is app.job_queue + assert callback_context.update_queue is app.update_queue + + callback_context = CallbackContext.from_update('', app) + + assert callback_context.chat_data is None + assert callback_context.user_data is None + assert callback_context.bot_data is app.bot_data + assert callback_context.bot is app.bot + assert callback_context.job_queue is app.job_queue + assert callback_context.update_queue is app.update_queue + + def test_from_error(self, app): + error = TelegramError('test') + update = Update( + 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) + ) + job = object() + coroutine = object() + + callback_context = CallbackContext.from_error( + update=update, error=error, application=app, job=job, coroutine=coroutine + ) + + assert callback_context.error is error + assert callback_context.chat_data == {} + assert callback_context.user_data == {} + assert callback_context.bot_data is app.bot_data + assert callback_context.bot is app.bot + assert callback_context.job_queue is app.job_queue + assert callback_context.update_queue is app.update_queue + assert callback_context.coroutine is coroutine + assert callback_context.job is job + + def test_match(self, app): + callback_context = CallbackContext(app) + + assert callback_context.match is None + + callback_context.matches = ['test', 'blah'] + + assert callback_context.match == 'test' + + def test_data_assignment(self, app): + update = Update( + 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) + ) + + callback_context = CallbackContext.from_update(update, app) + + with pytest.raises(AttributeError): + callback_context.bot_data = {"test": 123} + with pytest.raises(AttributeError): + callback_context.user_data = {} + with pytest.raises(AttributeError): + callback_context.chat_data = "test" + + def test_application_attribute(self, app): + callback_context = CallbackContext(app) + assert callback_context.application is app + + def test_drop_callback_data_exception(self, bot, app): + non_ext_bot = Bot(bot.token) + update = Update( + 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) + ) + + callback_context = CallbackContext.from_update(update, app) + + with pytest.raises(RuntimeError, match='This telegram.ext.ExtBot instance does not'): + callback_context.drop_callback_data(None) + + try: + app.bot = non_ext_bot + with pytest.raises(RuntimeError, match='telegram.Bot does not allow for'): + callback_context.drop_callback_data(None) + finally: + app.bot = bot + + @pytest.mark.asyncio + async def test_drop_callback_data(self, bot, monkeypatch, chat_id): + app = ApplicationBuilder().token(bot.token).arbitrary_callback_data(True).build() + + update = Update( + 0, message=Message(0, None, Chat(1, 'chat'), from_user=User(1, 'user', False)) + ) + + callback_context = CallbackContext.from_update(update, app) + async with app: + await app.bot.send_message( + chat_id=chat_id, + text='test', + reply_markup=InlineKeyboardMarkup.from_button( + InlineKeyboardButton('test', callback_data='callback_data') + ), + ) + keyboard_uuid = app.bot.callback_data_cache.persistence_data[0][0][0] + button_uuid = list(app.bot.callback_data_cache.persistence_data[0][0][2])[0] + callback_data = keyboard_uuid + button_uuid + callback_query = CallbackQuery( + id='1', + from_user=None, + chat_instance=None, + data=callback_data, + ) + app.bot.callback_data_cache.process_callback_query(callback_query) + + try: + assert len(app.bot.callback_data_cache.persistence_data[0]) == 1 + assert list(app.bot.callback_data_cache.persistence_data[1]) == ['1'] + + callback_context.drop_callback_data(callback_query) + assert app.bot.callback_data_cache.persistence_data == ([], {}) + finally: + app.bot.callback_data_cache.clear_callback_data() + app.bot.callback_data_cache.clear_callback_queries() From f90cd50016b1e91cf827f7e8622f432a6cf7d4a9 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 21:14:48 +0100 Subject: [PATCH 086/153] Test Defaults --- tests/test_defaults.py | 59 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/test_defaults.py diff --git a/tests/test_defaults.py b/tests/test_defaults.py new file mode 100644 index 00000000000..5dad8a9fad9 --- /dev/null +++ b/tests/test_defaults.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. + +import pytest +import inspect + +from telegram.ext import Defaults +from telegram import User + + +class TestDefault: + def test_slot_behaviour(self, mro_slots): + a = Defaults(parse_mode='HTML', quote=True) + for attr in a.__slots__: + assert getattr(a, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(a)) == len(set(mro_slots(a))), "duplicate slot" + + def test_data_assignment(self): + defaults = Defaults() + + for name, val in inspect.getmembers(Defaults, lambda x: isinstance(x, property)): + with pytest.raises(AttributeError): + setattr(defaults, name, True) + + def test_equality(self): + a = Defaults(parse_mode='HTML', quote=True) + b = Defaults(parse_mode='HTML', quote=True) + c = Defaults(parse_mode='HTML', quote=True, protect_content=True) + d = Defaults(parse_mode='HTML', protect_content=True) + e = User(123, 'test_user', False) + + assert a == b + assert hash(a) == hash(b) + assert a is not b + + assert a != c + assert hash(a) != hash(c) + + assert a != d + assert hash(a) != hash(d) + + assert a != e + assert hash(a) != hash(e) From 6012501f7f3043ebfe58a1af4289dea7120633bb Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 22:05:58 +0100 Subject: [PATCH 087/153] Get started on JobQueue Tests - half of them is still failing --- tests/test_jobqueue.py | 552 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 552 insertions(+) create mode 100644 tests/test_jobqueue.py diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py new file mode 100644 index 00000000000..e7fa87332bd --- /dev/null +++ b/tests/test_jobqueue.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio +import calendar +import datetime as dtm +import logging +import os +import platform +import time + +import pytest +import pytz +from apscheduler.schedulers import SchedulerNotRunningError +from flaky import flaky +from telegram.ext import ( + JobQueue, + Job, + CallbackContext, + ContextTypes, + ApplicationBuilder, +) + + +class CustomContext(CallbackContext): + pass + + +@pytest.fixture(scope='function') +@pytest.mark.asyncio +async def job_queue(bot, app): + jq = JobQueue() + jq.set_application(app) + await jq.start() + yield jq + await jq.stop() + + +@pytest.mark.skipif( + os.getenv('GITHUB_ACTIONS', False) and platform.system() in ['Windows', 'Darwin'], + reason="On Windows & MacOS precise timings are not accurate.", +) +@flaky(10, 1) # Timings aren't quite perfect +class TestJobQueue: + result = 0 + job_time = 0 + received_error = None + + @pytest.fixture(autouse=True) + def reset(self): + self.result = 0 + self.job_time = 0 + self.received_error = None + + async def job_run_once(self, context): + if ( + isinstance(context, CallbackContext) + and isinstance(context.job, Job) + and isinstance(context.update_queue, asyncio.Queue) + and context.job.context is None + and context.chat_data is None + and context.user_data is None + and isinstance(context.bot_data, dict) + ): + self.result += 1 + + async def job_with_exception(self, context): + raise Exception('Test Error') + + async def job_remove_self(self, context): + self.result += 1 + context.job.schedule_removal() + + async def job_run_once_with_context(self, context): + self.result += context.job.context + + async def job_datetime_tests(self, context): + self.job_time = time.time() + + async def error_handler_context(self, update, context): + self.received_error = (str(context.error), context.job) + + async def error_handler_raise_error(self, *args): + raise Exception('Failing bigly') + + def test_slot_behaviour(self, job_queue, mro_slots): + for attr in job_queue.__slots__: + assert getattr(job_queue, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(job_queue)) == len(set(mro_slots(job_queue))), "duplicate slot" + + def test_application_weakref(self, bot): + jq = JobQueue() + application = ApplicationBuilder().bot(bot).job_queue(None).build() + with pytest.raises(RuntimeError, match='No application was set'): + jq.application + jq.set_application(application) + assert jq.application is application + del application + with pytest.raises(RuntimeError, match='no longer alive'): + jq.application + + @pytest.mark.asyncio + async def test_run_once(self, job_queue): + job_queue.run_once(self.job_run_once, 0.01) + await asyncio.sleep(0.02) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_run_once_timezone(self, job_queue, timezone): + """Test the correct handling of aware datetimes""" + # we're parametrizing this with two different UTC offsets to exclude the possibility + # of an xpass when the test is run in a timezone with the same UTC offset + when = dtm.datetime.now(timezone) + job_queue.run_once(self.job_run_once, when) + await asyncio.sleep(0.1) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_job_with_context(self, job_queue): + job_queue.run_once(self.job_run_once_with_context, 0.01, context=5) + await asyncio.sleep(0.02) + assert self.result == 5 + + @pytest.mark.asyncio + async def test_run_repeating(self, job_queue): + job_queue.run_repeating(self.job_run_once, 0.1) + await asyncio.sleep(0.25) + assert self.result == 2 + + @pytest.mark.asyncio + async def test_run_repeating_first(self, job_queue): + job_queue.run_repeating(self.job_run_once, 0.05, first=0.2) + await asyncio.sleep(0.15) + assert self.result == 0 + await asyncio.sleep(0.07) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_run_repeating_first_timezone(self, job_queue, timezone): + """Test correct scheduling of job when passing a timezone-aware datetime as ``first``""" + job_queue.run_repeating( + self.job_run_once, 0.1, first=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.05) + ) + await asyncio.sleep(0.1) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_run_repeating_last(self, job_queue): + job_queue.run_repeating(self.job_run_once, 0.05, last=0.06) + await asyncio.sleep(0.1) + assert self.result == 1 + await asyncio.sleep(0.1) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_run_repeating_last_timezone(self, job_queue, timezone): + """Test correct scheduling of job when passing a timezone-aware datetime as ``first``""" + job_queue.run_repeating( + self.job_run_once, 0.05, last=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.06) + ) + await asyncio.sleep(0.1) + assert self.result == 1 + await asyncio.sleep(0.1) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_run_repeating_last_before_first(self, job_queue): + with pytest.raises(ValueError, match="'last' must not be before 'first'!"): + job_queue.run_repeating(self.job_run_once, 0.05, first=1, last=0.5) + + @pytest.mark.asyncio + async def test_run_repeating_timedelta(self, job_queue): + job_queue.run_repeating(self.job_run_once, dtm.timedelta(seconds=0.1)) + await asyncio.sleep(0.25) + assert self.result == 2 + + @pytest.mark.asyncio + async def test_run_custom(self, job_queue): + job_queue.run_custom(self.job_run_once, {'trigger': 'interval', 'seconds': 0.02}) + await asyncio.sleep(0.05) + assert self.result == 2 + + @pytest.mark.asyncio + async def test_multiple(self, job_queue): + job_queue.run_once(self.job_run_once, 0.01) + job_queue.run_once(self.job_run_once, 0.02) + job_queue.run_repeating(self.job_run_once, 0.02) + await asyncio.sleep(0.055) + assert self.result == 4 + + @pytest.mark.asyncio + async def test_disabled(self, job_queue): + j1 = job_queue.run_once(self.job_run_once, 0.1) + j2 = job_queue.run_repeating(self.job_run_once, 0.05) + + j1.enabled = False + j2.enabled = False + + await asyncio.sleep(0.06) + + assert self.result == 0 + + j1.enabled = True + + await asyncio.sleep(0.2) + + assert self.result == 1 + + @pytest.mark.asyncio + async def test_schedule_removal(self, job_queue): + j1 = job_queue.run_once(self.job_run_once, 0.03) + j2 = job_queue.run_repeating(self.job_run_once, 0.02) + + await asyncio.sleep(0.025) + + j1.schedule_removal() + j2.schedule_removal() + + await asyncio.sleep(0.04) + + assert self.result == 1 + + @pytest.mark.asyncio + async def test_schedule_removal_from_within(self, job_queue): + job_queue.run_repeating(self.job_remove_self, 0.01) + + await asyncio.sleep(0.05) + + assert self.result == 1 + + @pytest.mark.asyncio + async def test_longer_first(self, job_queue): + job_queue.run_once(self.job_run_once, 0.02) + job_queue.run_once(self.job_run_once, 0.01) + + await asyncio.sleep(0.015) + + assert self.result == 1 + + @pytest.mark.asyncio + async def test_error(self, job_queue): + job_queue.run_repeating(self.job_with_exception, 0.01) + job_queue.run_repeating(self.job_run_once, 0.02) + await asyncio.sleep(0.03) + assert self.result == 1 + + @pytest.mark.asyncio + async def test_in_application(self, bot): + application = ApplicationBuilder().bot(bot).build() + await application.job_queue.start() + try: + application.job_queue.run_repeating(self.job_run_once, 0.02) + await asyncio.sleep(0.03) + assert self.result == 1 + await application.stop() + await asyncio.sleep(1) + assert self.result == 1 + finally: + try: + await application.stop() + except SchedulerNotRunningError: + pass + + @pytest.mark.asyncio + async def test_time_unit_int(self, job_queue): + # Testing seconds in int + delta = 0.05 + expected_time = time.time() + delta + + job_queue.run_once(self.job_datetime_tests, delta) + await asyncio.sleep(0.06) + assert pytest.approx(self.job_time) == expected_time + + @pytest.mark.asyncio + async def test_time_unit_dt_timedelta(self, job_queue): + # Testing seconds, minutes and hours as datetime.timedelta object + # This is sufficient to test that it actually works. + interval = dtm.timedelta(seconds=0.05) + expected_time = time.time() + interval.total_seconds() + + job_queue.run_once(self.job_datetime_tests, interval) + await asyncio.sleep(0.06) + assert pytest.approx(self.job_time) == expected_time + + @pytest.mark.asyncio + async def test_time_unit_dt_datetime(self, job_queue): + # Testing running at a specific datetime + delta, now = dtm.timedelta(seconds=0.05), dtm.datetime.now(pytz.utc) + when = now + delta + expected_time = (now + delta).timestamp() + + job_queue.run_once(self.job_datetime_tests, when) + await asyncio.sleep(0.06) + assert self.job_time == pytest.approx(expected_time) + + @pytest.mark.asyncio + async def test_time_unit_dt_time_today(self, job_queue): + # Testing running at a specific time today + delta, now = 0.05, dtm.datetime.now(pytz.utc) + expected_time = now + dtm.timedelta(seconds=delta) + when = expected_time.time() + expected_time = expected_time.timestamp() + + job_queue.run_once(self.job_datetime_tests, when) + await asyncio.sleep(0.06) + assert self.job_time == pytest.approx(expected_time) + + @pytest.mark.asyncio + async def test_time_unit_dt_time_tomorrow(self, job_queue): + # Testing running at a specific time that has passed today. Since we can't wait a day, we + # test if the job's next scheduled execution time has been calculated correctly + delta, now = -2, dtm.datetime.now(pytz.utc) + when = (now + dtm.timedelta(seconds=delta)).time() + expected_time = (now + dtm.timedelta(seconds=delta, days=1)).timestamp() + + job_queue.run_once(self.job_datetime_tests, when) + scheduled_time = job_queue.jobs()[0].next_t.timestamp() + assert scheduled_time == pytest.approx(expected_time) + + @pytest.mark.asyncio + async def test_run_daily(self, job_queue): + delta, now = 1, dtm.datetime.now(pytz.utc) + time_of_day = (now + dtm.timedelta(seconds=delta)).time() + expected_reschedule_time = (now + dtm.timedelta(seconds=delta, days=1)).timestamp() + + job_queue.run_daily(self.job_run_once, time_of_day) + await asyncio.sleep(delta + 0.1) + assert self.result == 1 + scheduled_time = job_queue.jobs()[0].next_t.timestamp() + assert scheduled_time == pytest.approx(expected_reschedule_time) + + @pytest.mark.asyncio + async def test_run_monthly(self, job_queue, timezone): + delta, now = 1, dtm.datetime.now(timezone) + expected_reschedule_time = now + dtm.timedelta(seconds=delta) + time_of_day = expected_reschedule_time.time().replace(tzinfo=timezone) + + day = now.day + this_months_days = calendar.monthrange(now.year, now.month)[1] + if now.month == 12: + next_months_days = calendar.monthrange(now.year + 1, 1)[1] + else: + next_months_days = calendar.monthrange(now.year, now.month + 1)[1] + + expected_reschedule_time += dtm.timedelta(this_months_days) + if day > next_months_days: + expected_reschedule_time += dtm.timedelta(next_months_days) + + expected_reschedule_time = timezone.normalize(expected_reschedule_time) + # Adjust the hour for the special case that between now and next month a DST switch happens + expected_reschedule_time += dtm.timedelta( + hours=time_of_day.hour - expected_reschedule_time.hour + ) + expected_reschedule_time = expected_reschedule_time.timestamp() + + job_queue.run_monthly(self.job_run_once, time_of_day, day) + await asyncio.sleep(delta + 0.1) + assert self.result == 1 + scheduled_time = job_queue.jobs()[0].next_t.timestamp() + assert scheduled_time == pytest.approx(expected_reschedule_time, rel=1e-3) + + @pytest.mark.asyncio + async def test_run_monthly_non_strict_day(self, job_queue, timezone): + delta, now = 1, dtm.datetime.now(timezone) + expected_reschedule_time = now + dtm.timedelta(seconds=delta) + time_of_day = expected_reschedule_time.time().replace(tzinfo=timezone) + + expected_reschedule_time += dtm.timedelta( + calendar.monthrange(now.year, now.month)[1] + ) - dtm.timedelta(days=now.day) + # Adjust the hour for the special case that between now & end of month a DST switch happens + expected_reschedule_time = timezone.normalize(expected_reschedule_time) + expected_reschedule_time += dtm.timedelta( + hours=time_of_day.hour - expected_reschedule_time.hour + ) + expected_reschedule_time = expected_reschedule_time.timestamp() + + job_queue.run_monthly(self.job_run_once, time_of_day, -1) + scheduled_time = job_queue.jobs()[0].next_t.timestamp() + assert scheduled_time == pytest.approx(expected_reschedule_time) + + @pytest.mark.asyncio + async def test_default_tzinfo(self, app, tz_bot): + # we're parametrizing this with two different UTC offsets to exclude the possibility + # of an xpass when the test is run in a timezone with the same UTC offset + jq = JobQueue() + original_bot = app.bot + app.bot = tz_bot + jq.set_application(app) + try: + jq.start() + + when = dtm.datetime.now(tz_bot.defaults.tzinfo) + dtm.timedelta(seconds=0.0005) + jq.run_once(self.job_run_once, when.time()) + await asyncio.sleep(0.001) + assert self.result == 1 + + jq.stop() + finally: + app.bot = original_bot + + @pytest.mark.asyncio + async def test_get_jobs(self, job_queue): + callback = self.job_run_once + + job1 = job_queue.run_once(callback, 10, name='name1') + job2 = job_queue.run_once(callback, 10, name='name1') + job3 = job_queue.run_once(callback, 10, name='name2') + + assert job_queue.jobs() == (job1, job2, job3) + assert job_queue.get_jobs_by_name('name1') == (job1, job2) + assert job_queue.get_jobs_by_name('name2') == (job3,) + + @pytest.mark.asyncio + async def test_enable_disable_job(self, job_queue): + job = job_queue.run_repeating(self.job_run_once, 0.02) + await asyncio.sleep(0.05) + assert self.result == 2 + job.enabled = False + assert not job.enabled + await asyncio.sleep(0.05) + assert self.result == 2 + job.enabled = True + assert job.enabled + await asyncio.sleep(0.05) + assert self.result == 4 + + @pytest.mark.asyncio + async def test_remove_job(self, job_queue): + job = job_queue.run_repeating(self.job_run_once, 0.02) + await asyncio.sleep(0.05) + assert self.result == 2 + assert not job.removed + job.schedule_removal() + assert job.removed + await asyncio.sleep(0.05) + assert self.result == 2 + + @pytest.mark.asyncio + async def test_job_lt_eq(self, job_queue): + job = job_queue.run_repeating(self.job_run_once, 0.02) + assert not job == job_queue + assert not job < job + + @pytest.mark.asyncio + async def test_dispatch_error_context(self, job_queue, app): + app.add_error_handler(self.error_handler_context) + + job = job_queue.run_once(self.job_with_exception, 0.05) + await asyncio.sleep(0.1) + assert self.received_error[0] == 'Test Error' + assert self.received_error[1] is job + self.received_error = None + await job.run(app) + assert self.received_error[0] == 'Test Error' + assert self.received_error[1] is job + + # Remove handler + app.remove_error_handler(self.error_handler_context) + self.received_error = None + + job = job_queue.run_once(self.job_with_exception, 0.05) + await asyncio.sleep(0.1) + assert self.received_error is None + await job.run(app) + assert self.received_error is None + + @pytest.mark.asyncio + async def test_dispatch_error_that_raises_errors(self, job_queue, app, caplog): + app.add_error_handler(self.error_handler_raise_error) + + with caplog.at_level(logging.ERROR): + job = job_queue.run_once(self.job_with_exception, 0.05) + await asyncio.sleep(0.1) + assert len(caplog.records) == 1 + rec = caplog.records[-1] + assert 'An error was raised and an uncaught' in rec.getMessage() + caplog.clear() + + with caplog.at_level(logging.ERROR): + await job.run(app) + assert len(caplog.records) == 1 + rec = caplog.records[-1] + assert 'uncaught error was raised while handling' in rec.getMessage() + caplog.clear() + + # Remove handler + app.remove_error_handler(self.error_handler_raise_error) + self.received_error = None + + with caplog.at_level(logging.ERROR): + job = job_queue.run_once(self.job_with_exception, 0.05) + await asyncio.sleep(0.1) + assert len(caplog.records) == 1 + rec = caplog.records[-1] + assert 'No error handlers are registered' in rec.getMessage() + caplog.clear() + + with caplog.at_level(logging.ERROR): + await job.run(app) + assert len(caplog.records) == 1 + rec = caplog.records[-1] + assert 'No error handlers are registered' in rec.getMessage() + + @pytest.mark.asyncio + async def test_custom_context(self, bot, job_queue): + application = ( + ApplicationBuilder() + .bot(bot) + .context_types( + ContextTypes( + context=CustomContext, bot_data=int, user_data=float, chat_data=complex + ) + ) + .build() + ) + job_queue.set_application(application) + + def callback(context): + self.result = ( + type(context), + context.user_data, + context.chat_data, + type(context.bot_data), + ) + + job_queue.run_once(callback, 0.1) + await asyncio.sleep(0.15) + assert self.result == (CustomContext, None, None, int) + + @pytest.mark.asyncio + async def test_attribute_error(self): + job = Job(self.job_run_once) + with pytest.raises( + AttributeError, match="nor 'apscheduler.job.Job' has attribute 'error'" + ): + job.error From cc46bbf3aa2e36a98f2cdf01f75e060456a3880e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 22:07:19 +0100 Subject: [PATCH 088/153] Make JobQueue.start asyncio --- telegram/ext/_application.py | 2 +- telegram/ext/_jobqueue.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 1d54124dc56..27d15577eb2 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -421,7 +421,7 @@ async def start(self) -> None: _logger.debug('Loop for updating persistence started') if self.job_queue: - self.job_queue.start() + await self.job_queue.start() _logger.debug('JobQueue started') self.__update_fetcher_task = asyncio.create_task( diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index ae4a2aae75c..4b625545173 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -489,9 +489,8 @@ def run_custom( job.job = j return job - def start(self) -> None: - # TODO: Make this async - not needed yet, but it's probably saver to have it async already - # in case future versions need that + async def start(self) -> None: + # this method async just in case future versions need that """Starts the job_queue thread.""" if not self.scheduler.running: self.scheduler.start() From ececf9fdda4b793183be249e0241c99d9765bc21 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 24 Mar 2022 22:09:20 +0100 Subject: [PATCH 089/153] adjust tests --- tests/test_basepersistence.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index af7f668d6b9..0a7cab203a5 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -679,7 +679,7 @@ async def test_update_persistence_loop_call_count_update_handling( @pytest.mark.asyncio async def test_update_persistence_loop_call_count_job(self, papp: Application, caplog): async with papp: - papp.job_queue.start() + await papp.job_queue.start() papp.job_queue.run_once(self.job_callback(), when=1.5, chat_id=1, user_id=1) await asyncio.sleep(2.5) assert not papp.persistence.updated_bot_data @@ -839,7 +839,7 @@ async def test_update_persistence_loop_saved_data_job(self, papp: Application, c ) async with papp: - papp.job_queue.start() + await papp.job_queue.start() papp.job_queue.run_once( self.job_callback(chat_id=chat_id), when=1.5, chat_id=1, user_id=1 ) @@ -903,7 +903,7 @@ async def test_update_persistence_loop_async_logic( async with papp: if delay_type == 'job': - papp.job_queue.start() + await papp.job_queue.start() papp.job_queue.run_once(self.job_callback(), when=sleep, chat_id=1, user_id=1) elif delay_type == 'handler': papp.add_handler( @@ -1095,7 +1095,7 @@ async def raise_error(*args, **kwargs): assert errors == 0 if delay_type == 'job': - papp.job_queue.start() + await papp.job_queue.start() papp.job_queue.run_once(raise_error, when=sleep, chat_id=1, user_id=1) elif delay_type.endswith('_handler'): papp.add_handler( From 138ba20b71c076b8d3184e677bc7066222699c7d Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 25 Mar 2022 21:23:49 +0100 Subject: [PATCH 090/153] finish up job_queue tests --- tests/test_jobqueue.py | 195 ++++++++++++++++++++++------------------- 1 file changed, 106 insertions(+), 89 deletions(-) diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index e7fa87332bd..5c3fcbe28cb 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -26,7 +26,6 @@ import pytest import pytz -from apscheduler.schedulers import SchedulerNotRunningError from flaky import flaky from telegram.ext import ( JobQueue, @@ -116,8 +115,8 @@ def test_application_weakref(self, bot): @pytest.mark.asyncio async def test_run_once(self, job_queue): - job_queue.run_once(self.job_run_once, 0.01) - await asyncio.sleep(0.02) + job_queue.run_once(self.job_run_once, 0.1) + await asyncio.sleep(0.2) assert self.result == 1 @pytest.mark.asyncio @@ -132,8 +131,8 @@ async def test_run_once_timezone(self, job_queue, timezone): @pytest.mark.asyncio async def test_job_with_context(self, job_queue): - job_queue.run_once(self.job_run_once_with_context, 0.01, context=5) - await asyncio.sleep(0.02) + job_queue.run_once(self.job_run_once_with_context, 0.1, context=5) + await asyncio.sleep(0.2) assert self.result == 5 @pytest.mark.asyncio @@ -144,44 +143,46 @@ async def test_run_repeating(self, job_queue): @pytest.mark.asyncio async def test_run_repeating_first(self, job_queue): - job_queue.run_repeating(self.job_run_once, 0.05, first=0.2) + job_queue.run_repeating(self.job_run_once, 0.5, first=0.2) await asyncio.sleep(0.15) assert self.result == 0 - await asyncio.sleep(0.07) + await asyncio.sleep(0.1) assert self.result == 1 @pytest.mark.asyncio async def test_run_repeating_first_timezone(self, job_queue, timezone): """Test correct scheduling of job when passing a timezone-aware datetime as ``first``""" job_queue.run_repeating( - self.job_run_once, 0.1, first=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.05) + self.job_run_once, 0.5, first=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.2) ) - await asyncio.sleep(0.1) + await asyncio.sleep(0.15) + assert self.result == 0 + await asyncio.sleep(0.2) assert self.result == 1 @pytest.mark.asyncio async def test_run_repeating_last(self, job_queue): - job_queue.run_repeating(self.job_run_once, 0.05, last=0.06) - await asyncio.sleep(0.1) + job_queue.run_repeating(self.job_run_once, 0.25, last=0.4) + await asyncio.sleep(0.3) assert self.result == 1 - await asyncio.sleep(0.1) + await asyncio.sleep(0.4) assert self.result == 1 @pytest.mark.asyncio async def test_run_repeating_last_timezone(self, job_queue, timezone): """Test correct scheduling of job when passing a timezone-aware datetime as ``first``""" job_queue.run_repeating( - self.job_run_once, 0.05, last=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.06) + self.job_run_once, 0.25, last=dtm.datetime.now(timezone) + dtm.timedelta(seconds=0.4) ) - await asyncio.sleep(0.1) + await asyncio.sleep(0.3) assert self.result == 1 - await asyncio.sleep(0.1) + await asyncio.sleep(0.4) assert self.result == 1 @pytest.mark.asyncio async def test_run_repeating_last_before_first(self, job_queue): with pytest.raises(ValueError, match="'last' must not be before 'first'!"): - job_queue.run_repeating(self.job_run_once, 0.05, first=1, last=0.5) + job_queue.run_repeating(self.job_run_once, 0.5, first=1, last=0.5) @pytest.mark.asyncio async def test_run_repeating_timedelta(self, job_queue): @@ -191,133 +192,132 @@ async def test_run_repeating_timedelta(self, job_queue): @pytest.mark.asyncio async def test_run_custom(self, job_queue): - job_queue.run_custom(self.job_run_once, {'trigger': 'interval', 'seconds': 0.02}) - await asyncio.sleep(0.05) + job_queue.run_custom(self.job_run_once, {'trigger': 'interval', 'seconds': 0.2}) + await asyncio.sleep(0.5) assert self.result == 2 @pytest.mark.asyncio async def test_multiple(self, job_queue): - job_queue.run_once(self.job_run_once, 0.01) - job_queue.run_once(self.job_run_once, 0.02) - job_queue.run_repeating(self.job_run_once, 0.02) - await asyncio.sleep(0.055) + job_queue.run_once(self.job_run_once, 0.1) + job_queue.run_once(self.job_run_once, 0.2) + job_queue.run_repeating(self.job_run_once, 0.2) + await asyncio.sleep(0.55) assert self.result == 4 @pytest.mark.asyncio async def test_disabled(self, job_queue): j1 = job_queue.run_once(self.job_run_once, 0.1) - j2 = job_queue.run_repeating(self.job_run_once, 0.05) + j2 = job_queue.run_repeating(self.job_run_once, 0.5) j1.enabled = False j2.enabled = False - await asyncio.sleep(0.06) + await asyncio.sleep(0.6) assert self.result == 0 j1.enabled = True - await asyncio.sleep(0.2) + await asyncio.sleep(0.6) assert self.result == 1 @pytest.mark.asyncio async def test_schedule_removal(self, job_queue): - j1 = job_queue.run_once(self.job_run_once, 0.03) - j2 = job_queue.run_repeating(self.job_run_once, 0.02) + j1 = job_queue.run_once(self.job_run_once, 0.3) + j2 = job_queue.run_repeating(self.job_run_once, 0.2) - await asyncio.sleep(0.025) + await asyncio.sleep(0.25) j1.schedule_removal() j2.schedule_removal() - await asyncio.sleep(0.04) + await asyncio.sleep(0.4) assert self.result == 1 @pytest.mark.asyncio async def test_schedule_removal_from_within(self, job_queue): - job_queue.run_repeating(self.job_remove_self, 0.01) + job_queue.run_repeating(self.job_remove_self, 0.1) - await asyncio.sleep(0.05) + await asyncio.sleep(0.5) assert self.result == 1 @pytest.mark.asyncio async def test_longer_first(self, job_queue): - job_queue.run_once(self.job_run_once, 0.02) - job_queue.run_once(self.job_run_once, 0.01) + job_queue.run_once(self.job_run_once, 0.2) + job_queue.run_once(self.job_run_once, 0.1) - await asyncio.sleep(0.015) + await asyncio.sleep(0.15) assert self.result == 1 @pytest.mark.asyncio async def test_error(self, job_queue): - job_queue.run_repeating(self.job_with_exception, 0.01) - job_queue.run_repeating(self.job_run_once, 0.02) - await asyncio.sleep(0.03) + job_queue.run_repeating(self.job_with_exception, 0.1) + job_queue.run_repeating(self.job_run_once, 0.2) + await asyncio.sleep(0.3) assert self.result == 1 @pytest.mark.asyncio async def test_in_application(self, bot): - application = ApplicationBuilder().bot(bot).build() - await application.job_queue.start() - try: - application.job_queue.run_repeating(self.job_run_once, 0.02) - await asyncio.sleep(0.03) + app = ApplicationBuilder().bot(bot).build() + async with app: + assert not app.job_queue.scheduler.running + await app.start() + assert app.job_queue.scheduler.running + + app.job_queue.run_repeating(self.job_run_once, 0.2) + await asyncio.sleep(0.3) assert self.result == 1 - await application.stop() + await app.stop() + assert not app.job_queue.scheduler.running await asyncio.sleep(1) assert self.result == 1 - finally: - try: - await application.stop() - except SchedulerNotRunningError: - pass @pytest.mark.asyncio async def test_time_unit_int(self, job_queue): # Testing seconds in int - delta = 0.05 + delta = 0.5 expected_time = time.time() + delta job_queue.run_once(self.job_datetime_tests, delta) - await asyncio.sleep(0.06) + await asyncio.sleep(0.6) assert pytest.approx(self.job_time) == expected_time @pytest.mark.asyncio async def test_time_unit_dt_timedelta(self, job_queue): # Testing seconds, minutes and hours as datetime.timedelta object # This is sufficient to test that it actually works. - interval = dtm.timedelta(seconds=0.05) + interval = dtm.timedelta(seconds=0.5) expected_time = time.time() + interval.total_seconds() job_queue.run_once(self.job_datetime_tests, interval) - await asyncio.sleep(0.06) + await asyncio.sleep(0.6) assert pytest.approx(self.job_time) == expected_time @pytest.mark.asyncio async def test_time_unit_dt_datetime(self, job_queue): # Testing running at a specific datetime - delta, now = dtm.timedelta(seconds=0.05), dtm.datetime.now(pytz.utc) + delta, now = dtm.timedelta(seconds=0.5), dtm.datetime.now(pytz.utc) when = now + delta expected_time = (now + delta).timestamp() job_queue.run_once(self.job_datetime_tests, when) - await asyncio.sleep(0.06) + await asyncio.sleep(0.6) assert self.job_time == pytest.approx(expected_time) @pytest.mark.asyncio async def test_time_unit_dt_time_today(self, job_queue): # Testing running at a specific time today - delta, now = 0.05, dtm.datetime.now(pytz.utc) + delta, now = 0.5, dtm.datetime.now(pytz.utc) expected_time = now + dtm.timedelta(seconds=delta) when = expected_time.time() expected_time = expected_time.timestamp() job_queue.run_once(self.job_datetime_tests, when) - await asyncio.sleep(0.06) + await asyncio.sleep(0.6) assert self.job_time == pytest.approx(expected_time) @pytest.mark.asyncio @@ -395,24 +395,19 @@ async def test_run_monthly_non_strict_day(self, job_queue, timezone): assert scheduled_time == pytest.approx(expected_reschedule_time) @pytest.mark.asyncio - async def test_default_tzinfo(self, app, tz_bot): + async def test_default_tzinfo(self, tz_bot): # we're parametrizing this with two different UTC offsets to exclude the possibility # of an xpass when the test is run in a timezone with the same UTC offset - jq = JobQueue() - original_bot = app.bot - app.bot = tz_bot - jq.set_application(app) - try: - jq.start() - - when = dtm.datetime.now(tz_bot.defaults.tzinfo) + dtm.timedelta(seconds=0.0005) - jq.run_once(self.job_run_once, when.time()) - await asyncio.sleep(0.001) - assert self.result == 1 + app = ApplicationBuilder().bot(tz_bot).build() + jq = app.job_queue + await jq.start() + + when = dtm.datetime.now(tz_bot.defaults.tzinfo) + dtm.timedelta(seconds=0.1) + jq.run_once(self.job_run_once, when.time()) + await asyncio.sleep(0.15) + assert self.result == 1 - jq.stop() - finally: - app.bot = original_bot + await jq.stop() @pytest.mark.asyncio async def test_get_jobs(self, job_queue): @@ -428,32 +423,32 @@ async def test_get_jobs(self, job_queue): @pytest.mark.asyncio async def test_enable_disable_job(self, job_queue): - job = job_queue.run_repeating(self.job_run_once, 0.02) - await asyncio.sleep(0.05) + job = job_queue.run_repeating(self.job_run_once, 0.2) + await asyncio.sleep(0.5) assert self.result == 2 job.enabled = False assert not job.enabled - await asyncio.sleep(0.05) + await asyncio.sleep(0.5) assert self.result == 2 job.enabled = True assert job.enabled - await asyncio.sleep(0.05) + await asyncio.sleep(0.5) assert self.result == 4 @pytest.mark.asyncio async def test_remove_job(self, job_queue): - job = job_queue.run_repeating(self.job_run_once, 0.02) - await asyncio.sleep(0.05) + job = job_queue.run_repeating(self.job_run_once, 0.2) + await asyncio.sleep(0.5) assert self.result == 2 assert not job.removed job.schedule_removal() assert job.removed - await asyncio.sleep(0.05) + await asyncio.sleep(0.5) assert self.result == 2 @pytest.mark.asyncio async def test_job_lt_eq(self, job_queue): - job = job_queue.run_repeating(self.job_run_once, 0.02) + job = job_queue.run_repeating(self.job_run_once, 0.2) assert not job == job_queue assert not job < job @@ -461,8 +456,8 @@ async def test_job_lt_eq(self, job_queue): async def test_dispatch_error_context(self, job_queue, app): app.add_error_handler(self.error_handler_context) - job = job_queue.run_once(self.job_with_exception, 0.05) - await asyncio.sleep(0.1) + job = job_queue.run_once(self.job_with_exception, 0.1) + await asyncio.sleep(0.15) assert self.received_error[0] == 'Test Error' assert self.received_error[1] is job self.received_error = None @@ -474,8 +469,8 @@ async def test_dispatch_error_context(self, job_queue, app): app.remove_error_handler(self.error_handler_context) self.received_error = None - job = job_queue.run_once(self.job_with_exception, 0.05) - await asyncio.sleep(0.1) + job = job_queue.run_once(self.job_with_exception, 0.1) + await asyncio.sleep(0.15) assert self.received_error is None await job.run(app) assert self.received_error is None @@ -485,9 +480,9 @@ async def test_dispatch_error_that_raises_errors(self, job_queue, app, caplog): app.add_error_handler(self.error_handler_raise_error) with caplog.at_level(logging.ERROR): - job = job_queue.run_once(self.job_with_exception, 0.05) - await asyncio.sleep(0.1) - assert len(caplog.records) == 1 + job = job_queue.run_once(self.job_with_exception, 0.1) + await asyncio.sleep(0.15) + assert len(caplog.records) == 2 rec = caplog.records[-1] assert 'An error was raised and an uncaught' in rec.getMessage() caplog.clear() @@ -504,9 +499,9 @@ async def test_dispatch_error_that_raises_errors(self, job_queue, app, caplog): self.received_error = None with caplog.at_level(logging.ERROR): - job = job_queue.run_once(self.job_with_exception, 0.05) - await asyncio.sleep(0.1) - assert len(caplog.records) == 1 + job = job_queue.run_once(self.job_with_exception, 0.1) + await asyncio.sleep(0.15) + assert len(caplog.records) == 2 rec = caplog.records[-1] assert 'No error handlers are registered' in rec.getMessage() caplog.clear() @@ -550,3 +545,25 @@ async def test_attribute_error(self): AttributeError, match="nor 'apscheduler.job.Job' has attribute 'error'" ): job.error + + @pytest.mark.asyncio + @pytest.mark.parametrize('wait', (True, False)) + async def test_wait_on_shut_down(self, job_queue, wait): + ready_event = asyncio.Event() + + async def callback(_): + await ready_event.wait() + + await job_queue.start() + job_queue.run_once(callback, when=0.1) + await asyncio.sleep(0.15) + + task = asyncio.create_task(job_queue.stop(wait=wait)) + if wait: + assert not task.done() + ready_event.set() + await asyncio.sleep(0.1) + assert task.done() + else: + await asyncio.sleep(0.1) + assert task.done() From 363bfefab3db50bf41462d421aa524887b456a19 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 25 Mar 2022 21:24:17 +0100 Subject: [PATCH 091/153] some JQ tweaks --- telegram/ext/_jobqueue.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 4b625545173..21266d16418 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -103,7 +103,10 @@ def set_application(self, application: 'Application') -> None: """ self._application = weakref.ref(application) if isinstance(application.bot, ExtBot) and application.bot.defaults: - self.scheduler.configure(timezone=application.bot.defaults.tzinfo or pytz.utc) + self.scheduler.configure( + timezone=application.bot.defaults.tzinfo or pytz.utc, + executors={'default': self._executor}, + ) @property def application(self) -> 'Application': @@ -484,7 +487,7 @@ def run_custom( name = name or callback.__name__ job = Job(callback=callback, context=context, name=name, chat_id=chat_id, user_id=user_id) - j = self.scheduler.add_job(job, args=(self.application,), name=name, **job_kwargs) + j = self.scheduler.add_job(job.run, args=(self.application,), name=name, **job_kwargs) job.job = j return job @@ -514,7 +517,7 @@ async def stop(self, wait: bool = True) -> None: ) if self.scheduler.running: self.scheduler.shutdown(wait=wait) - # scheduler.shutdown schedules a task in the event loop but immediatel returns + # scheduler.shutdown schedules a task in the event loop but immediately returns # so give it a tiny bit of time to actually shut down. await asyncio.sleep(0.01) From a852ec6e9f91dc6767662ad4fe0b3065abb57e65 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 26 Mar 2022 17:00:39 +0100 Subject: [PATCH 092/153] Test all handles except for CH --- tests/test_callbackqueryhandler.py | 224 ++++++++++++++ tests/test_chatjoinrequesthandler.py | 143 +++++++++ tests/test_chatmemberhandler.py | 157 ++++++++++ tests/test_choseninlineresulthandler.py | 163 ++++++++++ tests/test_commandhandler.py | 395 ++++++++++++++++++++++++ tests/test_inlinequeryhandler.py | 165 ++++++++++ tests/test_messagehandler.py | 225 ++++++++++++++ tests/test_pollanswerhandler.py | 111 +++++++ tests/test_pollhandler.py | 124 ++++++++ tests/test_precheckoutqueryhandler.py | 116 +++++++ tests/test_shippingqueryhandler.py | 120 +++++++ tests/test_stringcommandhandler.py | 122 ++++++++ tests/test_stringregexhandler.py | 129 ++++++++ tests/test_typehandler.py | 68 ++++ 14 files changed, 2262 insertions(+) create mode 100644 tests/test_callbackqueryhandler.py create mode 100644 tests/test_chatjoinrequesthandler.py create mode 100644 tests/test_chatmemberhandler.py create mode 100644 tests/test_choseninlineresulthandler.py create mode 100644 tests/test_commandhandler.py create mode 100644 tests/test_inlinequeryhandler.py create mode 100644 tests/test_messagehandler.py create mode 100644 tests/test_pollanswerhandler.py create mode 100644 tests/test_pollhandler.py create mode 100644 tests/test_precheckoutqueryhandler.py create mode 100644 tests/test_shippingqueryhandler.py create mode 100644 tests/test_stringcommandhandler.py create mode 100644 tests/test_stringregexhandler.py create mode 100644 tests/test_typehandler.py diff --git a/tests/test_callbackqueryhandler.py b/tests/test_callbackqueryhandler.py new file mode 100644 index 00000000000..34666c9a3da --- /dev/null +++ b/tests/test_callbackqueryhandler.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + CallbackQuery, + Bot, + Message, + User, + Chat, + InlineQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import CallbackQueryHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, +] + +ids = ( + 'message', + 'edited_message', + 'channel_post', + 'edited_channel_post', + 'inline_query', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=2, **request.param) + + +@pytest.fixture(scope='function') +def callback_query(bot): + return Update(0, callback_query=CallbackQuery(2, User(1, '', False), None, data='test data')) + + +class TestCallbackQueryHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + handler = CallbackQueryHandler(self.callback_data_1) + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + def callback_basic(self, update, context): + test_bot = isinstance(context.bot, Bot) + test_update = isinstance(update, Update) + self.test_flag = test_bot and test_update + + def callback_data_1(self, bot, update, user_data=None, chat_data=None): + self.test_flag = (user_data is not None) or (chat_data is not None) + + def callback_data_2(self, bot, update, user_data=None, chat_data=None): + self.test_flag = (user_data is not None) and (chat_data is not None) + + def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): + self.test_flag = (job_queue is not None) or (update_queue is not None) + + def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): + self.test_flag = (job_queue is not None) and (update_queue is not None) + + def callback_group(self, bot, update, groups=None, groupdict=None): + if groups is not None: + self.test_flag = groups == ('t', ' data') + if groupdict is not None: + self.test_flag = groupdict == {'begin': 't', 'end': ' data'} + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.callback_query, CallbackQuery) + ) + + def callback_pattern(self, update, context): + if context.matches[0].groups(): + self.test_flag = context.matches[0].groups() == ('t', ' data') + if context.matches[0].groupdict(): + self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' data'} + + def test_with_pattern(self, callback_query): + handler = CallbackQueryHandler(self.callback_basic, pattern='.*est.*') + + assert handler.check_update(callback_query) + + callback_query.callback_query.data = 'nothing here' + assert not handler.check_update(callback_query) + + callback_query.callback_query.data = None + callback_query.callback_query.game_short_name = "this is a short game name" + assert not handler.check_update(callback_query) + + def test_with_callable_pattern(self, callback_query): + class CallbackData: + pass + + def pattern(callback_data): + return isinstance(callback_data, CallbackData) + + handler = CallbackQueryHandler(self.callback_basic, pattern=pattern) + + callback_query.callback_query.data = CallbackData() + assert handler.check_update(callback_query) + callback_query.callback_query.data = 'callback_data' + assert not handler.check_update(callback_query) + + def test_with_type_pattern(self, callback_query): + class CallbackData: + pass + + handler = CallbackQueryHandler(self.callback_basic, pattern=CallbackData) + + callback_query.callback_query.data = CallbackData() + assert handler.check_update(callback_query) + callback_query.callback_query.data = 'callback_data' + assert not handler.check_update(callback_query) + + handler = CallbackQueryHandler(self.callback_basic, pattern=bool) + + callback_query.callback_query.data = False + assert handler.check_update(callback_query) + callback_query.callback_query.data = 'callback_data' + assert not handler.check_update(callback_query) + + def test_other_update_types(self, false_update): + handler = CallbackQueryHandler(self.callback_basic) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, callback_query): + handler = CallbackQueryHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(callback_query) + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_pattern(self, app, callback_query): + handler = CallbackQueryHandler( + self.callback_pattern, pattern=r'(?P.*)est(?P.*)' + ) + app.add_handler(handler) + + async with app: + await app.process_update(callback_query) + assert self.test_flag + + app.remove_handler(handler) + handler = CallbackQueryHandler(self.callback_pattern, pattern=r'(t)est(.*)') + app.add_handler(handler) + + await app.process_update(callback_query) + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_callable_pattern(self, app, callback_query): + class CallbackData: + pass + + def pattern(callback_data): + return isinstance(callback_data, CallbackData) + + def callback(update, context): + assert context.matches is None + + handler = CallbackQueryHandler(callback, pattern=pattern) + app.add_handler(handler) + + async with app: + await app.process_update(callback_query) + + def test_async_pattern(self): + async def pattern(): + pass + + with pytest.raises(TypeError, match='must not be a coroutine function'): + CallbackQueryHandler(self.callback, pattern=pattern) diff --git a/tests/test_chatjoinrequesthandler.py b/tests/test_chatjoinrequesthandler.py new file mode 100644 index 00000000000..ccdee344bb2 --- /dev/null +++ b/tests/test_chatjoinrequesthandler.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import datetime +import asyncio + +import pytest +import pytz + +from telegram import ( + Update, + Bot, + Message, + User, + Chat, + CallbackQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, + ChatJoinRequest, + ChatInviteLink, +) +from telegram.ext import CallbackContext, JobQueue, ChatJoinRequestHandler + + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=2, **request.param) + + +@pytest.fixture(scope='class') +def time(): + return datetime.datetime.now(tz=pytz.utc) + + +@pytest.fixture(scope='class') +def chat_join_request(time, bot): + return ChatJoinRequest( + chat=Chat(1, Chat.SUPERGROUP), + from_user=User(2, 'first_name', False), + date=time, + bio='bio', + invite_link=ChatInviteLink( + 'https://invite.link', + User(42, 'creator', False), + creates_join_request=False, + name='InviteLink', + is_revoked=False, + is_primary=False, + ), + bot=bot, + ) + + +@pytest.fixture(scope='function') +def chat_join_request_update(bot, chat_join_request): + return Update(0, chat_join_request=chat_join_request) + + +class TestChatJoinRequestHandler: + test_flag = False + + def test_slot_behaviour(self, recwarn, mro_slots): + action = ChatJoinRequestHandler(self.callback) + for attr in action.__slots__: + assert getattr(action, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(action)) == len(set(mro_slots(action))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and isinstance(context.chat_data, dict) + and isinstance(context.bot_data, dict) + and isinstance( + update.chat_join_request, + ChatJoinRequest, + ) + ) + + def test_other_update_types(self, false_update): + handler = ChatJoinRequestHandler(self.callback) + assert not handler.check_update(false_update) + assert not handler.check_update(True) + + @pytest.mark.asyncio + async def test_context(self, app, chat_join_request_update): + handler = ChatJoinRequestHandler(callback=self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(chat_join_request_update) + assert self.test_flag diff --git a/tests/test_chatmemberhandler.py b/tests/test_chatmemberhandler.py new file mode 100644 index 00000000000..8db1733c761 --- /dev/null +++ b/tests/test_chatmemberhandler.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import time +import asyncio + +import pytest + +from telegram import ( + Update, + Bot, + Message, + User, + Chat, + CallbackQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, + ChatMemberUpdated, + ChatMember, +) +from telegram.ext import CallbackContext, JobQueue, ChatMemberHandler +from telegram._utils.datetime import from_timestamp + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=2, **request.param) + + +@pytest.fixture(scope='class') +def chat_member_updated(): + return ChatMemberUpdated( + Chat(1, 'chat'), + User(1, '', False), + from_timestamp(int(time.time())), + ChatMember(User(1, '', False), ChatMember.CREATOR), + ChatMember(User(1, '', False), ChatMember.CREATOR), + ) + + +@pytest.fixture(scope='function') +def chat_member(bot, chat_member_updated): + return Update(0, my_chat_member=chat_member_updated) + + +class TestChatMemberHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + action = ChatMemberHandler(self.callback) + for attr in action.__slots__: + assert getattr(action, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(action)) == len(set(mro_slots(action))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and isinstance(context.chat_data, dict) + and isinstance(context.bot_data, dict) + and isinstance(update.chat_member or update.my_chat_member, ChatMemberUpdated) + ) + + @pytest.mark.parametrize( + argnames=['allowed_types', 'expected'], + argvalues=[ + (ChatMemberHandler.MY_CHAT_MEMBER, (True, False)), + (ChatMemberHandler.CHAT_MEMBER, (False, True)), + (ChatMemberHandler.ANY_CHAT_MEMBER, (True, True)), + ], + ids=['MY_CHAT_MEMBER', 'CHAT_MEMBER', 'ANY_CHAT_MEMBER'], + ) + @pytest.mark.asyncio + async def test_chat_member_types( + self, app, chat_member_updated, chat_member, expected, allowed_types + ): + result_1, result_2 = expected + + handler = ChatMemberHandler(self.callback, chat_member_types=allowed_types) + app.add_handler(handler) + + async with app: + assert handler.check_update(chat_member) == result_1 + await app.process_update(chat_member) + assert self.test_flag == result_1 + + self.test_flag = False + chat_member.my_chat_member = None + chat_member.chat_member = chat_member_updated + + assert handler.check_update(chat_member) == result_2 + await app.process_update(chat_member) + assert self.test_flag == result_2 + + def test_other_update_types(self, false_update): + handler = ChatMemberHandler(self.callback) + assert not handler.check_update(false_update) + assert not handler.check_update(True) + + @pytest.mark.asyncio + async def test_context(self, app, chat_member): + handler = ChatMemberHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(chat_member) + assert self.test_flag diff --git a/tests/test_choseninlineresulthandler.py b/tests/test_choseninlineresulthandler.py new file mode 100644 index 00000000000..ed0b7aea100 --- /dev/null +++ b/tests/test_choseninlineresulthandler.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + Chat, + Bot, + ChosenInlineResult, + User, + Message, + CallbackQuery, + InlineQuery, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import ChosenInlineResultHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'inline_query', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=1, **request.param) + + +@pytest.fixture(scope='class') +def chosen_inline_result(): + return Update( + 1, + chosen_inline_result=ChosenInlineResult('result_id', User(1, 'test_user', False), 'query'), + ) + + +class TestChosenInlineResultHandler: + test_flag = False + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + def test_slot_behaviour(self, mro_slots): + handler = ChosenInlineResultHandler(self.callback_basic) + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + def callback_basic(self, update, context): + test_bot = isinstance(context.bot, Bot) + test_update = isinstance(update, Update) + self.test_flag = test_bot and test_update + + def callback_data_1(self, bot, update, user_data=None, chat_data=None): + self.test_flag = (user_data is not None) or (chat_data is not None) + + def callback_data_2(self, bot, update, user_data=None, chat_data=None): + self.test_flag = (user_data is not None) and (chat_data is not None) + + def callback_queue_1(self, bot, update, job_queue=None, update_queue=None): + self.test_flag = (job_queue is not None) or (update_queue is not None) + + def callback_queue_2(self, bot, update, job_queue=None, update_queue=None): + self.test_flag = (job_queue is not None) and (update_queue is not None) + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.chosen_inline_result, ChosenInlineResult) + ) + + def callback_pattern(self, update, context): + if context.matches[0].groups(): + self.test_flag = context.matches[0].groups() == ('res', '_id') + if context.matches[0].groupdict(): + self.test_flag = context.matches[0].groupdict() == {'begin': 'res', 'end': '_id'} + + def test_other_update_types(self, false_update): + handler = ChosenInlineResultHandler(self.callback_basic) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, chosen_inline_result): + handler = ChosenInlineResultHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(chosen_inline_result) + assert self.test_flag + + def test_with_pattern(self, chosen_inline_result): + handler = ChosenInlineResultHandler(self.callback_basic, pattern='.*ult.*') + + assert handler.check_update(chosen_inline_result) + + chosen_inline_result.chosen_inline_result.result_id = 'nothing here' + assert not handler.check_update(chosen_inline_result) + chosen_inline_result.chosen_inline_result.result_id = 'result_id' + + @pytest.mark.asyncio + async def test_context_pattern(self, app, chosen_inline_result): + handler = ChosenInlineResultHandler( + self.callback_pattern, pattern=r'(?P.*)ult(?P.*)' + ) + app.add_handler(handler) + async with app: + await app.process_update(chosen_inline_result) + assert self.test_flag + + app.remove_handler(handler) + handler = ChosenInlineResultHandler(self.callback_pattern, pattern=r'(res)ult(.*)') + app.add_handler(handler) + + await app.process_update(chosen_inline_result) + assert self.test_flag diff --git a/tests/test_commandhandler.py b/tests/test_commandhandler.py new file mode 100644 index 00000000000..2dd8c6afb0a --- /dev/null +++ b/tests/test_commandhandler.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import re +import asyncio + +import pytest + +from telegram import Message, Update, Chat, Bot +from telegram.ext import CommandHandler, filters, CallbackContext, JobQueue, PrefixHandler +from tests.conftest import ( + make_command_message, + make_command_update, + make_message, + make_message_update, +) + + +def is_match(handler, update): + """ + Utility function that returns whether an update matched + against a specific handler. + :param handler: ``CommandHandler`` to check against + :param update: update to check + :return: (bool) whether ``update`` matched with ``handler`` + """ + check = handler.check_update(update) + return check is not None and check is not False + + +class BaseTest: + """Base class for command and prefix handler test classes. Contains + utility methods an several callbacks used by both classes.""" + + test_flag = False + SRE_TYPE = type(re.match("", "")) + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def response(self, application, update): + """ + Utility to send an update to a dispatcher and assert + whether the callback was called appropriately. Its purpose is + for repeated usage in the same test function. + """ + self.test_flag = False + async with application: + await application.process_update(update) + return self.test_flag + + def callback_basic(self, update, context): + test_bot = isinstance(context.bot, Bot) + test_update = isinstance(update, Update) + self.test_flag = test_bot and test_update + + def make_callback_for(self, pass_keyword): + def callback(bot, update, **kwargs): + self.test_flag = kwargs.get(keyword) is not None + + keyword = pass_keyword[5:] + return callback + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and isinstance(context.chat_data, dict) + and isinstance(context.bot_data, dict) + and isinstance(update.message, Message) + ) + + def callback_args(self, update, context): + self.test_flag = context.args == ['one', 'two'] + + def callback_regex1(self, update, context): + if context.matches: + types = all(type(res) is self.SRE_TYPE for res in context.matches) + num = len(context.matches) == 1 + self.test_flag = types and num + + def callback_regex2(self, update, context): + if context.matches: + types = all(type(res) is self.SRE_TYPE for res in context.matches) + num = len(context.matches) == 2 + self.test_flag = types and num + + async def _test_context_args_or_regex(self, app, handler, text): + app.add_handler(handler) + update = make_command_update(text, bot=app.bot) + assert not await self.response(app, update) + update.message.text += ' one two' + assert await self.response(app, update) + + def _test_edited(self, message, handler_edited, handler_not_edited): + """ + Assert whether a handler that should accept edited messages + and a handler that shouldn't work correctly. + :param message: ``telegram.Message`` to check against the handlers + :param handler_edited: handler that should accept edited messages + :param handler_not_edited: handler that should not accept edited messages + """ + update = make_command_update(message) + edited_update = make_command_update(message, edited=True) + + assert is_match(handler_edited, update) + assert is_match(handler_edited, edited_update) + assert is_match(handler_not_edited, update) + assert not is_match(handler_not_edited, edited_update) + + +# ----------------------------- CommandHandler ----------------------------- + + +class TestCommandHandler(BaseTest): + CMD = '/test' + + def test_slot_behaviour(self, mro_slots): + handler = self.make_default_handler() + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + @pytest.fixture(scope='class') + def command(self): + return self.CMD + + @pytest.fixture(scope='class') + def command_message(self, command, bot): + return make_command_message(command, bot=bot) + + @pytest.fixture(scope='class') + def command_update(self, command_message): + return make_command_update(command_message) + + def make_default_handler(self, callback=None, **kwargs): + callback = callback or self.callback_basic + return CommandHandler(self.CMD[1:], callback, **kwargs) + + @pytest.mark.asyncio + async def test_basic(self, app, command): + """Test whether a command handler responds to its command + and not to others, or badly formatted commands""" + handler = self.make_default_handler() + app.add_handler(handler) + + assert await self.response(app, make_command_update(command, bot=app.bot)) + assert not is_match(handler, make_command_update(command[1:], bot=app.bot)) + assert not is_match(handler, make_command_update(f'/not{command[1:]}', bot=app.bot)) + assert not is_match(handler, make_command_update(f'not {command} at start', bot=app.bot)) + + @pytest.mark.parametrize( + 'cmd', + ['way_too_longcommand1234567yes_way_toooooooLong', 'ïñválídletters', 'invalid #&* chars'], + ids=['too long', 'invalid letter', 'invalid characters'], + ) + def test_invalid_commands(self, cmd): + with pytest.raises( + ValueError, match=f'`{re.escape(cmd.lower())}` is not a valid bot command' + ): + CommandHandler(cmd, self.callback_basic) + + def test_command_list(self, bot): + """A command handler with multiple commands registered should respond to all of them.""" + handler = CommandHandler(['test', 'star'], self.callback_basic) + assert is_match(handler, make_command_update('/test', bot=bot)) + assert is_match(handler, make_command_update('/star', bot=bot)) + assert not is_match(handler, make_command_update('/stop', bot=bot)) + + def test_edited(self, command_message): + """Test that a CH responds to an edited message if its filters allow it""" + handler_edited = self.make_default_handler() + handler_no_edited = self.make_default_handler(filters=~filters.UpdateType.EDITED_MESSAGE) + self._test_edited(command_message, handler_edited, handler_no_edited) + + def test_directed_commands(self, bot, command): + """Test recognition of commands with a mention to the bot""" + handler = self.make_default_handler() + assert is_match(handler, make_command_update(command + '@' + bot.username, bot=bot)) + assert not is_match(handler, make_command_update(command + '@otherbot', bot=bot)) + + def test_with_filter(self, command, bot): + """Test that a CH with a (generic) filter responds if its filters match""" + handler = self.make_default_handler(filters=filters.ChatType.GROUP) + assert is_match(handler, make_command_update(command, chat=Chat(-23, Chat.GROUP), bot=bot)) + assert not is_match( + handler, make_command_update(command, chat=Chat(23, Chat.PRIVATE), bot=bot) + ) + + @pytest.mark.asyncio + async def test_newline(self, app, command): + """Assert that newlines don't interfere with a command handler matching a message""" + handler = self.make_default_handler() + app.add_handler(handler) + update = make_command_update(command + '\nfoobar', bot=app.bot) + async with app: + assert is_match(handler, update) + assert await self.response(app, update) + + def test_other_update_types(self, false_update): + """Test that a command handler doesn't respond to unrelated updates""" + handler = self.make_default_handler() + assert not is_match(handler, false_update) + + def test_filters_for_wrong_command(self, mock_filter, bot): + """Filters should not be executed if the command does not match the handler""" + handler = self.make_default_handler(filters=mock_filter) + assert not is_match(handler, make_command_update('/star', bot=bot)) + assert not mock_filter.tested + + @pytest.mark.asyncio + async def test_context(self, app, command_update): + """Test correct behaviour of CHs with context-based callbacks""" + handler = self.make_default_handler(self.callback) + app.add_handler(handler) + assert await self.response(app, command_update) + + @pytest.mark.asyncio + async def test_context_args(self, app, command): + """Test CHs that pass arguments through ``context``""" + handler = self.make_default_handler(self.callback_args) + await self._test_context_args_or_regex(app, handler, command) + + @pytest.mark.asyncio + async def test_context_regex(self, app, command): + """Test CHs with context-based callbacks and a single filter""" + handler = self.make_default_handler(self.callback_regex1, filters=filters.Regex('one two')) + await self._test_context_args_or_regex(app, handler, command) + + @pytest.mark.asyncio + async def test_context_multiple_regex(self, app, command): + """Test CHs with context-based callbacks and filters combined""" + handler = self.make_default_handler( + self.callback_regex2, filters=filters.Regex('one') & filters.Regex('two') + ) + await self._test_context_args_or_regex(app, handler, command) + + +# ----------------------------- PrefixHandler ----------------------------- + + +def combinations(prefixes, commands): + return (prefix + command for prefix in prefixes for command in commands) + + +class TestPrefixHandler(BaseTest): + # Prefixes and commands with which to test PrefixHandler: + PREFIXES = ['!', '#', 'mytrig-'] + COMMANDS = ['help', 'test'] + COMBINATIONS = list(combinations(PREFIXES, COMMANDS)) + + def test_slot_behaviour(self, mro_slots): + handler = self.make_default_handler() + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + @pytest.fixture(scope='class', params=PREFIXES) + def prefix(self, request): + return request.param + + @pytest.fixture(scope='class', params=[1, 2], ids=['single prefix', 'multiple prefixes']) + def prefixes(self, request): + return TestPrefixHandler.PREFIXES[: request.param] + + @pytest.fixture(scope='class', params=COMMANDS) + def command(self, request): + return request.param + + @pytest.fixture(scope='class', params=[1, 2], ids=['single command', 'multiple commands']) + def commands(self, request): + return TestPrefixHandler.COMMANDS[: request.param] + + @pytest.fixture(scope='class') + def prefix_message_text(self, prefix, command): + return prefix + command + + @pytest.fixture(scope='class') + def prefix_message(self, prefix_message_text): + return make_message(prefix_message_text) + + @pytest.fixture(scope='class') + def prefix_message_update(self, prefix_message): + return make_message_update(prefix_message) + + def make_default_handler(self, callback=None, **kwargs): + callback = callback or self.callback_basic + return PrefixHandler(self.PREFIXES, self.COMMANDS, callback, **kwargs) + + @pytest.mark.asyncio + async def test_basic(self, app, prefix, command): + """Test the basic expected response from a prefix handler""" + handler = self.make_default_handler() + app.add_handler(handler) + text = prefix + command + + assert await self.response(app, make_message_update(text)) + assert not is_match(handler, make_message_update(command)) + assert not is_match(handler, make_message_update(prefix + 'notacommand')) + assert not is_match(handler, make_command_update(f'not {text} at start')) + + def test_single_multi_prefixes_commands(self, prefixes, commands, prefix_message_update): + """Test various combinations of prefixes and commands""" + handler = self.make_default_handler() + result = is_match(handler, prefix_message_update) + expected = prefix_message_update.message.text in combinations(prefixes, commands) + return result == expected + + def test_edited(self, prefix_message): + handler_edited = self.make_default_handler() + handler_no_edited = self.make_default_handler(filters=~filters.UpdateType.EDITED_MESSAGE) + self._test_edited(prefix_message, handler_edited, handler_no_edited) + + def test_with_filter(self, prefix_message_text): + handler = self.make_default_handler(filters=filters.ChatType.GROUP) + text = prefix_message_text + assert is_match(handler, make_message_update(text, chat=Chat(-23, Chat.GROUP))) + assert not is_match(handler, make_message_update(text, chat=Chat(23, Chat.PRIVATE))) + + def test_other_update_types(self, false_update): + handler = self.make_default_handler() + assert not is_match(handler, false_update) + + def test_filters_for_wrong_command(self, mock_filter): + """Filters should not be executed if the command does not match the handler""" + handler = self.make_default_handler(filters=mock_filter) + assert not is_match(handler, make_message_update('/test')) + assert not mock_filter.tested + + def test_edit_prefix(self): + handler = self.make_default_handler() + handler.prefix = ['?', '§'] + assert handler._commands == list(combinations(['?', '§'], self.COMMANDS)) + handler.prefix = '+' + assert handler._commands == list(combinations(['+'], self.COMMANDS)) + + def test_edit_command(self): + handler = self.make_default_handler() + handler.command = 'foo' + assert handler._commands == list(combinations(self.PREFIXES, ['foo'])) + + @pytest.mark.asyncio + async def test_basic_after_editing(self, app, prefix, command): + """Test the basic expected response from a prefix handler""" + handler = self.make_default_handler() + app.add_handler(handler) + text = prefix + command + + assert await self.response(app, make_message_update(text)) + handler.command = 'foo' + text = prefix + 'foo' + assert await self.response(app, make_message_update(text)) + + @pytest.mark.asyncio + async def test_context(self, app, prefix_message_update): + handler = self.make_default_handler(self.callback) + app.add_handler(handler) + assert await self.response(app, prefix_message_update) + + @pytest.mark.asyncio + async def test_context_args(self, app, prefix_message_text): + handler = self.make_default_handler(self.callback_args) + await self._test_context_args_or_regex(app, handler, prefix_message_text) + + @pytest.mark.asyncio + async def test_context_regex(self, app, prefix_message_text): + handler = self.make_default_handler(self.callback_regex1, filters=filters.Regex('one two')) + await self._test_context_args_or_regex(app, handler, prefix_message_text) + + @pytest.mark.asyncio + async def test_context_multiple_regex(self, app, prefix_message_text): + handler = self.make_default_handler( + self.callback_regex2, filters=filters.Regex('one') & filters.Regex('two') + ) + await self._test_context_args_or_regex(app, handler, prefix_message_text) diff --git a/tests/test_inlinequeryhandler.py b/tests/test_inlinequeryhandler.py new file mode 100644 index 00000000000..995fc09086c --- /dev/null +++ b/tests/test_inlinequeryhandler.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + CallbackQuery, + Bot, + Message, + User, + Chat, + InlineQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, + Location, +) +from telegram.ext import InlineQueryHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=2, **request.param) + + +@pytest.fixture(scope='function') +def inline_query(bot): + return Update( + 0, + inline_query=InlineQuery( + 'id', + User(2, 'test user', False), + 'test query', + offset='22', + location=Location(latitude=-23.691288, longitude=-46.788279), + ), + ) + + +class TestInlineQueryHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + handler = InlineQueryHandler(self.callback) + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.inline_query, InlineQuery) + ) + + def callback_pattern(self, update, context): + if context.matches[0].groups(): + self.test_flag = context.matches[0].groups() == ('t', ' query') + if context.matches[0].groupdict(): + self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' query'} + + def test_other_update_types(self, false_update): + handler = InlineQueryHandler(self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, inline_query): + handler = InlineQueryHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(inline_query) + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_pattern(self, app, inline_query): + handler = InlineQueryHandler(self.callback_pattern, pattern=r'(?P.*)est(?P.*)') + app.add_handler(handler) + + async with app: + await app.process_update(inline_query) + assert self.test_flag + + app.remove_handler(handler) + handler = InlineQueryHandler(self.callback_pattern, pattern=r'(t)est(.*)') + app.add_handler(handler) + + await app.process_update(inline_query) + assert self.test_flag + + @pytest.mark.parametrize('chat_types', [[Chat.SENDER], [Chat.SENDER, Chat.SUPERGROUP], []]) + @pytest.mark.parametrize( + 'chat_type,result', [(Chat.SENDER, True), (Chat.CHANNEL, False), (None, False)] + ) + @pytest.mark.asyncio + async def test_chat_types(self, app, inline_query, chat_types, chat_type, result): + try: + inline_query.inline_query.chat_type = chat_type + + handler = InlineQueryHandler(self.callback, chat_types=chat_types) + app.add_handler(handler) + async with app: + await app.process_update(inline_query) + + if not chat_types: + assert self.test_flag is False + else: + assert self.test_flag == result + + finally: + inline_query.inline_query.chat_type = None diff --git a/tests/test_messagehandler.py b/tests/test_messagehandler.py new file mode 100644 index 00000000000..a727a0905f5 --- /dev/null +++ b/tests/test_messagehandler.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import re +import asyncio + +import pytest + +from telegram import ( + Message, + Update, + Chat, + Bot, + User, + CallbackQuery, + InlineQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import filters, MessageHandler, CallbackContext, JobQueue +from telegram.ext.filters import MessageFilter + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'callback_query', + 'inline_query', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=1, **request.param) + + +@pytest.fixture(scope='class') +def message(bot): + return Message(1, None, Chat(1, ''), from_user=User(1, '', False), bot=bot) + + +class TestMessageHandler: + test_flag = False + SRE_TYPE = type(re.match("", "")) + + def test_slot_behaviour(self, mro_slots): + handler = MessageHandler(filters.ALL, self.callback) + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.chat_data, dict) + and isinstance(context.bot_data, dict) + and ( + ( + isinstance(context.user_data, dict) + and ( + isinstance(update.message, Message) + or isinstance(update.edited_message, Message) + ) + ) + or ( + context.user_data is None + and ( + isinstance(update.channel_post, Message) + or isinstance(update.edited_channel_post, Message) + ) + ) + ) + ) + + def callback_regex1(self, update, context): + if context.matches: + types = all(type(res) is self.SRE_TYPE for res in context.matches) + num = len(context.matches) == 1 + self.test_flag = types and num + + def callback_regex2(self, update, context): + if context.matches: + types = all(type(res) is self.SRE_TYPE for res in context.matches) + num = len(context.matches) == 2 + self.test_flag = types and num + + def test_with_filter(self, message): + handler = MessageHandler(filters.ChatType.GROUP, self.callback) + + message.chat.type = 'group' + assert handler.check_update(Update(0, message)) + + message.chat.type = 'private' + assert not handler.check_update(Update(0, message)) + + def test_callback_query_with_filter(self, message): + class TestFilter(filters.UpdateFilter): + flag = False + + def filter(self, u): + self.flag = True + + test_filter = TestFilter() + handler = MessageHandler(test_filter, self.callback) + + update = Update(1, callback_query=CallbackQuery(1, None, None, message=message)) + + assert update.effective_message + assert not handler.check_update(update) + assert not test_filter.flag + + def test_specific_filters(self, message): + f = ( + ~filters.UpdateType.MESSAGES + & ~filters.UpdateType.CHANNEL_POST + & filters.UpdateType.EDITED_CHANNEL_POST + ) + handler = MessageHandler(f, self.callback) + + assert not handler.check_update(Update(0, edited_message=message)) + assert not handler.check_update(Update(0, message=message)) + assert not handler.check_update(Update(0, channel_post=message)) + assert handler.check_update(Update(0, edited_channel_post=message)) + + def test_other_update_types(self, false_update): + handler = MessageHandler(None, self.callback) + assert not handler.check_update(false_update) + + def test_filters_returns_empty_dict(self): + class DataFilter(MessageFilter): + data_filter = True + + def filter(self, msg: Message): + return {} + + handler = MessageHandler(DataFilter(), self.callback) + assert handler.check_update(Update(0, message)) is False + + @pytest.mark.asyncio + async def test_context(self, app, message): + handler = MessageHandler( + None, + self.callback, + ) + app.add_handler(handler) + + async with app: + await app.process_update(Update(0, message=message)) + assert self.test_flag + + self.test_flag = False + await app.process_update(Update(0, edited_message=message)) + assert self.test_flag + + self.test_flag = False + await app.process_update(Update(0, channel_post=message)) + assert self.test_flag + + self.test_flag = False + await app.process_update(Update(0, edited_channel_post=message)) + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_regex(self, app, message): + handler = MessageHandler(filters.Regex('one two'), self.callback_regex1) + app.add_handler(handler) + + async with app: + message.text = 'not it' + await app.process_update(Update(0, message)) + assert not self.test_flag + + message.text += ' one two now it is' + await app.process_update(Update(0, message)) + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_multiple_regex(self, app, message): + handler = MessageHandler(filters.Regex('one') & filters.Regex('two'), self.callback_regex2) + app.add_handler(handler) + + async with app: + message.text = 'not it' + await app.process_update(Update(0, message)) + assert not self.test_flag + + message.text += ' one two now it is' + await app.process_update(Update(0, message)) + assert self.test_flag diff --git a/tests/test_pollanswerhandler.py b/tests/test_pollanswerhandler.py new file mode 100644 index 00000000000..22ebd07c7ae --- /dev/null +++ b/tests/test_pollanswerhandler.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + CallbackQuery, + Bot, + Message, + User, + Chat, + PollAnswer, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import PollAnswerHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=2, **request.param) + + +@pytest.fixture(scope='function') +def poll_answer(bot): + return Update(0, poll_answer=PollAnswer(1, User(2, 'test user', False), [0, 1])) + + +class TestPollAnswerHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + handler = PollAnswerHandler(self.callback) + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.poll_answer, PollAnswer) + ) + + def test_other_update_types(self, false_update): + handler = PollAnswerHandler(self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, poll_answer): + handler = PollAnswerHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(poll_answer) + assert self.test_flag diff --git a/tests/test_pollhandler.py b/tests/test_pollhandler.py new file mode 100644 index 00000000000..a55b34a41a5 --- /dev/null +++ b/tests/test_pollhandler.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + Poll, + PollOption, + Bot, + Message, + User, + Chat, + CallbackQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import PollHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=2, **request.param) + + +@pytest.fixture(scope='function') +def poll(bot): + return Update( + 0, + poll=Poll( + 1, + 'question', + [PollOption('1', 0), PollOption('2', 0)], + 0, + False, + False, + Poll.REGULAR, + True, + ), + ) + + +class TestPollHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + inst = PollHandler(self.callback) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and context.user_data is None + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.poll, Poll) + ) + + def test_other_update_types(self, false_update): + handler = PollHandler(self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, poll): + handler = PollHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(poll) + assert self.test_flag diff --git a/tests/test_precheckoutqueryhandler.py b/tests/test_precheckoutqueryhandler.py new file mode 100644 index 00000000000..c028640423d --- /dev/null +++ b/tests/test_precheckoutqueryhandler.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + Chat, + Bot, + ChosenInlineResult, + User, + Message, + CallbackQuery, + InlineQuery, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import PreCheckoutQueryHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'inline_query', + 'chosen_inline_result', + 'shipping_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=1, **request.param) + + +@pytest.fixture(scope='class') +def pre_checkout_query(): + return Update( + 1, + pre_checkout_query=PreCheckoutQuery( + 'id', User(1, 'test user', False), 'EUR', 223, 'invoice_payload' + ), + ) + + +class TestPreCheckoutQueryHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + inst = PreCheckoutQueryHandler(self.callback) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.pre_checkout_query, PreCheckoutQuery) + ) + + def test_other_update_types(self, false_update): + handler = PreCheckoutQueryHandler(self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, pre_checkout_query): + handler = PreCheckoutQueryHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(pre_checkout_query) + assert self.test_flag diff --git a/tests/test_shippingqueryhandler.py b/tests/test_shippingqueryhandler.py new file mode 100644 index 00000000000..6fef02f9148 --- /dev/null +++ b/tests/test_shippingqueryhandler.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Update, + Chat, + Bot, + ChosenInlineResult, + User, + Message, + CallbackQuery, + InlineQuery, + ShippingQuery, + PreCheckoutQuery, + ShippingAddress, +) +from telegram.ext import ShippingQueryHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'inline_query', + 'chosen_inline_result', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=1, **request.param) + + +@pytest.fixture(scope='class') +def shiping_query(): + return Update( + 1, + shipping_query=ShippingQuery( + 42, + User(1, 'test user', False), + 'invoice_payload', + ShippingAddress('EN', 'my_state', 'my_city', 'steer_1', '', 'post_code'), + ), + ) + + +class TestShippingQueryHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + inst = ShippingQueryHandler(self.callback) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, Update) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and isinstance(context.user_data, dict) + and context.chat_data is None + and isinstance(context.bot_data, dict) + and isinstance(update.shipping_query, ShippingQuery) + ) + + def test_other_update_types(self, false_update): + handler = ShippingQueryHandler(self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app, shiping_query): + handler = ShippingQueryHandler(self.callback) + app.add_handler(handler) + + async with app: + await app.process_update(shiping_query) + assert self.test_flag diff --git a/tests/test_stringcommandhandler.py b/tests/test_stringcommandhandler.py new file mode 100644 index 00000000000..9c8450ef32b --- /dev/null +++ b/tests/test_stringcommandhandler.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Bot, + Update, + Message, + User, + Chat, + CallbackQuery, + InlineQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import StringCommandHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'inline_query', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=1, **request.param) + + +class TestStringCommandHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + inst = StringCommandHandler('sleepy', self.callback) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, str) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and context.user_data is None + and context.chat_data is None + and isinstance(context.bot_data, dict) + ) + + async def callback_args(self, update, context): + self.test_flag = context.args == ['one', 'two'] + + def test_other_update_types(self, false_update): + handler = StringCommandHandler('test', self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context(self, app): + handler = StringCommandHandler('test', self.callback) + app.add_handler(handler) + + async with app: + await app.process_update('/test') + assert self.test_flag + + @pytest.mark.asyncio + async def test_context_args(self, app): + handler = StringCommandHandler('test', self.callback_args) + app.add_handler(handler) + + async with app: + await app.process_update('/test') + assert not self.test_flag + + await app.process_update('/test one two') + assert self.test_flag diff --git a/tests/test_stringregexhandler.py b/tests/test_stringregexhandler.py new file mode 100644 index 00000000000..b7db2ec5bbe --- /dev/null +++ b/tests/test_stringregexhandler.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import asyncio + +import pytest + +from telegram import ( + Bot, + Update, + Message, + User, + Chat, + CallbackQuery, + InlineQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, +) +from telegram.ext import StringRegexHandler, CallbackContext, JobQueue + +message = Message(1, None, Chat(1, ''), from_user=User(1, '', False), text='Text') + +params = [ + {'message': message}, + {'edited_message': message}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat', message=message)}, + {'channel_post': message}, + {'edited_channel_post': message}, + {'inline_query': InlineQuery(1, User(1, '', False), '', '')}, + {'chosen_inline_result': ChosenInlineResult('id', User(1, '', False), '')}, + {'shipping_query': ShippingQuery('id', User(1, '', False), '', None)}, + {'pre_checkout_query': PreCheckoutQuery('id', User(1, '', False), '', 0, '')}, + {'callback_query': CallbackQuery(1, User(1, '', False), 'chat')}, +] + +ids = ( + 'message', + 'edited_message', + 'callback_query', + 'channel_post', + 'edited_channel_post', + 'inline_query', + 'chosen_inline_result', + 'shipping_query', + 'pre_checkout_query', + 'callback_query_without_message', +) + + +@pytest.fixture(scope='class', params=params, ids=ids) +def false_update(request): + return Update(update_id=1, **request.param) + + +class TestStringRegexHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + inst = StringRegexHandler('pfft', self.callback) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, str) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + ) + + async def callback_pattern(self, update, context): + if context.matches[0].groups(): + self.test_flag = context.matches[0].groups() == ('t', ' message') + if context.matches[0].groupdict(): + self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' message'} + + @pytest.mark.asyncio + async def test_basic(self, app): + handler = StringRegexHandler('(?P.*)est(?P.*)', self.callback) + app.add_handler(handler) + + assert handler.check_update('test message') + async with app: + await app.process_update('test message') + assert self.test_flag + + assert not handler.check_update('does not match') + + def test_other_update_types(self, false_update): + handler = StringRegexHandler('test', self.callback) + assert not handler.check_update(false_update) + + @pytest.mark.asyncio + async def test_context_pattern(self, app): + handler = StringRegexHandler(r'(t)est(.*)', self.callback_pattern) + app.add_handler(handler) + + async with app: + await app.process_update('test message') + assert self.test_flag + + app.remove_handler(handler) + handler = StringRegexHandler(r'(t)est(.*)', self.callback_pattern) + app.add_handler(handler) + + await app.process_update('test message') + assert self.test_flag diff --git a/tests/test_typehandler.py b/tests/test_typehandler.py new file mode 100644 index 00000000000..8bb8fcbb264 --- /dev/null +++ b/tests/test_typehandler.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +from collections import OrderedDict +import asyncio + +import pytest + +from telegram import Bot +from telegram.ext import TypeHandler, CallbackContext, JobQueue + + +class TestTypeHandler: + test_flag = False + + def test_slot_behaviour(self, mro_slots): + inst = TypeHandler(dict, self.callback) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.fixture(autouse=True) + def reset(self): + self.test_flag = False + + async def callback(self, update, context): + self.test_flag = ( + isinstance(context, CallbackContext) + and isinstance(context.bot, Bot) + and isinstance(update, dict) + and isinstance(context.update_queue, asyncio.Queue) + and isinstance(context.job_queue, JobQueue) + and context.user_data is None + and context.chat_data is None + and isinstance(context.bot_data, dict) + ) + + @pytest.mark.asyncio + async def test_basic(self, app): + handler = TypeHandler(dict, self.callback) + app.add_handler(handler) + + assert handler.check_update({'a': 1, 'b': 2}) + assert not handler.check_update('not a dict') + async with app: + await app.process_update({'a': 1, 'b': 2}) + assert self.test_flag + + def test_strict(self): + handler = TypeHandler(dict, self.callback, strict=True) + o = OrderedDict({'a': 1, 'b': 2}) + assert handler.check_update({'a': 1, 'b': 2}) + assert not handler.check_update(o) From 346ac17e72569c61d69da1848c0938db7bed6a79 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 26 Mar 2022 17:35:24 +0100 Subject: [PATCH 093/153] fix the test suite --- tests/test_jobqueue.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index 5c3fcbe28cb..12eb95ebfc1 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -104,7 +104,7 @@ def test_slot_behaviour(self, job_queue, mro_slots): def test_application_weakref(self, bot): jq = JobQueue() - application = ApplicationBuilder().bot(bot).job_queue(None).build() + application = ApplicationBuilder().token(bot.token).job_queue(None).build() with pytest.raises(RuntimeError, match='No application was set'): jq.application jq.set_application(application) @@ -262,7 +262,7 @@ async def test_error(self, job_queue): @pytest.mark.asyncio async def test_in_application(self, bot): - app = ApplicationBuilder().bot(bot).build() + app = ApplicationBuilder().token(bot.token).build() async with app: assert not app.job_queue.scheduler.running await app.start() @@ -516,7 +516,7 @@ async def test_dispatch_error_that_raises_errors(self, job_queue, app, caplog): async def test_custom_context(self, bot, job_queue): application = ( ApplicationBuilder() - .bot(bot) + .token(bot.token) .context_types( ContextTypes( context=CustomContext, bot_data=int, user_data=float, chat_data=complex From 424f8100ea66b0a35d444200e14cb3024adf34a7 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 27 Mar 2022 11:23:31 +0200 Subject: [PATCH 094/153] small tweak in DictPersistence --- telegram/ext/_dictpersistence.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index 8ab8f8db0aa..c1f598129e2 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -237,7 +237,9 @@ def conversations_json(self) -> str: """:obj:`str`: The conversations serialized as a JSON-string.""" if self._conversations_json: return self._conversations_json - return self._encode_conversations_to_json(self.conversations) # type: ignore[arg-type] + if self.conversations: + return self._encode_conversations_to_json(self.conversations) + return json.dumps(self.conversations) async def get_user_data(self) -> Dict[int, Dict[object, object]]: """Returns the user_data created from the ``user_data_json`` or an empty :obj:`dict`. From d9e0eda5a483ffe451e34b9a431f9233d8344482 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 30 Mar 2022 22:08:06 +0200 Subject: [PATCH 095/153] Try to improve non-blocking logic of CH --- telegram/ext/_conversationhandler.py | 35 +++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 30fce68c263..d4bb18556c4 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -38,7 +38,7 @@ ) from telegram import Update -from telegram._utils.defaultvalue import DEFAULT_TRUE +from telegram._utils.defaultvalue import DEFAULT_TRUE, DefaultValue from telegram._utils.types import DVInput from telegram.ext import ( CallbackContext, @@ -50,6 +50,7 @@ StringCommandHandler, StringRegexHandler, TypeHandler, + ExtBot, ) from telegram._utils.warnings import warn from telegram.ext._utils.trackingdict import TrackingDict @@ -218,7 +219,8 @@ class ConversationHandler(Handler[Update, CCT]): block (:obj:`bool`, optional): Pass :obj:`False` to *overrule* the :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, :attr:`states` and :attr:`fallbacks`). - Defaults to :obj:`True`, in which case the handlers setting will be respected. + By default the handlers setting and :attr:`telegram.ext.Defaults.bock` will be + respected (in that order). .. versionadded:: 13.2 .. versionchanged:: 14.0 @@ -230,7 +232,8 @@ class ConversationHandler(Handler[Update, CCT]): Attributes: persistent (:obj:`bool`): Optional. If the conversations dict for this handler should be saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run asynchronously. Always + :obj:`True` since conversation handlers handle any non-blocking callbacks internally. .. versionadded:: 13.2 @@ -238,6 +241,7 @@ class ConversationHandler(Handler[Update, CCT]): __slots__ = ( '_allow_reentry', + '_block', '_child_conversations', '_conversation_timeout', '_conversations', @@ -285,7 +289,11 @@ def __init__( PollAnswerHandler, ) - self.block = block + # self.block is what the Application checks and we want it to always run CH in a blocking + # way so that CH can take care of any non-blocking logic internally + self.block = True + # Store the actual setting in a protected variable instead + self._block = block self._entry_points = entry_points self._states = states @@ -741,9 +749,23 @@ async def handle_update( # type: ignore[override] if timeout_job is not None: timeout_job.schedule_removal() + + # Resolution order of "block": + # 1. Setting of the ConversationHandler + # 2. Setting of the selected handler + # 3. Default values of the bot + if self._block is not DEFAULT_TRUE: + # CHs block-setting has highest priority + block = self._block + else: + if handler.block is not DEFAULT_TRUE: + block = handler.block + elif isinstance(application.bot, ExtBot) and application.bot.defaults is not None: + block = application.bot.defaults.block + else: + block = DefaultValue.get_value(handler.block) + try: - # TODO handle non-blocking handlers correctly - block = self.block and handler.block if block: new_state: object = await handler.handle_update( update, application, handler_check_result, context @@ -792,6 +814,7 @@ async def handle_update( # type: ignore[override] # Don't pass the new state here. If we're in a nested conversation, the parent is # expecting None as return value. raise ApplicationHandlerStop() + # Signals a possible parent conversation to stay in the current state return None def _update_state(self, new_state: object, key: ConversationKey) -> None: From 2fc5eeb75fc742b28015105e8e820a5f487f4cc0 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 30 Mar 2022 22:14:05 +0200 Subject: [PATCH 096/153] Remove some resolved todo items --- telegram/ext/_application.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index f1c61ada8fe..b91ab18b4b2 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -574,7 +574,6 @@ def __run(self, updater_coroutine: Coroutine, close_loop: bool = True) -> None: loop.run_until_complete(self.start()) try: loop.run_forever() - # TODO: maybe allow for custom exception classes to catch here? Or provide a custom one? except (KeyboardInterrupt, SystemExit): pass finally: @@ -986,9 +985,6 @@ def migrate_chat_data( # old_chat_id is marked for deletion by drop_chat_data above def _mark_for_persistence_update(self, *, update: object = None, job: 'Job' = None) -> None: - # TODO: This should be at the end of `Application.process_update`, when the task created - # by `Application.create_task` is done and when a `Job` is done. Add tests to make sure - # that this is happening if isinstance(update, Update): if update.effective_chat: self._chat_ids_to_be_updated_in_persistence.add(update.effective_chat.id) From 737e013ce0492c255bb14a02e5b25ec31372f4a9 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 30 Mar 2022 22:33:53 +0200 Subject: [PATCH 097/153] some more --- telegram/ext/_messagehandler.py | 2 -- tests/test_callbackquery.py | 1 - 2 files changed, 3 deletions(-) diff --git a/telegram/ext/_messagehandler.py b/telegram/ext/_messagehandler.py index d9b43e00c36..477d21aed8f 100644 --- a/telegram/ext/_messagehandler.py +++ b/telegram/ext/_messagehandler.py @@ -94,8 +94,6 @@ def check_update(self, update: object) -> Optional[Union[bool, Dict[str, list]]] """ if isinstance(update, Update): - # The `or False` makes sure that we don't return empty dicts - # TODO: add a test for this to MessageHandler return self.filters.check_update(update) or False return None diff --git a/tests/test_callbackquery.py b/tests/test_callbackquery.py index 7e3dcb0c22f..9cad6e1c7ed 100644 --- a/tests/test_callbackquery.py +++ b/tests/test_callbackquery.py @@ -127,7 +127,6 @@ async def make_assertion(*_, **kwargs): assert await check_defaults_handling(callback_query.answer, callback_query.get_bot()) monkeypatch.setattr(callback_query.get_bot(), 'answer_callback_query', make_assertion) - # TODO: PEP8 assert await callback_query.answer() @pytest.mark.asyncio From d14e55d2edbe384090d06b638a0bd2cd32e72915 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Wed, 30 Mar 2022 22:36:21 +0200 Subject: [PATCH 098/153] fix a typo --- telegram/ext/_conversationhandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index d4bb18556c4..97ec24b194e 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -219,7 +219,7 @@ class ConversationHandler(Handler[Update, CCT]): block (:obj:`bool`, optional): Pass :obj:`False` to *overrule* the :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, :attr:`states` and :attr:`fallbacks`). - By default the handlers setting and :attr:`telegram.ext.Defaults.bock` will be + By default the handlers setting and :attr:`telegram.ext.Defaults.block` will be respected (in that order). .. versionadded:: 13.2 From 0f69b0212f37f6a7b092447f07a726b789f6e314 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 1 Apr 2022 18:08:31 +0200 Subject: [PATCH 099/153] Rename App.dispatch_error to App.process_error --- telegram/ext/_application.py | 33 ++++++++++++++++++--------------- telegram/ext/_jobqueue.py | 2 +- tests/test_jobqueue.py | 4 ++-- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index b91ab18b4b2..1e97c0f0816 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -506,7 +506,7 @@ def run_polling( ) def error_callback(exc: TelegramError) -> None: - self.create_task(self.dispatch_error(update=None, error=exc)) + self.create_task(self.process_error(error=exc, update=None)) return self.__run( updater_coroutine=self.updater.start_polling( @@ -589,17 +589,17 @@ def __run(self, updater_coroutine: Coroutine, close_loop: bool = True) -> None: def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Task: """Thin wrapper around :func:`asyncio.create_task` that handles exceptions raised by - the :paramref:`coroutine` with :meth:`dispatch_error`. + the :paramref:`coroutine` with :meth:`process_error`. Note: * If :paramref:`coroutine` raises an exception, it will be set on the task created by - this method even though it's handled by :meth:`dispatch_error`. + this method even though it's handled by :meth:`process_error`. * If the application is currently running, tasks created by this methods will be awaited by :meth:`stop`. Args: coroutine: The coroutine to run as task. - update: Optional. If passed, will be passed to :meth:`dispatch_error` as additional + update: Optional. If passed, will be passed to :meth:`process_error` as additional information for the error handlers. Moreover, the corresponding :attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of :meth:`update_persistence` after the :paramref:`coroutine` is finished. @@ -613,7 +613,7 @@ def __create_task( self, coroutine: Coroutine, update: object = None, is_error_handler: bool = False ) -> asyncio.Task: # Unfortunately, we can't know if `coroutine` runs one of the error handler functions - # but by passing `is_error_handler=True` from `dispatch_error`, we can make sure that we + # but by passing `is_error_handler=True` from `process_error`, we can make sure that we # get at most one recursion of the user calls `create_task` manually with an error handler # function task = asyncio.create_task( @@ -674,7 +674,7 @@ async def __create_task_callback( # If we arrive here, an exception happened in the task and was neither # ApplicationHandlerStop nor raised by an error handler. # So we can and must handle it - await self.dispatch_error(update, exception, coroutine=coroutine) + await self.process_error(update=update, error=exception, coroutine=coroutine) # Raise exception so that it can be set on the task raise exception @@ -709,6 +709,7 @@ async def __process_update_wrapper(self, update: object) -> None: async def process_update(self, update: object) -> None: """Processes a single update and updates the persistence. + Exceptions raised by handler callbacks will be processed by :meth:`process_update`. .. versionchanged:: 14.0 This calls :meth:`update_persistence` exactly once after handling of the update was @@ -758,7 +759,7 @@ async def process_update(self, update: object) -> None: # Dispatch any error. except Exception as exc: - if await self.dispatch_error(update, exc): + if await self.process_error(update=update, error=exc): _logger.debug('Error handler stopped further handlers.') break @@ -1125,7 +1126,7 @@ async def __update_persistence(self) -> None: # dispatch any errors await asyncio.gather( *( - self.dispatch_error(update=None, error=result) + self.process_error(error=result, update=None) for result in results if isinstance(result, Exception) ) @@ -1137,7 +1138,7 @@ def add_error_handler( block: DVInput[bool] = DEFAULT_TRUE, ) -> None: """Registers an error handler in the Application. This handler will receive every error - which happens in your bot. See the docs of :meth:`dispatch_error` for more details on how + which happens in your bot. See the docs of :meth:`process_error` for more details on how errors are handled. Note: @@ -1146,11 +1147,11 @@ def add_error_handler( Args: callback (:obj:`callable`): The callback function for this error handler. Will be called when an error is raised. Callback signature: - ``def callback(update: object, context: CallbackContext)``. + ``def callback(update: Optional[object], context: CallbackContext)``. The error that happened will be present in ``context.error``. block (:obj:`bool`, optional): Determines whether the return value of the callback should be awaited before processing the next error handler in - :meth:`dispatch_error`. Defaults to :obj:`True`. + :meth:`process_error`. Defaults to :obj:`True`. """ if callback in self.error_handlers: _logger.warning('The callback is already registered as an error handler. Ignoring.') @@ -1167,17 +1168,19 @@ def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None: """ self.error_handlers.pop(callback, None) - async def dispatch_error( + async def process_error( self, update: Optional[object], error: Exception, job: 'Job' = None, coroutine: Coroutine = None, ) -> bool: - """Dispatches an error by passing it to all error handlers registered with + """Processes an error by passing it to all error handlers registered with :meth:`add_error_handler`. If one of the error handlers raises - :class:`telegram.ext.ApplicationHandlerStop`, the update will not be handled by other error - handlers or handlers (even in other groups). All other exceptions raised by an error + :class:`telegram.ext.ApplicationHandlerStop`, the error will not be handled by other error + handlers. Raising :class:`telegram.ext.ApplicationHandlerStop` also stops processing of + the update when this method is called by :meth:`process_update`, i.e. no further handlers + (even in other groups) will handle the update. All other exceptions raised by an error handler will just be logged. .. versionchanged:: 14.0 diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 21266d16418..91a6739cb4b 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -633,7 +633,7 @@ async def _run(self, application: 'Application') -> None: await context.refresh_data() await self.callback(context) except Exception as exc: - await application.create_task(application.dispatch_error(None, exc, job=self)) + await application.create_task(application.process_error(None, exc, job=self)) finally: # This is internal logic of application - let's keep it private for now application._mark_for_persistence_update(job=self) # pylint: disable=protected-access diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index 12eb95ebfc1..5dca6dfa6fd 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -453,7 +453,7 @@ async def test_job_lt_eq(self, job_queue): assert not job < job @pytest.mark.asyncio - async def test_dispatch_error_context(self, job_queue, app): + async def test_process_error_context(self, job_queue, app): app.add_error_handler(self.error_handler_context) job = job_queue.run_once(self.job_with_exception, 0.1) @@ -476,7 +476,7 @@ async def test_dispatch_error_context(self, job_queue, app): assert self.received_error is None @pytest.mark.asyncio - async def test_dispatch_error_that_raises_errors(self, job_queue, app, caplog): + async def test_process_error_that_raises_errors(self, job_queue, app, caplog): app.add_error_handler(self.error_handler_raise_error) with caplog.at_level(logging.ERROR): From e31419c2132f0507de2e3cde55bdfc3f234ce759 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 1 Apr 2022 19:04:25 +0200 Subject: [PATCH 100/153] fix a bug in persisting non-blocking conversations --- telegram/ext/_application.py | 2 ++ telegram/ext/_utils/trackingdict.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 1e97c0f0816..f9d727663e2 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -1106,6 +1106,8 @@ async def __update_persistence(self) -> None: 'persistence with the current state.' ) result = new_state.old_state + # We need to check again on the next run if the state is done + self._conversation_handler_conversations[name].mark_as_accessed(key) else: result = new_state.resolve() else: diff --git a/telegram/ext/_utils/trackingdict.py b/telegram/ext/_utils/trackingdict.py index f65e20d080c..314ab219d06 100644 --- a/telegram/ext/_utils/trackingdict.py +++ b/telegram/ext/_utils/trackingdict.py @@ -84,6 +84,11 @@ def pop_accessed_write_items(self) -> List[Tuple[_KT, _VT]]: keys = self.pop_accessed_keys() return [(key, self[key] if key in self else self.DELETED) for key in keys] + def mark_as_accessed(self, key: _KT) -> None: + """Use this method have the key returned again in the next call to + :meth:`pop_accessed_write_items` or :meth:`pop_accessed_keys""" + self._write_access_keys.add(key) + # Override methods to track access def __setitem__(self, key: _KT, value: _VT) -> None: From d189ce3703ebfe44670650af927ecd28a42b17f4 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 3 Apr 2022 10:09:39 +0200 Subject: [PATCH 101/153] Review --- telegram/ext/_conversationhandler.py | 9 +++++---- telegram/ext/_defaults.py | 5 +---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 97ec24b194e..4f57f22c4f7 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: from telegram.ext import Application, Job, JobQueue -CheckUpdateType = Tuple[object, ConversationKey, Handler, object] +_CheckUpdateType = Tuple[object, ConversationKey, Handler, object] _logger = logging.getLogger(__name__) @@ -638,7 +638,7 @@ def _schedule_job( _logger.exception("Failed to schedule timeout.", exc_info=exc) # pylint: disable=too-many-return-statements - def check_update(self, update: object) -> Optional[CheckUpdateType]: + def check_update(self, update: object) -> Optional[_CheckUpdateType]: """ Determines whether an update should be handled by this conversation handler, and if so in which state the conversation currently is. @@ -664,6 +664,7 @@ def check_update(self, update: object) -> Optional[CheckUpdateType]: key = self._get_key(update) state = self._conversations.get(key) + check: Optional[object] = None # Resolve promises if isinstance(state, PendingState): @@ -701,7 +702,7 @@ def check_update(self, update: object) -> Optional[CheckUpdateType]: return None # Get the handler list for current state, if we didn't find one yet and we're still here - if state is not None and not handler: + if state is not None and handler is None: for candidate in self.states.get(state, []): check = candidate.check_update(update) if check is not None and check is not False: @@ -725,7 +726,7 @@ async def handle_update( # type: ignore[override] self, update: Update, application: 'Application', - check_result: CheckUpdateType, + check_result: _CheckUpdateType, context: CallbackContext, ) -> Optional[object]: """Send the update to the callback for the current state and Handler diff --git a/telegram/ext/_defaults.py b/telegram/ext/_defaults.py index 81460795d22..2c3139479b0 100644 --- a/telegram/ext/_defaults.py +++ b/telegram/ext/_defaults.py @@ -22,8 +22,6 @@ import pytz -from telegram._utils.defaultvalue import DEFAULT_NONE - class Defaults: """Convenience Class to gather all parameters with a (user defined) default value @@ -101,7 +99,7 @@ def __init__( 'protect_content', ): value = getattr(self, kwarg) - if value not in [None, DEFAULT_NONE]: + if value is not None: self._api_defaults[kwarg] = value @property @@ -232,7 +230,6 @@ def __hash__(self) -> int: self._tzinfo, self._block, self._protect_content, - self._protect_content, ) ) From a4b37fd6e61517e75f9bfb356828c14c0ef95606 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 3 Apr 2022 12:57:31 +0200 Subject: [PATCH 102/153] Small adjustments for persistence --- telegram/ext/_application.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index f9d727663e2..92611003e00 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -1004,7 +1004,6 @@ async def _persistence_updater(self) -> None: if not self.persistence: return - await self.update_persistence() try: await asyncio.wait_for( self.__update_persistence_event.wait(), @@ -1014,6 +1013,10 @@ async def _persistence_updater(self) -> None: except asyncio.TimeoutError: pass + # putting this *after* the wait_for so we don't immediately update on startup as + # that would make little sense + await self.update_persistence() + async def update_persistence(self) -> None: """Updates :attr:`user_data`, :attr:`chat_data`, :attr:`bot_data` in :attr:`persistence` along with :attr:`~telegram.ext.ExtBot.callback_data_cache` and the conversation states of @@ -1100,11 +1103,17 @@ async def __update_persistence(self) -> None: # Note that when updating the persistence one last time during self.stop(), # *all* tasks will be done. if not new_state.done(): - # TODO: Try to test that this doesn't happen on shutdown - _logger.warning( - 'A ConversationHandlers state was not yet resolved. Updating the ' - 'persistence with the current state.' - ) + if self.running: + _logger.debug( + 'A ConversationHandlers state was not yet resolved. Updating the ' + 'persistence with the current state. Will check again on next run of ' + 'Application.update_persistence.' + ) + else: + _logger.warning( + 'A ConversationHandlers state was not yet resolved. Updating the ' + 'persistence with the current state.' + ) result = new_state.old_state # We need to check again on the next run if the state is done self._conversation_handler_conversations[name].mark_as_accessed(key) From a0204c69268f83d785acb06947c888bf0570b5a3 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 3 Apr 2022 16:34:06 +0200 Subject: [PATCH 103/153] Fix persistence of nested conversations --- telegram/ext/_application.py | 22 +++++++--------------- telegram/ext/_conversationhandler.py | 16 ++++++++++++---- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 92611003e00..f6ec22f6d1b 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -297,14 +297,15 @@ async def initialize(self) -> None: # Initialize the persistent conversation handlers with the stored states for handler in itertools.chain.from_iterable(self.handlers.values()): if isinstance(handler, ConversationHandler) and handler.persistent and handler.name: - self._conversation_handler_conversations[ - handler.name - ] = await handler._initialize_persistence( # pylint: disable=protected-access - self - ) + await self._add_ch_to_persistence(handler) self._initialized = True + async def _add_ch_to_persistence(self, handler: 'ConversationHandler') -> None: + self._conversation_handler_conversations.update( + await handler._initialize_persistence(self) # pylint: disable=protected-access + ) + async def shutdown(self) -> None: """ @@ -768,13 +769,6 @@ async def process_update(self, update: object) -> None: # blocking handler - the non-blocking handlers mark the update again when finished self._mark_for_persistence_update(update=update) - async def _add_ch_after_init(self, handler: 'ConversationHandler') -> None: - self._conversation_handler_conversations[ - handler.name # type: ignore[index] - ] = await handler._initialize_persistence( # pylint: disable=protected-access - self - ) - def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: """Register a handler. @@ -822,7 +816,7 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> f"can not be persistent if application has no persistence" ) if self._initialized: - self.create_task(self._add_ch_after_init(handler)) + self.create_task(self._add_ch_to_persistence(handler)) warn( 'A persistent `ConversationHandler` was passed to `add_handler`, ' 'after `Application.initialize` was called. This is discouraged.' @@ -1123,8 +1117,6 @@ async def __update_persistence(self) -> None: result = new_state effective_new_state = None if result is TrackingDict.DELETED else result - # TODO: Test that we actually pass `None` here in case the conversation had ended, - # i.e. effective_new_state is TrackingDict.DELETED coroutines.add( self.persistence.update_conversation( name=name, key=key, new_state=effective_new_state diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 4f57f22c4f7..698380eeb98 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -528,7 +528,7 @@ def map_to_parent(self, value: object) -> NoReturn: async def _initialize_persistence( self, application: 'Application' - ) -> TrackingDict[ConversationKey, object]: + ) -> Dict[str, TrackingDict[ConversationKey, object]]: """Initializes the persistence for this handler. While this method is marked as protected, we expect it to be called by the Application/parent conversations. It's just protected to hide it from users. @@ -536,6 +536,10 @@ async def _initialize_persistence( Args: application (:class:`telegram.ext.Application`): The application. + Returns: + A dict {conversation.name -> TrackingDict}, which contains all dict of this + conversation and possible child conversations. + """ if not (self.persistent and self.name and application.persistence): raise RuntimeError( @@ -556,12 +560,16 @@ async def _initialize_persistence( await application.persistence.get_conversations(self.name) ) + out = {self.name: self._conversations} + for handler in self._child_conversations: - await handler._initialize_persistence( # pylint: disable=protected-access - application=application + out.update( + await handler._initialize_persistence( # pylint: disable=protected-access + application=application + ) ) - return self._conversations + return out def _get_key(self, update: Update) -> ConversationKey: chat = update.effective_chat From 655d080750e2d6861e5cde2d1545c522143378a6 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 3 Apr 2022 20:58:19 +0200 Subject: [PATCH 104/153] Remove unused method from webhook setup --- telegram/ext/_utils/webhookhandler.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/telegram/ext/_utils/webhookhandler.py b/telegram/ext/_utils/webhookhandler.py index 90f2db08d21..4685f461777 100644 --- a/telegram/ext/_utils/webhookhandler.py +++ b/telegram/ext/_utils/webhookhandler.py @@ -83,15 +83,6 @@ async def shutdown(self) -> None: await self._http_server.close_all_connections() self._logger.debug('Webhook Server stopped') - # pylint: disable=unused-argument - def handle_error(self, request: object, client_address: str) -> None: - """Handle an error gracefully.""" - self._logger.debug( - 'Exception happened during processing of request from %s', - client_address, - exc_info=True, - ) - class WebhookAppClass(tornado.web.Application): """Application used in the Webserver""" From c99645edab13677319194c261a195df7c67eefbf Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 4 Apr 2022 07:47:42 +0200 Subject: [PATCH 105/153] Review --- telegram/request/_baserequest.py | 8 ++++---- telegram/request/_httpxrequest.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index 18d643fdd47..78f4b0e0fdb 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -121,9 +121,9 @@ async def post( self, url: str, request_data: RequestData = None, - connect_timeout: ODVInput[float] = DEFAULT_NONE, read_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> Union[JSONDict, bool]: """Makes a request to the Bot API handles the return code and parses the answer. @@ -172,8 +172,8 @@ async def retrieve( self, url: str, read_timeout: ODVInput[float] = DEFAULT_NONE, - connect_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> bytes: """Retrieve the contents of a file by its URL. @@ -207,8 +207,8 @@ async def _request_wrapper( method: str, request_data: RequestData = None, read_timeout: ODVInput[float] = DEFAULT_NONE, - connect_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> bytes: """Wraps the real implementation request method. @@ -316,9 +316,9 @@ async def do_request( url: str, method: str, request_data: RequestData = None, - connect_timeout: ODVInput[float] = DEFAULT_NONE, read_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> Tuple[int, bytes]: """Makes a request to the Bot API. Must be implemented by a subclass. diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 425b92224e7..98e100c1f5f 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -127,9 +127,9 @@ async def do_request( url: str, method: str, request_data: RequestData = None, - connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, read_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, write_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, + connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, pool_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE, ) -> Tuple[int, bytes]: """See :meth:`BaseRequest.do_request`.""" From 40f886b8350561a2021444b49a4ee177434b441a Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 4 Apr 2022 07:53:03 +0200 Subject: [PATCH 106/153] more adjusting --- telegram/request/_httpxrequest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 98e100c1f5f..4e243bf367b 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -83,9 +83,9 @@ def __init__( self, connection_pool_size: int = 1, proxy_url: str = None, - connect_timeout: Optional[float] = 5.0, read_timeout: Optional[float] = 5.0, write_timeout: Optional[float] = 5.0, + connect_timeout: Optional[float] = 5.0, pool_timeout: Optional[float] = 1.0, ): timeout = httpx.Timeout( From 02f3bf38428308508f7f7a6cf209af982a451029 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 5 Apr 2022 08:06:14 +0200 Subject: [PATCH 107/153] Review --- tests/test_application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_application.py b/tests/test_application.py index 715d012f627..4d333f93254 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -204,7 +204,7 @@ def test_custom_context_init(self, bot): assert isinstance(application.bot_data, complex) @pytest.mark.asyncio - @pytest.mark.asyncio('updater', (True, False)) + @pytest.mark.parametrize('updater', (True, False)) async def test_initialize(self, bot, monkeypatch, updater): """Initialization of persistence is tested test_basepersistence""" self.test_flag = set() From ed2a39ecf561a4abc745539efe5d58a8aa7f83cf Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 5 Apr 2022 18:15:26 +0200 Subject: [PATCH 108/153] Make CH.persistent immutable --- telegram/ext/_conversationhandler.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 698380eeb98..58e61ded9ba 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -230,8 +230,6 @@ class ConversationHandler(Handler[Update, CCT]): ValueError Attributes: - persistent (:obj:`bool`): Optional. If the conversations dict for this handler should be - saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` block (:obj:`bool`): Determines whether the callback will run asynchronously. Always :obj:`True` since conversation handlers handle any non-blocking callbacks internally. @@ -252,9 +250,9 @@ class ConversationHandler(Handler[Update, CCT]): '_per_chat', '_per_message', '_per_user', + '_persistent', '_states', '_timeout_jobs_lock', - 'persistent', 'timeout_jobs', ) @@ -314,7 +312,7 @@ def __init__( if persistent and not self.name: raise ValueError("Conversations can't be persistent when handler is unnamed.") - self.persistent: bool = persistent + self._persistent: bool = persistent if not any((self.per_user, self.per_chat, self.per_message)): raise ValueError("'per_user', 'per_chat' and 'per_message' can't all be 'False'") @@ -512,6 +510,16 @@ def name(self) -> Optional[str]: def name(self, value: object) -> NoReturn: raise AttributeError("You can not assign a new value to name after initialization.") + @property + def persistent(self) -> bool: + """:obj:`bool`: Optional. If the conversations dict for this handler should be + saved.""" + return self._persistent + + @persistent.setter + def persistent(self, value: object) -> NoReturn: + raise AttributeError("You can not assign a new value to persistent after initialization.") + @property def map_to_parent(self) -> Optional[Dict[object, object]]: """Dict[:obj:`object`, :obj:`object`]: Optional. A :obj:`dict` that can be From 51b4044fdbfe642f0f64264867ef31e4b264c629 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 7 Apr 2022 22:08:44 +0200 Subject: [PATCH 109/153] typo --- telegram/ext/_conversationhandler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 58e61ded9ba..2fc7779e80a 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -342,7 +342,7 @@ def __init__( per_faq_link = ( " Read this FAQ entry to learn more about the per_* settings: " "https://github.com/python-telegram-bot/python-telegram-bot/wiki" - "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversation handler-do." + "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversationhandler-do." ) for handler in all_handlers: @@ -784,9 +784,12 @@ async def handle_update( # type: ignore[override] try: if block: - new_state: object = await handler.handle_update( - update, application, handler_check_result, context - ) + try: + new_state: object = await handler.handle_update( + update, application, handler_check_result, context + ) + except Exception as exc: + print(exc) else: new_state = application.create_task( coroutine=handler.handle_update( From 1a16bda66b7ef3414518f509d4a754b69e1ab975 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 8 Apr 2022 16:32:12 +0200 Subject: [PATCH 110/153] Remove falsely committed debug print --- telegram/ext/_conversationhandler.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 2fc7779e80a..45ed298a0c8 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -784,12 +784,9 @@ async def handle_update( # type: ignore[override] try: if block: - try: - new_state: object = await handler.handle_update( - update, application, handler_check_result, context - ) - except Exception as exc: - print(exc) + new_state: object = await handler.handle_update( + update, application, handler_check_result, context + ) else: new_state = application.create_task( coroutine=handler.handle_update( From 92e2cdd7ab3bf91b53299b503f883353a8d65349 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 10 Apr 2022 10:22:28 +0200 Subject: [PATCH 111/153] Review --- setup.cfg | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index b39b7bf5f68..5df067d852c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,11 +14,12 @@ max-line-length = 99 ignore = W503, W605 extend-ignore = E203 exclude = setup.py, setup-raw.py docs/source/conf.py -per-file-ignores = - telegram/ext/_jobqueue.py:E402 [pylint.message-control] -disable = C0330,R0801,R0913,R0904,R0903,R0902,W0511,C0116,C0115,W0703,R0914,R0914,C0302,R0912,R0915,R0401 +disable = duplicate-code,too-many-arguments,too-many-public-methods,too-few-public-methods, + broad-except,too-many-instance-attributes,fixme,missing-function-docstring, + missing-class-docstring,too-many-locals,too-many-lines,too-many-branches, + too-many-statements,cyclic-import [tool:pytest] testpaths = tests From 62fc4ffe6d578bede422c741a76be29d0fd1dcc1 Mon Sep 17 00:00:00 2001 From: Poolitzer Date: Sun, 10 Apr 2022 15:54:40 +0200 Subject: [PATCH 112/153] Updating the examples (#2937) * Fix: await coroutines Also put one logging line where it belongs * Feat: switch all markdown mentions to hml * Feat: readd pre-commits for examples * Fix: Apply Black * Fix: mypy caught wrong application definition * Fix: missed contexttype * Fix: Temporarily fixing black dependency * Fix: Remove markdown leftovers Co-authored-by: Harshil <37377066+harshil21@users.noreply.github.com> * verify examples by running them * Autofix issues in 1 file Resolved issues in examples/rawapibot.py via DeepSource Autofix * dont use my dev chat id * remove unused backslash from string * Apply suggestions from code review Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> * Update passportbot.html, minor tweak in {passport, payment}bot.py Co-authored-by: Harshil <37377066+harshil21@users.noreply.github.com> Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> --- .pre-commit-config.yaml | 14 +++++++ examples/arbitrarycallbackdatabot.py | 10 ++--- examples/chatmemberbot.py | 6 +-- examples/contexttypesbot.py | 10 ++++- examples/conversationbot.py | 2 +- examples/conversationbot2.py | 2 +- examples/deeplinking.py | 6 +-- examples/echobot.py | 4 +- examples/errorhandlerbot.py | 8 +--- examples/inlinebot.py | 14 +++---- examples/inlinekeyboard.py | 6 +-- examples/inlinekeyboard2.py | 20 ++++----- examples/nestedconversationbot.py | 24 ++++++----- examples/passportbot.html | 39 +++++++++-------- examples/passportbot.py | 7 +--- examples/paymentbot.py | 8 ++-- examples/persistentconversationbot.py | 5 +-- examples/pollbot.py | 20 ++++----- examples/rawapibot.py | 60 +++++++++++++++------------ examples/timerbot.py | 9 ++-- requirements-dev.txt | 2 + 21 files changed, 147 insertions(+), 129 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b78c12611f7..31751c4477e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,8 @@ repos: args: - --diff - --check + additional_dependencies: + - click==8.0.2 - repo: https://gitlab.com/pycqa/flake8 rev: 4.0.1 hooks: @@ -44,6 +46,18 @@ repos: - APScheduler==3.6.3 - cachetools==4.2.2 - . # this basically does `pip install -e .`n + - id: mypy + name: mypy-examples + files: ^examples/.*\.py$ + args: + - --no-strict-optional + - --follow-imports=silent + additional_dependencies: + - certifi + - tornado>=6.1 + - APScheduler==3.6.3 + - cachetools==4.2.2 + - . # this basically does `pip install -e .` - repo: https://github.com/asottile/pyupgrade rev: v2.29.0 hooks: diff --git a/examples/arbitrarycallbackdatabot.py b/examples/arbitrarycallbackdatabot.py index 354ec934c2f..15adac4f994 100644 --- a/examples/arbitrarycallbackdatabot.py +++ b/examples/arbitrarycallbackdatabot.py @@ -97,13 +97,13 @@ def main() -> None: .build() ) - application.application.add_handler(CommandHandler('start', start)) - application.application.add_handler(CommandHandler('help', help_command)) - application.application.add_handler(CommandHandler('clear', clear)) - application.application.add_handler( + application.add_handler(CommandHandler('start', start)) + application.add_handler(CommandHandler('help', help_command)) + application.add_handler(CommandHandler('clear', clear)) + application.add_handler( CallbackQueryHandler(handle_invalid_button, pattern=InvalidCallbackData) ) - application.application.add_handler(CallbackQueryHandler(list_button)) + application.add_handler(CallbackQueryHandler(list_button)) # Run the bot until the user presses Ctrl-C application.run_polling() diff --git a/examples/chatmemberbot.py b/examples/chatmemberbot.py index c3606602be2..30f7138bc86 100644 --- a/examples/chatmemberbot.py +++ b/examples/chatmemberbot.py @@ -127,12 +127,12 @@ async def greet_chat_members(update: Update, context: CallbackContext.DEFAULT_TY member_name = update.chat_member.new_chat_member.user.mention_html() if not was_member and is_member: - update.effective_chat.send_message( + await update.effective_chat.send_message( f"{member_name} was added by {cause_name}. Welcome!", parse_mode=ParseMode.HTML, ) elif was_member and not is_member: - update.effective_chat.send_message( + await update.effective_chat.send_message( f"{member_name} is no longer with us. Thanks a lot, {cause_name} ...", parse_mode=ParseMode.HTML, ) @@ -153,7 +153,7 @@ def main() -> None: # Run the bot until the user presses Ctrl-C # We pass 'allowed_updates' handle *all* updates including `chat_member` updates # To reset this, simply pass `allowed_updates=[]` - application.run_polling()(allowed_updates=Update.ALL_TYPES) + application.run_polling(allowed_updates=Update.ALL_TYPES) if __name__ == "__main__": diff --git a/examples/contexttypesbot.py b/examples/contexttypesbot.py index c931f92ca33..8bbcbe5f3ed 100644 --- a/examples/contexttypesbot.py +++ b/examples/contexttypesbot.py @@ -10,6 +10,7 @@ bot. """ +import logging from collections import defaultdict from typing import DefaultDict, Optional, Set @@ -25,6 +26,12 @@ Application, ) +# Enable logging +logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO +) +logger = logging.getLogger(__name__) + class ChatData: """Custom class for chat_data. Here we store data per message.""" @@ -88,7 +95,7 @@ async def count_click(update: Update, context: CustomContext) -> None: """Update the click count for the message.""" context.message_clicks += 1 await update.callback_query.answer() - update.effective_message.edit_text( + await update.effective_message.edit_text( f'This button was clicked {context.message_clicks} times.', reply_markup=InlineKeyboardMarkup.from_button( InlineKeyboardButton(text='Click me!', callback_data='button') @@ -116,7 +123,6 @@ def main() -> None: context_types = ContextTypes(context=CustomContext, chat_data=ChatData) application = Application.builder().token("TOKEN").context_types(context_types).build() - application = application.application # run track_users in its own group to not interfere with the user handlers application.add_handler(TypeHandler(Update, track_users), group=-1) application.add_handler(CommandHandler("start", start)) diff --git a/examples/conversationbot.py b/examples/conversationbot.py index 691e982c13f..3fb171cfe78 100644 --- a/examples/conversationbot.py +++ b/examples/conversationbot.py @@ -69,7 +69,7 @@ async def photo(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: """Stores the photo and asks for a location.""" user = update.message.from_user photo_file = await update.message.photo[-1].get_file() - photo_file.download('user_photo.jpg') + await photo_file.download('user_photo.jpg') logger.info("Photo of %s: %s", user.first_name, 'user_photo.jpg') await update.message.reply_text( 'Gorgeous! Now, send me your location please, or send /skip if you don\'t want to.' diff --git a/examples/conversationbot2.py b/examples/conversationbot2.py index 6d5737e8f8b..2e8f854a7f0 100644 --- a/examples/conversationbot2.py +++ b/examples/conversationbot2.py @@ -89,7 +89,7 @@ async def received_information(update: Update, context: CallbackContext.DEFAULT_ await update.message.reply_text( "Neat! Just so you know, this is what you already told me:" - f"{facts_to_str(user_data)} You can tell me more, or change your opinion" + f"{facts_to_str(user_data)}You can tell me more, or change your opinion" " on something.", reply_markup=markup, ) diff --git a/examples/deeplinking.py b/examples/deeplinking.py index f0644d16aef..659c2245d3c 100644 --- a/examples/deeplinking.py +++ b/examples/deeplinking.py @@ -73,10 +73,8 @@ async def deep_linked_level_2(update: Update, context: CallbackContext.DEFAULT_T """Reached through the SO_COOL payload""" bot = context.bot url = helpers.create_deep_linked_url(bot.username, USING_ENTITIES) - text = f"You can also mask the deep-linked URLs as links: [▶️ CLICK HERE]({url})." - await update.message.reply_text( - text, parse_mode=ParseMode.MARKDOWN, disable_web_page_preview=True - ) + text = f"You can also mask the deep-linked URLs as links: ▶️ CLICK HERE." + await update.message.reply_text(text, parse_mode=ParseMode.HTML, disable_web_page_preview=True) async def deep_linked_level_3(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: diff --git a/examples/echobot.py b/examples/echobot.py index 95c34d2d084..9a011091cff 100644 --- a/examples/echobot.py +++ b/examples/echobot.py @@ -39,8 +39,8 @@ async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a message when the command /start is issued.""" user = update.effective_user - await update.message.reply_markdown_v2( - fr'Hi {user.mention_markdown_v2()}\!', + await update.message.reply_html( + fr'Hi {user.mention_html()}!', reply_markup=ForceReply(selective=True), ) diff --git a/examples/errorhandlerbot.py b/examples/errorhandlerbot.py index 8b0079d1648..8f78be734ac 100644 --- a/examples/errorhandlerbot.py +++ b/examples/errorhandlerbot.py @@ -18,9 +18,6 @@ ) logger = logging.getLogger(__name__) -# The token you got from @botfather when you created the bot -BOT_TOKEN = "TOKEN" - # This can be your own ID, or one for a developer group/channel. # You can use the /start command of this bot to see your chat id. DEVELOPER_CHAT_ID = 123456789 @@ -70,10 +67,7 @@ async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: def main() -> None: """Run the bot.""" # Create the Application and pass it your bot's token. - application = Application.builder().token(BOT_TOKEN).build() - - # Get the application to register handlers - application = application.application + application = Application.builder().token("TOKEN").build() # Register the commands... application.add_handler(CommandHandler('start', start)) diff --git a/examples/inlinebot.py b/examples/inlinebot.py index c1cfac18547..fb6f782e9e0 100644 --- a/examples/inlinebot.py +++ b/examples/inlinebot.py @@ -14,10 +14,10 @@ """ import logging from uuid import uuid4 +from html import escape from telegram import InlineQueryResultArticle, InputTextMessageContent, Update from telegram.constants import ParseMode -from telegram.helpers import escape_markdown from telegram.ext import Application, InlineQueryHandler, CommandHandler, CallbackContext # Enable logging @@ -28,7 +28,7 @@ # Define a few command handlers. These usually take the two arguments update and -# context. Error handlers also receive the raised TelegramError object in error. +# context. async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Send a message when the command /start is issued.""" await update.message.reply_text('Hi!') @@ -39,8 +39,8 @@ async def help_command(update: Update, context: CallbackContext.DEFAULT_TYPE) -> await update.message.reply_text('Help!') -async def inlinequery(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: - """Handle the inline query.""" +async def inline_query(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: + """Handle the inline query. This is run when you type: @botusername """ query = update.inline_query.query if query == "": @@ -56,14 +56,14 @@ async def inlinequery(update: Update, context: CallbackContext.DEFAULT_TYPE) -> id=str(uuid4()), title="Bold", input_message_content=InputTextMessageContent( - f"*{escape_markdown(query)}*", parse_mode=ParseMode.MARKDOWN + f"{escape(query)}", parse_mode=ParseMode.HTML ), ), InlineQueryResultArticle( id=str(uuid4()), title="Italic", input_message_content=InputTextMessageContent( - f"_{escape_markdown(query)}_", parse_mode=ParseMode.MARKDOWN + f"{escape(query)}", parse_mode=ParseMode.HTML ), ), ] @@ -81,7 +81,7 @@ def main() -> None: application.add_handler(CommandHandler("help", help_command)) # on non command i.e message - echo the message on Telegram - application.add_handler(InlineQueryHandler(inlinequery)) + application.add_handler(InlineQueryHandler(inline_query)) # Run the bot until the user presses Ctrl-C application.run_polling() diff --git a/examples/inlinekeyboard.py b/examples/inlinekeyboard.py index 730e70b23cd..b618d4b85ee 100644 --- a/examples/inlinekeyboard.py +++ b/examples/inlinekeyboard.py @@ -60,9 +60,9 @@ def main() -> None: # Create the Application and pass it your bot's token. application = Application.builder().token("TOKEN").build() - application.application.add_handler(CommandHandler('start', start)) - application.application.add_handler(CallbackQueryHandler(button)) - application.application.add_handler(CommandHandler('help', help_command)) + application.add_handler(CommandHandler('start', start)) + application.add_handler(CallbackQueryHandler(button)) + application.add_handler(CommandHandler('help', help_command)) # Run the bot until the user presses Ctrl-C application.run_polling() diff --git a/examples/inlinekeyboard2.py b/examples/inlinekeyboard2.py index cb95637e666..24f67b2adfd 100644 --- a/examples/inlinekeyboard2.py +++ b/examples/inlinekeyboard2.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # Stages -FIRST, SECOND = range(2) +START_ROUTES, END_ROUTES = range(2) # Callback data ONE, TWO, THREE, FOUR = range(4) @@ -56,7 +56,7 @@ async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: # Send message with text and appended InlineKeyboard await update.message.reply_text("Start handler, Choose a route", reply_markup=reply_markup) # Tell ConversationHandler that we're in state `FIRST` now - return FIRST + return START_ROUTES async def start_over(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: @@ -77,7 +77,7 @@ async def start_over(update: Update, context: CallbackContext.DEFAULT_TYPE) -> i # originated the CallbackQuery. This gives the feeling of an # interactive menu. await query.edit_message_text(text="Start handler, Choose a route", reply_markup=reply_markup) - return FIRST + return START_ROUTES async def one(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: @@ -94,7 +94,7 @@ async def one(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: await query.edit_message_text( text="First CallbackQueryHandler, Choose a route", reply_markup=reply_markup ) - return FIRST + return START_ROUTES async def two(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: @@ -111,11 +111,11 @@ async def two(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: await query.edit_message_text( text="Second CallbackQueryHandler, Choose a route", reply_markup=reply_markup ) - return FIRST + return START_ROUTES async def three(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: - """Show new choice of buttons""" + """Show new choice of buttons. This is the end point of the conversation.""" query = update.callback_query await query.answer() keyboard = [ @@ -129,7 +129,7 @@ async def three(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: text="Third CallbackQueryHandler. Do want to start over?", reply_markup=reply_markup ) # Transfer to conversation state `SECOND` - return SECOND + return END_ROUTES async def four(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: @@ -146,7 +146,7 @@ async def four(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: await query.edit_message_text( text="Fourth CallbackQueryHandler, Choose a route", reply_markup=reply_markup ) - return FIRST + return START_ROUTES async def end(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: @@ -173,13 +173,13 @@ def main() -> None: conv_handler = ConversationHandler( entry_points=[CommandHandler('start', start)], states={ - FIRST: [ + START_ROUTES: [ CallbackQueryHandler(one, pattern='^' + str(ONE) + '$'), CallbackQueryHandler(two, pattern='^' + str(TWO) + '$'), CallbackQueryHandler(three, pattern='^' + str(THREE) + '$'), CallbackQueryHandler(four, pattern='^' + str(FOUR) + '$'), ], - SECOND: [ + END_ROUTES: [ CallbackQueryHandler(start_over, pattern='^' + str(ONE) + '$'), CallbackQueryHandler(end, pattern='^' + str(TWO) + '$'), ], diff --git a/examples/nestedconversationbot.py b/examples/nestedconversationbot.py index d80c5f7e997..d08137636f5 100644 --- a/examples/nestedconversationbot.py +++ b/examples/nestedconversationbot.py @@ -120,27 +120,29 @@ async def adding_self(update: Update, context: CallbackContext.DEFAULT_TYPE) -> async def show_data(update: Update, context: CallbackContext.DEFAULT_TYPE) -> str: """Pretty print gathered data.""" - def prettyprint(user_data: Dict[str, Any], level: str) -> str: - people = user_data.get(level) + def pretty_print(data: Dict[str, Any], level: str) -> str: + people = data.get(level) if not people: return '\nNo information yet.' - text = '' + return_str = '' if level == SELF: - for person in user_data[level]: - text += f"\nName: {person.get(NAME, '-')}, Age: {person.get(AGE, '-')}" + for person in data[level]: + return_str += f"\nName: {person.get(NAME, '-')}, Age: {person.get(AGE, '-')}" else: male, female = _name_switcher(level) - for person in user_data[level]: + for person in data[level]: gender = female if person[GENDER] == FEMALE else male - text += f"\n{gender}: Name: {person.get(NAME, '-')}, Age: {person.get(AGE, '-')}" - return text + return_str += ( + f"\n{gender}: Name: {person.get(NAME, '-')}, Age: {person.get(AGE, '-')}" + ) + return return_str user_data = context.user_data - text = f"Yourself:{prettyprint(user_data, SELF)}" - text += f"\n\nParents:{prettyprint(user_data, PARENTS)}" - text += f"\n\nChildren:{prettyprint(user_data, CHILDREN)}" + text = f"Yourself:{pretty_print(user_data, SELF)}" + text += f"\n\nParents:{pretty_print(user_data, PARENTS)}" + text += f"\n\nChildren:{pretty_print(user_data, CHILDREN)}" buttons = [[InlineKeyboardButton(text='Back', callback_data=str(END))]] keyboard = InlineKeyboardMarkup(buttons) diff --git a/examples/passportbot.html b/examples/passportbot.html index 4e37f0c69c1..b25c51f6a50 100644 --- a/examples/passportbot.html +++ b/examples/passportbot.html @@ -3,27 +3,32 @@ Telegram passport test! - - - - - - + +

Telegram passport test

+ + + + diff --git a/examples/passportbot.py b/examples/passportbot.py index 47d5402ab7d..07efa9670f6 100644 --- a/examples/passportbot.py +++ b/examples/passportbot.py @@ -20,7 +20,7 @@ # Enable logging logging.basicConfig( - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.DEBUG + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO ) logger = logging.getLogger(__name__) @@ -83,7 +83,7 @@ async def msg(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: selfie_file = await data.selfie.get_file() print(data.type, selfie_file) await selfie_file.download() - if data.type in ( + if data.translation and data.type in ( 'passport', 'driver_license', 'identity_card', @@ -109,9 +109,6 @@ def main() -> None: Application.builder().token("TOKEN").private_key(private_key.read_bytes()).build() ) - # Get the application to register handlers - application = application.application - # On messages that include passport data call msg application.add_handler(MessageHandler(filters.PASSPORT_DATA, msg)) diff --git a/examples/paymentbot.py b/examples/paymentbot.py index 4ccafd1ed29..28ae5899b0e 100644 --- a/examples/paymentbot.py +++ b/examples/paymentbot.py @@ -24,6 +24,8 @@ ) logger = logging.getLogger(__name__) +PAYMENT_PROVIDER_TOKEN = 'TOKEN' + async def start_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Displays info on how to use the bot.""" @@ -45,7 +47,6 @@ async def start_with_shipping_callback( # select a payload just for you to recognize its the donation from your bot payload = "Custom-Payload" # In order to get a provider_token see https://core.telegram.org/bots/payments#getting-a-token - provider_token = "PROVIDER_TOKEN" currency = "USD" # price in dollars price = 1 @@ -60,7 +61,7 @@ async def start_with_shipping_callback( title, description, payload, - provider_token, + PAYMENT_PROVIDER_TOKEN, currency, prices, need_name=True, @@ -81,7 +82,6 @@ async def start_without_shipping_callback( # select a payload just for you to recognize its the donation from your bot payload = "Custom-Payload" # In order to get a provider_token see https://core.telegram.org/bots/payments#getting-a-token - provider_token = "PROVIDER_TOKEN" currency = "USD" # price in dollars price = 1 @@ -91,7 +91,7 @@ async def start_without_shipping_callback( # optionally pass need_name=True, need_phone_number=True, # need_email=True, need_shipping_address=True, is_flexible=True await context.bot.send_invoice( - chat_id, title, description, payload, provider_token, currency, prices + chat_id, title, description, payload, PAYMENT_PROVIDER_TOKEN, currency, prices ) diff --git a/examples/persistentconversationbot.py b/examples/persistentconversationbot.py index 71eb2f5bcba..1e871de8b84 100644 --- a/examples/persistentconversationbot.py +++ b/examples/persistentconversationbot.py @@ -123,7 +123,7 @@ async def done(update: Update, context: CallbackContext.DEFAULT_TYPE) -> int: del context.user_data['choice'] await update.message.reply_text( - f"I learned these facts about you: {facts_to_str(context.user_data)} Until next time!", + f"I learned these facts about you: {facts_to_str(context.user_data)}Until next time!", reply_markup=ReplyKeyboardRemove(), ) return ConversationHandler.END @@ -135,9 +135,6 @@ def main() -> None: persistence = PicklePersistence(filepath='conversationbot') application = Application.builder().token("TOKEN").persistence(persistence).build() - # Get the application to register handlers - application = application.application - # Add conversation handler with the states CHOOSING, TYPING_CHOICE and TYPING_REPLY conv_handler = ConversationHandler( entry_points=[CommandHandler('start', start)], diff --git a/examples/pollbot.py b/examples/pollbot.py index 2b7898cd1f3..dc4091f724a 100644 --- a/examples/pollbot.py +++ b/examples/pollbot.py @@ -69,9 +69,9 @@ async def poll(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: async def receive_poll_answer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Summarize a users poll vote""" answer = update.poll_answer - poll_id = answer.poll_id + answered_poll = context.bot_data[answer.poll_id] try: - questions = context.bot_data[poll_id]["questions"] + questions = answered_poll["questions"] # this means this poll answer update is from an old poll, we can't do our answering then except KeyError: return @@ -83,16 +83,14 @@ async def receive_poll_answer(update: Update, context: CallbackContext.DEFAULT_T else: answer_string += questions[question_id] await context.bot.send_message( - context.bot_data[poll_id]["chat_id"], + answered_poll["chat_id"], f"{update.effective_user.mention_html()} feels {answer_string}!", parse_mode=ParseMode.HTML, ) - context.bot_data[poll_id]["answers"] += 1 + answered_poll["answers"] += 1 # Close poll after three participants voted - if context.bot_data[poll_id]["answers"] == 3: - await context.bot.stop_poll( - context.bot_data[poll_id]["chat_id"], context.bot_data[poll_id]["message_id"] - ) + if answered_poll["answers"] == 3: + await context.bot.stop_poll(answered_poll["chat_id"], answered_poll["message_id"]) async def quiz(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: @@ -158,12 +156,12 @@ def main() -> None: application = Application.builder().token("TOKEN").build() application.add_handler(CommandHandler('start', start)) application.add_handler(CommandHandler('poll', poll)) - application.add_handler(PollAnswerHandler(receive_poll_answer)) application.add_handler(CommandHandler('quiz', quiz)) - application.add_handler(PollHandler(receive_quiz_answer)) application.add_handler(CommandHandler('preview', preview)) - application.add_handler(MessageHandler(filters.POLL, receive_poll)) application.add_handler(CommandHandler('help', help_handler)) + application.add_handler(MessageHandler(filters.POLL, receive_poll)) + application.add_handler(PollAnswerHandler(receive_poll_answer)) + application.add_handler(PollHandler(receive_quiz_answer)) # Run the bot until the user presses Ctrl-C application.run_polling() diff --git a/examples/rawapibot.py b/examples/rawapibot.py index 19a9d512fe0..2d451dc3b37 100644 --- a/examples/rawapibot.py +++ b/examples/rawapibot.py @@ -10,51 +10,57 @@ import logging from typing import NoReturn -import telegram +from telegram import Bot from telegram.error import NetworkError, Forbidden -UPDATE_ID = None +logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO +) +logger = logging.getLogger(__name__) async def main() -> NoReturn: """Run the bot.""" - global UPDATE_ID - # Telegram Bot Authorization Token - bot = telegram.Bot('TOKEN') - - # get the first pending update_id, this is so we can skip over it in case - # we get an "Forbidden" exception. - try: - UPDATE_ID = (await bot.get_updates())[0].update_id - except IndexError: - UPDATE_ID = None - - logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - while True: + # Here we use the `async with` syntax to properly initialize and shutdown resources. + async with Bot("TOKEN") as bot: + # get the first pending update_id, this is so we can skip over it in case + # we get a "Forbidden" exception. try: - await echo(bot) - except NetworkError: - await asyncio.sleep(1) - except Forbidden: - # The user has removed or blocked the bot. - UPDATE_ID += 1 + update_id = (await bot.get_updates())[0].update_id + except IndexError: + update_id = None + + logger.info("listening for new messages...") + while True: + try: + update_id = await echo(bot, update_id) + except NetworkError: + await asyncio.sleep(1) + except Forbidden: + # The user has removed or blocked the bot. + update_id += 1 -async def echo(bot: telegram.Bot) -> None: +async def echo(bot: Bot, update_id: int) -> int: """Echo the message the user sent.""" - global UPDATE_ID # Request updates after the last update_id - for update in await bot.get_updates(offset=UPDATE_ID, timeout=10): - UPDATE_ID = update.update_id + 1 + updates = await bot.get_updates(offset=update_id, timeout=10) + for update in updates: + next_update_id = update.update_id + 1 # your bot can receive updates without messages # and not all messages contain text if update.message and update.message.text: # Reply to the message + logger.info("Found message %s!", update.message.text) await update.message.reply_text(update.message.text) + return next_update_id + return update_id if __name__ == '__main__': - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt: # Ignore exception when Ctrl-C is pressed + pass diff --git a/examples/timerbot.py b/examples/timerbot.py index 0d04241b5f8..8ff874dea96 100644 --- a/examples/timerbot.py +++ b/examples/timerbot.py @@ -31,7 +31,7 @@ # Define a few command handlers. These usually take the two arguments update and -# context. Error handlers also receive the raised TelegramError object in error. +# context. # Best practice would be to replace context with an underscore, # since context is an unused local variable. # This being an example and not having context present confusing beginners, @@ -59,7 +59,7 @@ def remove_job_if_exists(name: str, context: CallbackContext.DEFAULT_TYPE) -> bo async def set_timer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: """Add a job to the queue.""" - chat_id = update.message.chat_id + chat_id = update.effective_message.chat_id try: # args[0] should contain the time for the timer in seconds due = int(context.args[0]) @@ -76,7 +76,7 @@ async def set_timer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> No await update.message.reply_text(text) except (IndexError, ValueError): - await update.message.reply_text('Usage: /set ') + await update.effective_message.reply_text('Usage: /set ') async def unset(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: @@ -93,8 +93,7 @@ def main() -> None: application = Application.builder().token("TOKEN").build() # on different commands - answer in Telegram - application.add_handler(CommandHandler("start", start)) - application.add_handler(CommandHandler("help", start)) + application.add_handler(CommandHandler(["start", "help"], start)) application.add_handler(CommandHandler("set", set_timer)) application.add_handler(CommandHandler("unset", unset)) diff --git a/requirements-dev.txt b/requirements-dev.txt index ee7a8b6c744..9f3ae43d89d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,6 +4,8 @@ cryptography!=3.4,!=3.4.1,!=3.4.2,!=3.4.3 pre-commit # Make sure that the versions specified here match the pre-commit settings! black==21.9b0 +# hardpinned dependency for black +click==8.0.2 flake8==4.0.1 pylint==2.12.1 mypy==0.910 From 8411d62e8a51574ce82485a002935aa50727a8f5 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 10 Apr 2022 16:17:32 +0200 Subject: [PATCH 113/153] log errors in Update.de_json in Updater --- telegram/ext/_updater.py | 30 ++++++--- telegram/ext/_utils/webhookhandler.py | 10 ++- tests/test_updater.py | 94 +++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 10 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 9dae213ea56..0c41427d364 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -267,15 +267,27 @@ async def _start_polling( self._logger.debug('Bootstrap done') async def polling_action_cb() -> bool: - updates = await self.bot.get_updates( - offset=self._last_update_id, - timeout=timeout, - read_timeout=read_timeout, - connect_timeout=connect_timeout, - write_timeout=write_timeout, - pool_timeout=pool_timeout, - allowed_updates=allowed_updates, - ) + try: + updates = await self.bot.get_updates( + offset=self._last_update_id, + timeout=timeout, + read_timeout=read_timeout, + connect_timeout=connect_timeout, + write_timeout=write_timeout, + pool_timeout=pool_timeout, + allowed_updates=allowed_updates, + ) + except TelegramError as exc: + # TelegramErrors should be processed by the network retry loop + raise exc + except Exception as exc: + # Other exceptions should not. Let's log them for now. + self._logger.critical( + 'Something went wrong processing the data received from Telegram. ' + 'Received data was *not* processed!', + exc_info=exc, + ) + return True if updates: if not self.running: diff --git a/telegram/ext/_utils/webhookhandler.py b/telegram/ext/_utils/webhookhandler.py index 4685f461777..792bcf720b8 100644 --- a/telegram/ext/_utils/webhookhandler.py +++ b/telegram/ext/_utils/webhookhandler.py @@ -125,7 +125,15 @@ async def post(self) -> None: self.set_status(HTTPStatus.OK) self._logger.debug('Webhook received data: %s', json_string) - update = Update.de_json(data, self.bot) + try: + update = Update.de_json(data, self.bot) + except Exception as exc: + self._logger.critical( + 'Something went wrong processing the data received from Telegram. ' + 'Received data was *not* processed!', + exc_info=exc, + ) + if update: self._logger.debug('Received Update with ID %d on Webhook', update.update_id) diff --git a/tests/test_updater.py b/tests/test_updater.py index 5e31cccc76f..b2e1fb14c4c 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -482,6 +482,53 @@ async def _start_polling(*args, **kwargs): await updater.start_polling() assert updater.running is False + @pytest.mark.asyncio + async def test_polling_update_de_json_fails(self, monkeypatch, updater, caplog): + updates = asyncio.Queue() + raise_exception = True + await updates.put(Update(update_id=1)) + + async def get_updates(*args, **kwargs): + if raise_exception: + await asyncio.sleep(0.01) + raise TypeError('Invalid Data') + + next_update = await updates.get() + updates.task_done() + return [next_update] + + orig_del_webhook = updater.bot.delete_webhook + + async def delete_webhook(*args, **kwargs): + # Dropping pending updates is done by passing the parameter to delete_webhook + if kwargs.get('drop_pending_updates'): + self.message_count += 1 + return await orig_del_webhook(*args, **kwargs) + + monkeypatch.setattr(updater.bot, 'get_updates', get_updates) + monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) + + async with updater: + with caplog.at_level(logging.CRITICAL): + await updater.start_polling() + assert updater.running + await asyncio.sleep(1) + + assert len(caplog.records) > 0 + for record in caplog.records: + assert record.getMessage().startswith('Something went wrong processing') + + # Make sure that everything works fine again when receiving proper updates + raise_exception = False + await asyncio.sleep(0.5) + caplog.clear() + with caplog.at_level(logging.CRITICAL): + await updates.join() + assert len(caplog.records) == 0 + + await updater.stop() + assert not updater.running + @pytest.mark.asyncio @pytest.mark.parametrize('ext_bot', [True, False]) @pytest.mark.parametrize('drop_pending_updates', (True, False)) @@ -832,3 +879,50 @@ async def return_true(*args, **kwargs): # assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR await updater.stop() + + @pytest.mark.asyncio + async def test_webhook_update_de_json_fails(self, monkeypatch, updater, caplog): + async def delete_webhook(*args, **kwargs): + return True + + async def set_webhook(*args, **kwargs): + return True + + def de_json_fails(*args, **kwargs): + raise TypeError('Invalid input') + + monkeypatch.setattr(updater.bot, 'set_webhook', set_webhook) + monkeypatch.setattr(updater.bot, 'delete_webhook', delete_webhook) + orig_de_json = Update.de_json + monkeypatch.setattr(Update, 'de_json', de_json_fails) + + ip = '127.0.0.1' + port = randrange(1024, 49152) # Select random port + + async with updater: + return_value = await updater.start_webhook( + ip_address=ip, + port=port, + url_path='TOKEN', + ) + assert return_value is updater.update_queue + assert updater.running + + # Now, we send an update to the server + update = make_message_update('Webhook') + with caplog.at_level(logging.CRITICAL): + await send_webhook_message(ip, port, update.to_json(), 'TOKEN') + + assert len(caplog.records) == 1 + assert caplog.records[-1].getMessage().startswith('Something went wrong processing') + + # Make sure that everything works fine again when receiving proper updates + caplog.clear() + with caplog.at_level(logging.CRITICAL): + monkeypatch.setattr(Update, 'de_json', orig_de_json) + await send_webhook_message(ip, port, update.to_json(), 'TOKEN') + assert (await updater.update_queue.get()).to_dict() == update.to_dict() + assert len(caplog.records) == 0 + + await updater.stop() + assert not updater.running From c116f2e3c3987d86245233b15726763c0af51012 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 12 Apr 2022 20:39:24 +0200 Subject: [PATCH 114/153] Stabilize startup of App.run() --- telegram/ext/_application.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index f6ec22f6d1b..697164ffb94 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -570,20 +570,28 @@ def __run(self, updater_coroutine: Coroutine, close_loop: bool = True) -> None: # running event loop or we are in the main thread, which are the intended use cases. # See the docs of get_event_loop() and get_running_loop() for more info loop = asyncio.get_event_loop() - loop.run_until_complete(self.initialize()) - loop.run_until_complete(updater_coroutine) - loop.run_until_complete(self.start()) try: + loop.run_until_complete(self.initialize()) + loop.run_until_complete(updater_coroutine) + loop.run_until_complete(self.start()) + loop.run_forever() except (KeyboardInterrupt, SystemExit): pass + except Exception as exc: + # In case the coroutine wasn't awaited, we don't need to bother the user with a warning + updater_coroutine.close() + raise exc finally: # We arrive here either by catching the exceptions above or if the loop gets stopped try: # Mypy doesn't know that we already check if updater is None - loop.run_until_complete(self.updater.stop()) # type: ignore[union-attr] - loop.run_until_complete(self.stop()) + if self.updater.running: # type: ignore[union-attr] + loop.run_until_complete(self.updater.stop()) # type: ignore[union-attr] + if self.running: + loop.run_until_complete(self.stop()) loop.run_until_complete(self.shutdown()) + loop.run_until_complete(self.updater.shutdown()) # type: ignore[union-attr] finally: if close_loop: loop.close() From 43ca1dae4017ee35212a3431593d5637b641d7d2 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 12 Apr 2022 21:45:32 +0200 Subject: [PATCH 115/153] Add another warning about CH timeout not working if JQ is not running --- telegram/ext/_conversationhandler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 45ed298a0c8..62aa9f3a90b 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -799,7 +799,16 @@ async def handle_update( # type: ignore[override] raise_dp_handler_stop = True async with self._timeout_jobs_lock: if self.conversation_timeout: - if application.job_queue is not None: + if application.job_queue is None: + warn( + "Ignoring `conversation_timeout` because the Application has no JobQueue.", + ) + elif not application.job_queue.scheduler.running: + warn( + "Ignoring `conversation_timeout` because the Applications JobQueue is " + "not running.", + ) + else: # Add the new timeout job # checking if the new state is self.END is done in _schedule_job if isinstance(new_state, asyncio.Task): @@ -813,10 +822,6 @@ async def handle_update( # type: ignore[override] self._schedule_job( new_state, application, update, context, conversation_key ) - else: - warn( - "Ignoring `conversation_timeout` because the Application has no JobQueue.", - ) if isinstance(self.map_to_parent, dict) and new_state in self.map_to_parent: self._update_state(self.END, conversation_key) From 87e977f2c3c5ff5c0dae90ad75e156959116acc6 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 16 Apr 2022 22:22:22 +0200 Subject: [PATCH 116/153] change blocking-resolution order in CH --- telegram/ext/_conversationhandler.py | 32 ++++++++++++---------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 62aa9f3a90b..38b9d8c02b8 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -216,15 +216,17 @@ class ConversationHandler(Handler[Update, CCT]): map_to_parent (Dict[:obj:`object`, :obj:`object`], optional): A :obj:`dict` that can be used to instruct a nested conversation handler to transition into a mapped state on its parent conversation handler in place of a specified nested state. - block (:obj:`bool`, optional): Pass :obj:`False` to *overrule* the + block (:obj:`bool`, optional): Pass :obj:`False` to set a default value for the :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, - :attr:`states` and :attr:`fallbacks`). - By default the handlers setting and :attr:`telegram.ext.Defaults.block` will be - respected (in that order). + :attr:`states` and :attr:`fallbacks`). The resolution order for checking if a handler + should be run non-blocking is: + + 1. :attr:`telegram.ext.Handler.block` (if set) + 2. the value passed to this parameter (if any) + 3. :attr:`telegram.ext.Defaults.block` (if defaults are used) - .. versionadded:: 13.2 .. versionchanged:: 14.0 - No longer *overrides* the handlers settings. + No longer overrides the handlers settings. Resolution order was changed. Raises: ValueError @@ -233,8 +235,6 @@ class ConversationHandler(Handler[Update, CCT]): block (:obj:`bool`): Determines whether the callback will run asynchronously. Always :obj:`True` since conversation handlers handle any non-blocking callbacks internally. - .. versionadded:: 13.2 - """ __slots__ = ( @@ -346,9 +346,6 @@ def __init__( ) for handler in all_handlers: - if self.block: - handler.block = True - if isinstance(handler, (StringCommandHandler, StringRegexHandler)): warn( "The `ConversationHandler` only handles updates of type `telegram.Update`. " @@ -768,15 +765,14 @@ async def handle_update( # type: ignore[override] timeout_job.schedule_removal() # Resolution order of "block": - # 1. Setting of the ConversationHandler - # 2. Setting of the selected handler + # 1. Setting of the selected handler + # 2. Setting of the ConversationHandler # 3. Default values of the bot - if self._block is not DEFAULT_TRUE: - # CHs block-setting has highest priority - block = self._block + if handler.block is not DEFAULT_TRUE: + block = handler.block else: - if handler.block is not DEFAULT_TRUE: - block = handler.block + if self._block is not DEFAULT_TRUE: + block = self._block elif isinstance(application.bot, ExtBot) and application.bot.defaults is not None: block = application.bot.defaults.block else: From ef4bc43b671f06680235b447e3f17447e817d056 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sat, 16 Apr 2022 23:35:58 +0200 Subject: [PATCH 117/153] small fix --- telegram/ext/_conversationhandler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 38b9d8c02b8..af1dde51d9a 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -597,7 +597,8 @@ def _get_key(self, update: Update) -> ConversationKey: raise RuntimeError("Can't build key for update without CallbackQuery!") if update.callback_query.inline_message_id: key.append(update.callback_query.inline_message_id) - key.append(update.callback_query.message.message_id) # type: ignore[union-attr] + else: + key.append(update.callback_query.message.message_id) # type: ignore[union-attr] return tuple(key) From c46fcdd0f47aa2cbaf2800b0a281069b1b06a08e Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Sun, 17 Apr 2022 12:04:51 +0530 Subject: [PATCH 118/153] Docs/CSI for asyncio (#2926) Co-authored-by: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> --- docs/source/conf.py | 3 + telegram/_bot.py | 1280 ++++++++++++++++---- telegram/_callbackquery.py | 2 +- telegram/_chat.py | 38 +- telegram/_chatjoinrequest.py | 8 +- telegram/_chatmemberupdated.py | 2 +- telegram/_choseninlineresult.py | 2 +- telegram/_files/_basethumbedmedium.py | 4 +- telegram/_files/file.py | 21 +- telegram/_files/inputfile.py | 10 +- telegram/_files/inputmedia.py | 53 +- telegram/_files/sticker.py | 26 +- telegram/_files/video.py | 4 +- telegram/_games/game.py | 12 +- telegram/_inline/inlinequery.py | 15 +- telegram/_message.py | 38 +- telegram/_payment/precheckoutquery.py | 2 +- telegram/_payment/shippingquery.py | 2 +- telegram/_replykeyboardmarkup.py | 4 +- telegram/_update.py | 19 +- telegram/_user.py | 36 +- telegram/_userprofilephotos.py | 2 +- telegram/_utils/datetime.py | 43 +- telegram/error.py | 2 +- telegram/ext/_application.py | 245 +++- telegram/ext/_applicationbuilder.py | 205 +++- telegram/ext/_basepersistence.py | 25 +- telegram/ext/_callbackcontext.py | 62 +- telegram/ext/_callbackdatacache.py | 17 +- telegram/ext/_callbackqueryhandler.py | 26 +- telegram/ext/_chatjoinrequesthandler.py | 15 +- telegram/ext/_chatmemberhandler.py | 16 +- telegram/ext/_choseninlineresulthandler.py | 17 +- telegram/ext/_commandhandler.py | 56 +- telegram/ext/_contexttypes.py | 33 +- telegram/ext/_conversationhandler.py | 161 ++- telegram/ext/_defaults.py | 15 +- telegram/ext/_dictpersistence.py | 16 +- telegram/ext/_handler.py | 17 +- telegram/ext/_inlinequeryhandler.py | 22 +- telegram/ext/_jobqueue.py | 63 +- telegram/ext/_messagehandler.py | 17 +- telegram/ext/_picklepersistence.py | 16 +- telegram/ext/_pollanswerhandler.py | 15 +- telegram/ext/_pollhandler.py | 16 +- telegram/ext/_precheckoutqueryhandler.py | 14 +- telegram/ext/_shippingqueryhandler.py | 14 +- telegram/ext/_stringcommandhandler.py | 18 +- telegram/ext/_stringregexhandler.py | 16 +- telegram/ext/_typehandler.py | 30 +- telegram/ext/_updater.py | 93 +- telegram/ext/_utils/stack.py | 2 +- telegram/ext/_utils/trackingdict.py | 5 +- telegram/ext/filters.py | 45 +- telegram/helpers.py | 14 +- telegram/request/__init__.py | 28 +- telegram/request/_baserequest.py | 90 +- telegram/request/_httpxrequest.py | 62 +- telegram/request/_requestdata.py | 7 +- telegram/request/_requestparameter.py | 4 +- tests/test_bot.py | 2 +- 61 files changed, 2175 insertions(+), 972 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b59bc35189e..61a6cee75f8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -127,6 +127,9 @@ # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' +# Decides the language used for syntax highlighting of code blocks. +highlight_language = 'python3' + # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] diff --git a/telegram/_bot.py b/telegram/_bot.py index c528fe60318..e204c3d8b03 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -163,16 +163,16 @@ class Bot(TelegramObject, AbstractAsyncContextManager): * Attempting to pickle a bot instance will now raise :exc:`pickle.PicklingError`. Args: - token (:obj:`str`): Bot's unique authentication. + token (:obj:`str`): Bot's unique authentication token. base_url (:obj:`str`, optional): Telegram Bot API service URL. base_file_url (:obj:`str`, optional): Telegram Bot API file URL. request (:class:`telegram.request.BaseRequest`, optional): Pre initialized :class:`telegram.request.BaseRequest` instances. Will be used for all bot methods - *except* for :attr:`get_updates`. If not passed, an instance of + *except* for :meth:`get_updates`. If not passed, an instance of :class:`telegram.request.HTTPXRequest` will be used. get_updates_request (:class:`telegram.request.BaseRequest`, optional): Pre initialized :class:`telegram.request.BaseRequest` instances. Will be used exclusively for - :attr:`get_updates`. If not passed, an instance of + :meth:`get_updates`. If not passed, an instance of :class:`telegram.request.HTTPXRequest` will be used. private_key (:obj:`bytes`, optional): Private key for decryption of telegram passport data. private_key_password (:obj:`bytes`, optional): Password for above private key. @@ -279,12 +279,12 @@ def _insert_defaults(self, data: Dict[str, object]) -> None: # pylint: disable= async def _post( self, endpoint: str, - data: JSONDict = None, # {'chat_id': 123, 'text': 'Hello there!'} + data: JSONDict = None, read_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, - api_kwargs: JSONDict = None, # {'new_param': whatever} + api_kwargs: JSONDict = None, ) -> Union[bool, JSONDict, None]: if data is None: data = {} @@ -514,9 +514,18 @@ async def get_me( """A simple method for testing your bot's auth token. Requires no parameters. Args: - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -588,9 +597,18 @@ async def send_message( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -656,9 +674,18 @@ async def delete_message( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). message_id (:obj:`int`): Identifier of the message to delete. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -718,9 +745,18 @@ async def forward_message( .. versionadded:: 13.10 - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -760,8 +796,8 @@ async def send_photo( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -818,7 +854,17 @@ async def send_photo( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -868,8 +914,8 @@ async def send_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -947,7 +993,17 @@ async def send_audio( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1003,8 +1059,8 @@ async def send_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1076,7 +1132,17 @@ async def send_document( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1126,8 +1192,8 @@ async def send_sticker( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1168,7 +1234,17 @@ async def send_sticker( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1205,8 +1281,8 @@ async def send_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, @@ -1290,7 +1366,17 @@ async def send_video( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1347,8 +1433,8 @@ async def send_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, @@ -1412,7 +1498,17 @@ async def send_video_note( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1464,8 +1560,8 @@ async def send_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1539,7 +1635,17 @@ async def send_animation( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1594,8 +1700,8 @@ async def send_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1658,7 +1764,17 @@ async def send_voice( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1707,8 +1823,8 @@ async def send_media_group( ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1734,7 +1850,17 @@ async def send_media_group( original message. allow_sending_without_reply (:obj:`bool`, optional): Pass :obj:`True`, if the message should be sent even if the specified replied-to message is not found. - timeout (:obj:`int` | :obj:`float`, optional): Send file timeout (default: 20 seconds). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to ``20``. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1827,9 +1953,18 @@ async def send_location( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -1928,9 +2063,18 @@ async def edit_message_live_location( :tg-const:`telegram.constants.LocationLimit.HEADING` if specified. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for a new inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2003,9 +2147,18 @@ async def stop_message_live_location( specified. Identifier of the inline message. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for a new inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2099,9 +2252,18 @@ async def send_venue( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2211,9 +2373,18 @@ async def send_contact( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2297,9 +2468,18 @@ async def send_game( reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for a new inline keyboard. If empty, one ‘Play game_title’ button will be shown. If not empty, the first button must launch the game. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2350,9 +2530,18 @@ async def send_chat_action( action(:obj:`str`): Type of action to broadcast. Choose one, depending on what the user is about to receive. For convenience look at the constants in :class:`telegram.constants.ChatAction`. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2510,9 +2699,18 @@ async def answer_inline_query( the inline query to answer. If passed, PTB will automatically take care of the pagination for you, i.e. pass the correct :paramref:`next_offset` and truncate the results list/get the results from the callable you passed. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2583,9 +2781,18 @@ async def get_user_profile_photos( By default, all photos are returned. limit (:obj:`int`, optional): Limits the number of photos to be retrieved. Values between 1-100 are accepted. Defaults to ``100``. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2649,9 +2856,18 @@ async def get_file( :class:`telegram.Voice`): Either the file identifier or an object that has a file_id attribute to get file information about. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2713,9 +2929,18 @@ async def ban_chat_member( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target group or username of the target supergroup or channel (in the format ``@channelusername``). user_id (:obj:`int`): Unique identifier of the target user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. until_date (:obj:`int` | :obj:`datetime.datetime`, optional): Date when the user will be unbanned, unix time. If user is banned for more than 366 days or less than 30 seconds from the current time they are considered to be banned forever. Applied @@ -2781,9 +3006,18 @@ async def ban_chat_sender_chat( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target group or username of the target supergroup or channel (in the format ``@channelusername``). sender_chat_id (:obj:`int`): Unique identifier of the target sender chat. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2833,9 +3067,18 @@ async def unban_chat_member( of the target supergroup or channel (in the format ``@channelusername``). user_id (:obj:`int`): Unique identifier of the target user. only_if_banned (:obj:`bool`, optional): Do nothing if the user is not banned. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2884,9 +3127,18 @@ async def unban_chat_sender_chat( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup or channel (in the format ``@channelusername``). sender_chat_id (:obj:`int`): Unique identifier of the target sender chat. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -2950,9 +3202,18 @@ async def answer_callback_query( your bot with a parameter. cache_time (:obj:`int`, optional): The maximum amount of time in seconds that the result of the callback query may be cached client-side. Defaults to 0. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3027,9 +3288,18 @@ async def edit_message_text( this message. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for an inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3105,9 +3375,18 @@ async def edit_message_caption( :paramref:`parse_mode`. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for an inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3182,9 +3461,18 @@ async def edit_message_media( specified. Identifier of the inline message. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for an inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3248,9 +3536,18 @@ async def edit_message_reply_markup( specified. Identifier of the inline message. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for an inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3307,18 +3604,27 @@ async def get_updates( offset (:obj:`int`, optional): Identifier of the first update to be returned. Must be greater by one than the highest among the identifiers of previously received updates. By default, updates starting with the earliest unconfirmed update are - returned. An update is considered confirmed as soon as getUpdates is called with an - offset higher than its :attr:`telegram.Update.update_id`. The negative offset can - be specified to retrieve updates starting from -offset update from the end of the - updates queue. All previous updates will forgotten. + returned. An update is considered confirmed as soon as this method is called with + an offset higher than its :attr:`telegram.Update.update_id`. The negative offset + can be specified to retrieve updates starting from -offset update from the end of + the updates queue. All previous updates will forgotten. limit (:obj:`int`, optional): Limits the number of updates to be retrieved. Values between 1-100 are accepted. Defaults to ``100``. timeout (:obj:`int`, optional): Timeout in seconds for long polling. Defaults to ``0``, i.e. usual short polling. Should be positive, short polling should be used for testing purposes only. - read_latency (:obj:`float` | :obj:`int`, optional): Grace time in seconds for receiving - the reply from server. Will be added to the ``timeout`` value and used as the read - timeout from server. Defaults to ``2``. + read_timeout (:obj:`float`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + ``2``. :paramref:`timeout` will be added to this value. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. allowed_updates (List[:obj:`str`]), optional): A list the types of updates you want your bot to receive. For example, specify ["message", "edited_channel_post", "callback_query"] to only receive updates of these types. @@ -3429,9 +3735,18 @@ async def set_webhook( a short period of time. drop_pending_updates (:obj:`bool`, optional): Pass :obj:`True` to drop all pending updates. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3498,9 +3813,18 @@ async def delete_webhook( Args: drop_pending_updates (:obj:`bool`, optional): Pass :obj:`True` to drop all pending updates. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3543,9 +3867,18 @@ async def leave_chat( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup or channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3587,9 +3920,18 @@ async def get_chat( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup or channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3630,9 +3972,18 @@ async def get_chat_administrators( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup or channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3675,9 +4026,18 @@ async def get_chat_member_count( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup or channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3717,9 +4077,18 @@ async def get_chat_member( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup or channel (in the format ``@channelusername``). user_id (:obj:`int`): Unique identifier of the target user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3763,9 +4132,18 @@ async def set_chat_sticker_set( of the target supergroup (in the format @supergroupusername). sticker_set_name (:obj:`str`): Name of the sticker set to be set as the group sticker set. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3802,9 +4180,18 @@ async def delete_chat_sticker_set( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup (in the format @supergroupusername). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3837,9 +4224,18 @@ async def get_webhook_info( :attr:`telegram.WebhookInfo.url` field empty. Args: - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3890,9 +4286,18 @@ async def set_game_score( Identifier of the sent message. inline_message_id (:obj:`str`, optional): Required if chat_id and message_id are not specified. Identifier of the inline message. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -3958,9 +4363,18 @@ async def get_game_high_scores( Identifier of the sent message. inline_message_id (:obj:`str`, optional): Required if chat_id and message_id are not specified. Identifier of the inline message. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4113,9 +4527,18 @@ async def send_invoice( reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for an inline keyboard. If empty, one 'Pay total price' button will be shown. If not empty, the first button must be a Pay button. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4214,9 +4637,18 @@ async def answer_shipping_query( # pylint: disable=invalid-name human readable form that explains why it is impossible to complete the order (e.g. "Sorry, delivery to your desired address is unavailable"). Telegram will display this message to the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4296,9 +4728,18 @@ async def answer_pre_checkout_query( # pylint: disable=invalid-name the checkout (e.g. "Sorry, somebody just bought the last of our amazing black T-shirts while you were busy filling out your payment details. Please choose a different color or garment!"). Telegram will display this message to the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4351,12 +4792,7 @@ async def restrict_chat_member( """ Use this method to restrict a user in a supergroup. The bot must be an administrator in the supergroup for this to work and must have the appropriate admin rights. Pass - :obj:`True` for all boolean parameters to lift restrictions from a user. - - Note: - Since Bot API 4.4, :meth:`restrict_chat_member` takes the new user permissions in a - single argument of type :class:`telegram.ChatPermissions`. The old way of passing - parameters will not keep working forever. + :obj:`True` for all boolean parameters in :class:`telegram.ChatPermissions` to lift restrictions from a user. Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username @@ -4370,9 +4806,18 @@ async def restrict_chat_member( bot will be used. permissions (:class:`telegram.ChatPermissions`): An object for new user permissions. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4466,9 +4911,18 @@ async def promote_chat_member( add new administrators with a subset of his own privileges or demote administrators that he has promoted, directly or indirectly (promoted by administrators that were appointed by him). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4536,9 +4990,18 @@ async def set_chat_permissions( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target supergroup (in the format `@supergroupusername`). permissions (:class:`telegram.ChatPermissions`): New default chat permissions. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4583,9 +5046,18 @@ async def set_chat_administrator_custom_title( user_id (:obj:`int`): Unique identifier of the target administrator. custom_title (:obj:`str`): New custom title for the administrator; 0-16 characters, emoji are not allowed. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4628,9 +5100,18 @@ async def export_chat_invite_link( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4691,9 +5172,18 @@ async def create_chat_invite_link( member_limit (:obj:`int`, optional): Maximum number of users that can be members of the chat simultaneously after joining the chat via this invite link; 1-:tg-const:`telegram.constants.ChatInviteLinkLimit.MEMBER_LIMIT`. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. name (:obj:`str`, optional): Invite link name; @@ -4787,9 +5277,18 @@ async def edit_chat_invite_link( member_limit (:obj:`int`, optional): Maximum number of users that can be members of the chat simultaneously after joining the chat via this invite link; 1-:tg-const:`telegram.constants.ChatInviteLinkLimit.MEMBER_LIMIT`. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. name (:obj:`str`, optional): Invite link name; @@ -4866,9 +5365,18 @@ async def revoke_chat_invite_link( .. versionchanged:: 14.0 Now also accepts :obj:`telegram.ChatInviteLink` instances. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4916,9 +5424,18 @@ async def approve_chat_join_request( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). user_id (:obj:`int`): Unique identifier of the target user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4964,9 +5481,18 @@ async def decline_chat_join_request( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). user_id (:obj:`int`): Unique identifier of the target user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -4995,8 +5521,8 @@ async def set_chat_photo( self, chat_id: Union[str, int], photo: FileInput, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -5013,9 +5539,18 @@ async def set_chat_photo( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5056,9 +5591,18 @@ async def delete_chat_photo( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5101,9 +5645,18 @@ async def set_chat_title( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). title (:obj:`str`): New chat title, 1-255 characters. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5146,9 +5699,18 @@ async def set_chat_description( chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). description (:obj:`str`, optional): New chat description, 0-255 characters. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5199,9 +5761,18 @@ async def pin_chat_message( disable_notification (:obj:`bool`, optional): Pass :obj:`True`, if it is not necessary to send a notification to all chat members about the new pinned message. Notifications are always disabled in channels and private chats. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5251,9 +5822,18 @@ async def unpin_chat_message( of the target channel (in the format ``@channelusername``). message_id (:obj:`int`, optional): Identifier of a message to unpin. If not specified, the most recent pinned message (by sending date) will be unpinned. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5299,9 +5879,18 @@ async def unpin_all_chat_messages( Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username of the target channel (in the format ``@channelusername``). - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5337,9 +5926,18 @@ async def get_sticker_set( Args: name (:obj:`str`): Name of the sticker set. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5367,8 +5965,8 @@ async def upload_sticker_file( self, user_id: Union[str, int], png_sticker: FileInput, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -5390,9 +5988,18 @@ async def upload_sticker_file( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5425,8 +6032,8 @@ async def create_new_sticker_set( png_sticker: FileInput = None, contains_masks: bool = None, mask_position: MaskPosition = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, tgs_sticker: FileInput = None, @@ -5485,9 +6092,18 @@ async def create_new_sticker_set( should be created. mask_position (:class:`telegram.MaskPosition`, optional): Position where the mask should be placed on faces. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5531,8 +6147,8 @@ async def add_sticker_to_set( emojis: str, png_sticker: FileInput = None, mask_position: MaskPosition = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, tgs_sticker: FileInput = None, @@ -5585,9 +6201,18 @@ async def add_sticker_to_set( emojis (:obj:`str`): One or more emoji corresponding to the sticker. mask_position (:class:`telegram.MaskPosition`, optional): Position where the mask should be placed on faces. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5637,9 +6262,18 @@ async def set_sticker_position_in_set( Args: sticker (:obj:`str`): File identifier of the sticker. position (:obj:`int`): New sticker position in the set, zero-based. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5676,9 +6310,18 @@ async def delete_sticker_from_set( Args: sticker (:obj:`str`): File identifier of the sticker. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5738,9 +6381,18 @@ async def set_sticker_set_thumb( .. versionchanged:: 13.2 Accept :obj:`bytes` as input. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5792,9 +6444,18 @@ async def set_passport_data_errors( user_id (:obj:`int`): User identifier errors (List[:class:`PassportElementError`]): An array describing the errors. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during - creation of the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5897,9 +6558,18 @@ async def send_poll( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -5972,9 +6642,18 @@ async def stop_poll( message_id (:obj:`int`): Identifier of the original message with the poll. reply_markup (:class:`telegram.InlineKeyboardMarkup`, optional): An object for a new message inline keyboard. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -6050,9 +6729,18 @@ async def send_dice( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -6098,9 +6786,18 @@ async def get_my_commands( language. Args: - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. scope (:class:`telegram.BotCommandScope`, optional): An object, @@ -6162,9 +6859,18 @@ async def set_my_commands( commands (List[:class:`BotCommand` | (:obj:`str`, :obj:`str`)]): A list of bot commands to be set as the list of the bot's commands. At most 100 commands can be specified. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. scope (:class:`telegram.BotCommandScope`, optional): An object, @@ -6233,9 +6939,18 @@ def delete_my_commands( language_code (:obj:`str`, optional): A two-letter ISO 639-1 language code. If empty, commands will be applied to all users from the given scope, for whose language there are no dedicated commands. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. @@ -6281,9 +6996,18 @@ async def log_out( minutes. Args: - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. Returns: :obj:`True`: On success @@ -6315,9 +7039,18 @@ async def close( 10 minutes after the bot is launched. Args: - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. Returns: :obj:`True`: On success @@ -6388,9 +7121,18 @@ async def copy_message( :class:`ReplyKeyboardRemove` | :class:`ForceReply`, optional): Additional interface options. An object for an inline keyboard, custom reply keyboard, instructions to remove reply keyboard or to force a reply from the user. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. api_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to be passed to the Telegram API. diff --git a/telegram/_callbackquery.py b/telegram/_callbackquery.py index 17e1712fc2e..b49dc7926c4 100644 --- a/telegram/_callbackquery.py +++ b/telegram/_callbackquery.py @@ -47,7 +47,7 @@ class CallbackQuery(TelegramObject): considered equal, if their :attr:`id` is equal. Note: - * In Python :keyword:`from` is a reserved word, :paramref:`from_user` + * In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. * Exactly one of the fields :attr:`data` or :attr:`game_short_name` will be present. * After the user presses an inline button, Telegram clients will display a progress bar until you call :attr:`answer`. It is, therefore, necessary to react diff --git a/telegram/_chat.py b/telegram/_chat.py index c6774b515ec..4fb931a29f6 100644 --- a/telegram/_chat.py +++ b/telegram/_chat.py @@ -68,7 +68,7 @@ class Chat(TelegramObject): Args: id (:obj:`int`): Unique identifier for this chat. This number may be greater than 32 bits and some programming languages may have difficulty/silent defects in interpreting it. - But it is smaller than 52 bits, so a signed 64 bit integer or double-precision float + But it is smaller than 52 bits, so a signed 64-bit integer or double-precision float type are safe for storing this identifier. type (:obj:`str`): Type of chat, can be either :attr:`PRIVATE`, :attr:`GROUP`, :attr:`SUPERGROUP` or :attr:`CHANNEL`. @@ -906,8 +906,8 @@ async def send_media_group( ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -977,8 +977,8 @@ async def send_photo( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1074,8 +1074,8 @@ async def send_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1127,8 +1127,8 @@ async def send_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1398,8 +1398,8 @@ async def send_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1447,8 +1447,8 @@ async def send_sticker( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1543,8 +1543,8 @@ async def send_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, @@ -1600,8 +1600,8 @@ async def send_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, @@ -1647,8 +1647,8 @@ async def send_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, diff --git a/telegram/_chatjoinrequest.py b/telegram/_chatjoinrequest.py index 98cc5808a4c..daca7b41791 100644 --- a/telegram/_chatjoinrequest.py +++ b/telegram/_chatjoinrequest.py @@ -62,13 +62,7 @@ class ChatJoinRequest(TelegramObject): """ - __slots__ = ( - 'chat', - 'from_user', - 'date', - 'bio', - 'invite_link', - ) + __slots__ = ('chat', 'from_user', 'date', 'bio', 'invite_link') def __init__( self, diff --git a/telegram/_chatmemberupdated.py b/telegram/_chatmemberupdated.py index 1944f452784..640b7a2289a 100644 --- a/telegram/_chatmemberupdated.py +++ b/telegram/_chatmemberupdated.py @@ -38,7 +38,7 @@ class ChatMemberUpdated(TelegramObject): .. versionadded:: 13.4 Note: - In Python :keyword:`from` is a reserved word, :paramref:`from_user` + In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. Args: chat (:class:`telegram.Chat`): Chat the user belongs to. diff --git a/telegram/_choseninlineresult.py b/telegram/_choseninlineresult.py index bc3bdf8a006..d40500fe801 100644 --- a/telegram/_choseninlineresult.py +++ b/telegram/_choseninlineresult.py @@ -37,7 +37,7 @@ class ChosenInlineResult(TelegramObject): considered equal, if their :attr:`result_id` is equal. Note: - * In Python :keyword:`from` is a reserved word, :paramref:`from_user` + * In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. * It is necessary to enable inline feedback via `@Botfather `_ in order to receive these objects in updates. diff --git a/telegram/_files/_basethumbedmedium.py b/telegram/_files/_basethumbedmedium.py index fbfbed7abee..548b5112c6a 100644 --- a/telegram/_files/_basethumbedmedium.py +++ b/telegram/_files/_basethumbedmedium.py @@ -30,8 +30,8 @@ class _BaseThumbedMedium(_BaseMedium): - """Base class for objects representing the various media file types that may include a - thumbnail. + """ + Base class for objects representing the various media file types that may include a thumbnail. Objects of this class are comparable in terms of equality. Two objects of this class are considered equal, if their :attr:`file_unique_id` is equal. diff --git a/telegram/_files/file.py b/telegram/_files/file.py index 843cc0ee9f8..5d5b036746b 100644 --- a/telegram/_files/file.py +++ b/telegram/_files/file.py @@ -46,7 +46,7 @@ class File(TelegramObject): * Maximum file size to download is :tg-const:`telegram.constants.FileSizeLimit.FILESIZE_DOWNLOAD`. * If you obtain an instance of this class from :attr:`telegram.PassportFile.get_file`, - then it will automatically be decrypted as it downloads when you call :attr:`download()`. + then it will automatically be decrypted as it downloads when you call :meth:`download()`. Args: file_id (:obj:`str`): Identifier for this file, which can be used to download @@ -65,7 +65,7 @@ class File(TelegramObject): is supposed to be the same over time and for different bots. Can't be used to download or reuse the file. file_size (:obj:`str`): Optional. File size in bytes. - file_path (:obj:`str`): Optional. File path. Use :attr:`download` to get the file. + file_path (:obj:`str`): Optional. File path. Use :meth:`download` to get the file. """ @@ -102,8 +102,8 @@ async def download( custom_path: FilePathInput = None, out: IO = None, read_timeout: ODVInput[float] = DEFAULT_NONE, - connect_timeout: ODVInput[float] = DEFAULT_NONE, write_timeout: ODVInput[float] = DEFAULT_NONE, + connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, ) -> Union[Path, IO]: """ @@ -129,9 +129,18 @@ async def download( custom_path (:class:`pathlib.Path` | :obj:`str`, optional): Custom path. out (:obj:`io.BufferedWriter`, optional): A file-like object. Must be opened for writing in binary mode, if applicable. - timeout (:obj:`int` | :obj:`float`, optional): If this value is specified, use it as - the read timeout from the server (instead of the one specified during creation of - the connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.read_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.request.BaseRequest.post.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. Returns: :class:`pathlib.Path` | :obj:`io.BufferedWriter`: The same object as :paramref:`out` if diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index cc68ef622a1..ec77c1a1293 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -40,17 +40,15 @@ class InputFile: bytes. Note: - If ``obj`` is a string, it will be encoded as bytes via ``obj.encode('utf-8')``. + If :paramref:`obj` is a string, it will be encoded as bytes via + :external:obj:`obj.encode('utf-8') `. filename (:obj:`str`, optional): Filename for this InputFile. - Raises: - TelegramError - Attributes: input_file_content (:obj:`bytes`): The binary content of the file to send. attach_name (:obj:`str`): Attach name. - filename (:obj:`str`): Optional. Filename for the file to be sent. - mimetype (:obj:`str`): Optional. The mimetype inferred from the file to be sent. + filename (:obj:`str`): Filename for the file to be sent. + mimetype (:obj:`str`): The mimetype inferred from the file to be sent. """ diff --git a/telegram/_files/inputmedia.py b/telegram/_files/inputmedia.py index 4552e60da99..efa4665db4d 100644 --- a/telegram/_files/inputmedia.py +++ b/telegram/_files/inputmedia.py @@ -46,7 +46,7 @@ class InputMedia(TelegramObject): :attr:`caption_entities`, :paramref:`parse_mode`. Args: - media_type (:obj:`str`) Type of media that the instance represents. + media_type (:obj:`str`): Type of media that the instance represents. media (:obj:`str` | :term:`file object` | :obj:`bytes` | :class:`pathlib.Path` | \ :class:`telegram.Animation` | :class:`telegram.Audio` | \ :class:`telegram.Document` | :class:`telegram.PhotoSize` | \ @@ -55,10 +55,12 @@ class InputMedia(TelegramObject): (recommended), pass an HTTP URL for Telegram to get a file from the Internet. Lastly you can pass an existing telegram media object of the corresponding type to send. - caption (:obj:`str`, optional): Caption of the media to be sent, 0-1024 characters - after entities parsing. + caption (:obj:`str`, optional): Caption of the media to be sent, + 0-:tg-const:`telegram.constants.MessageLimit.CAPTION_LENGTH` characters after entities + parsing. caption_entities (List[:class:`telegram.MessageEntity`], optional): List of special - entities that appear in the caption, which can be specified instead of parse_mode. + entities that appear in the caption, which can be specified instead of + :paramref:`parse_mode`. parse_mode (:obj:`str`, optional): Send Markdown or HTML, if you want Telegram apps to show bold, italic, fixed-width text or inline URLs in the media caption. See the constants in :class:`telegram.constants.ParseMode` for the available modes. @@ -108,7 +110,7 @@ class InputMediaAnimation(InputMedia): """Represents an animation file (GIF or H.264/MPEG-4 AVC video without sound) to be sent. Note: - When using a :class:`telegram.Animation` for the :attr:`media` attribute. It will take the + When using a :class:`telegram.Animation` for the :attr:`media` attribute, it will take the width, height and duration from that video, unless otherwise specified with the optional arguments. @@ -129,8 +131,8 @@ class InputMediaAnimation(InputMedia): thumb (:term:`file object` | :obj:`bytes` | :class:`pathlib.Path`, optional): Thumbnail of the file sent; can be ignored if thumbnail generation for the file is supported server-side. The thumbnail should be - in JPEG format and less than 200 kB in size. A thumbnail's width and height should - not exceed 320. Ignored if the file is not uploaded using multipart/form-data. + in JPEG format and less than ``200`` kB in size. A thumbnail's width and height should + not exceed ``320``. Ignored if the file is not uploaded using multipart/form-data. Thumbnails can't be reused and can be only uploaded as a new file. .. versionchanged:: 13.2 @@ -142,7 +144,8 @@ class InputMediaAnimation(InputMedia): bold, italic, fixed-width text or inline URLs in the media caption. See the constants in :class:`telegram.constants.ParseMode` for the available modes. caption_entities (List[:class:`telegram.MessageEntity`], optional): List of special - entities that appear in the caption, which can be specified instead of parse_mode. + entities that appear in the caption, which can be specified instead of + :paramref:`parse_mode`. width (:obj:`int`, optional): Animation width. height (:obj:`int`, optional): Animation height. duration (:obj:`int`, optional): Animation duration in seconds. @@ -214,7 +217,8 @@ class InputMediaPhoto(InputMedia): bold, italic, fixed-width text or inline URLs in the media caption. See the constants in :class:`telegram.constants.ParseMode` for the available modes. caption_entities (List[:class:`telegram.MessageEntity`], optional): List of special - entities that appear in the caption, which can be specified instead of parse_mode. + entities that appear in the caption, which can be specified instead of + :paramref:`parse_mode`. Attributes: type (:obj:`str`): :tg-const:`telegram.constants.InputMediaType.PHOTO`. @@ -244,10 +248,10 @@ class InputMediaVideo(InputMedia): """Represents a video to be sent. Note: - * When using a :class:`telegram.Video` for the :attr:`media` attribute. It will take the + * When using a :class:`telegram.Video` for the :attr:`media` attribute, it will take the width, height and duration from that video, unless otherwise specified with the optional arguments. - * ``thumb`` will be ignored for small video files, for which Telegram can easily + * :paramref:`thumb` will be ignored for small video files, for which Telegram can easily generate thumbnails. However, this behaviour is undocumented and might be changed by Telegram. @@ -272,7 +276,8 @@ class InputMediaVideo(InputMedia): bold, italic, fixed-width text or inline URLs in the media caption. See the constants in :class:`telegram.constants.ParseMode` for the available modes. caption_entities (List[:class:`telegram.MessageEntity`], optional): List of special - entities that appear in the caption, which can be specified instead of parse_mode. + entities that appear in the caption, which can be specified instead of + :paramref:`parse_mode`. width (:obj:`int`, optional): Video width. height (:obj:`int`, optional): Video height. duration (:obj:`int`, optional): Video duration in seconds. @@ -281,8 +286,8 @@ class InputMediaVideo(InputMedia): thumb (:term:`file object` | :obj:`bytes` | :class:`pathlib.Path`, optional): Thumbnail of the file sent; can be ignored if thumbnail generation for the file is supported server-side. The thumbnail should be - in JPEG format and less than 200 kB in size. A thumbnail's width and height should - not exceed 320. Ignored if the file is not uploaded using multipart/form-data. + in JPEG format and less than ``200`` kB in size. A thumbnail's width and height should + not exceed ``320``. Ignored if the file is not uploaded using multipart/form-data. Thumbnails can't be reused and can be only uploaded as a new file. .. versionchanged:: 13.2 @@ -340,7 +345,7 @@ class InputMediaAudio(InputMedia): """Represents an audio file to be treated as music to be sent. Note: - When using a :class:`telegram.Audio` for the :attr:`media` attribute. It will take the + When using a :class:`telegram.Audio` for the :attr:`media` attribute, it will take the duration, performer and title from that video, unless otherwise specified with the optional arguments. @@ -366,7 +371,8 @@ class InputMediaAudio(InputMedia): bold, italic, fixed-width text or inline URLs in the media caption. See the constants in :class:`telegram.constants.ParseMode` for the available modes. caption_entities (List[:class:`telegram.MessageEntity`], optional): List of special - entities that appear in the caption, which can be specified instead of parse_mode. + entities that appear in the caption, which can be specified instead of + :paramref:`parse_mode`. duration (:obj:`int`): Duration of the audio in seconds as defined by sender. performer (:obj:`str`, optional): Performer of the audio as defined by sender or by audio tags. @@ -374,8 +380,8 @@ class InputMediaAudio(InputMedia): thumb (:term:`file object` | :obj:`bytes` | :class:`pathlib.Path`, optional): Thumbnail of the file sent; can be ignored if thumbnail generation for the file is supported server-side. The thumbnail should be - in JPEG format and less than 200 kB in size. A thumbnail's width and height should - not exceed 320. Ignored if the file is not uploaded using multipart/form-data. + in JPEG format and less than ``200`` kB in size. A thumbnail's width and height should + not exceed ``320``. Ignored if the file is not uploaded using multipart/form-data. Thumbnails can't be reused and can be only uploaded as a new file. .. versionchanged:: 13.2 @@ -449,19 +455,20 @@ class InputMediaDocument(InputMedia): bold, italic, fixed-width text or inline URLs in the media caption. See the constants in :class:`telegram.constants.ParseMode` for the available modes. caption_entities (List[:class:`telegram.MessageEntity`], optional): List of special - entities that appear in the caption, which can be specified instead of parse_mode. + entities that appear in the caption, which can be specified instead of + :paramref:`parse_mode`. thumb (:term:`file object` | :obj:`bytes` | :class:`pathlib.Path`, optional): Thumbnail of the file sent; can be ignored if thumbnail generation for the file is supported server-side. The thumbnail should be - in JPEG format and less than 200 kB in size. A thumbnail's width and height should - not exceed 320. Ignored if the file is not uploaded using multipart/form-data. + in JPEG format and less than ``200`` kB in size. A thumbnail's width and height should + not exceed ``320``. Ignored if the file is not uploaded using multipart/form-data. Thumbnails can't be reused and can be only uploaded as a new file. .. versionchanged:: 13.2 Accept :obj:`bytes` as input. disable_content_type_detection (:obj:`bool`, optional): Disables automatic server-side - content type detection for files uploaded using multipart/form-data. Always true, if - the document is sent as part of an album. + content type detection for files uploaded using multipart/form-data. Always + :obj:`True`, if the document is sent as part of an album. Attributes: type (:obj:`str`): :tg-const:`telegram.constants.InputMediaType.DOCUMENT`. diff --git a/telegram/_files/sticker.py b/telegram/_files/sticker.py index 6397b002751..58fbe4d69ef 100644 --- a/telegram/_files/sticker.py +++ b/telegram/_files/sticker.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains objects that represents stickers.""" +"""This module contains objects that represent stickers.""" from typing import TYPE_CHECKING, Any, List, Optional, ClassVar @@ -35,7 +35,7 @@ class Sticker(_BaseThumbedMedium): considered equal, if their :attr:`file_unique_id` is equal. Note: - As of v13.11 ``is_video`` is a required argument and therefore the order of the + As of v13.11 :paramref:`is_video` is a required argument and therefore the order of the arguments had to be changed. Use keyword arguments to make sure that the arguments are passed correctly. @@ -51,8 +51,8 @@ class Sticker(_BaseThumbedMedium): is_video (:obj:`bool`): :obj:`True`, if the sticker is a video sticker. .. versionadded:: 13.11 - thumb (:class:`telegram.PhotoSize`, optional): Sticker thumbnail in the .WEBP or .JPG - format. + thumb (:class:`telegram.PhotoSize`, optional): Sticker thumbnail in the ``.WEBP`` or + ``.JPG`` format. emoji (:obj:`str`, optional): Emoji associated with the sticker set_name (:obj:`str`, optional): Name of the sticker set to which the sticker belongs. @@ -73,8 +73,8 @@ class Sticker(_BaseThumbedMedium): is_video (:obj:`bool`): :obj:`True`, if the sticker is a video sticker. .. versionadded:: 13.11 - thumb (:class:`telegram.PhotoSize`): Optional. Sticker thumbnail in the .webp or .jpg - format. + thumb (:class:`telegram.PhotoSize`): Optional. Sticker thumbnail in the ``.WEBP`` or + ``.JPG`` format. emoji (:obj:`str`): Optional. Emoji associated with the sticker. set_name (:obj:`str`): Optional. Name of the sticker set to which the sticker belongs. mask_position (:class:`telegram.MaskPosition`): Optional. For mask stickers, the position @@ -148,7 +148,7 @@ class StickerSet(TelegramObject): considered equal, if their :attr:`name` is equal. Note: - As of v13.11 ``is_video`` is a required argument and therefore the order of the + As of v13.11 :paramref:`is_video` is a required argument and therefore the order of the arguments had to be changed. Use keyword arguments to make sure that the arguments are passed correctly. @@ -241,12 +241,12 @@ class MaskPosition(TelegramObject): point (:obj:`str`): The part of the face relative to which the mask should be placed. One of :attr:`FOREHEAD`, :attr:`EYES`, :attr:`MOUTH`, or :attr:`CHIN`. x_shift (:obj:`float`): Shift by X-axis measured in widths of the mask scaled to the face - size, from left to right. For example, choosing -1.0 will place mask just to the left - of the default mask position. + size, from left to right. For example, choosing ``-1.0`` will place mask just to the + left of the default mask position. y_shift (:obj:`float`): Shift by Y-axis measured in heights of the mask scaled to the face - size, from top to bottom. For example, 1.0 will place the mask just below the default - mask position. - scale (:obj:`float`): Mask scaling coefficient. For example, 2.0 means double size. + size, from top to bottom. For example, ``1.0`` will place the mask just below the + default mask position. + scale (:obj:`float`): Mask scaling coefficient. For example, ``2.0`` means double size. Attributes: point (:obj:`str`): The part of the face relative to which the mask should be placed. @@ -255,7 +255,7 @@ class MaskPosition(TelegramObject): size, from left to right. y_shift (:obj:`float`): Shift by Y-axis measured in heights of the mask scaled to the face size, from top to bottom. - scale (:obj:`float`): Mask scaling coefficient. For example, 2.0 means double size. + scale (:obj:`float`): Mask scaling coefficient. For example, ``2.0`` means double size. """ diff --git a/telegram/_files/video.py b/telegram/_files/video.py index 73dca49bf24..930f64b67f2 100644 --- a/telegram/_files/video.py +++ b/telegram/_files/video.py @@ -44,7 +44,7 @@ class Video(_BaseThumbedMedium): duration (:obj:`int`): Duration of the video in seconds as defined by sender. thumb (:class:`telegram.PhotoSize`, optional): Video thumbnail. file_name (:obj:`str`, optional): Original filename as defined by sender. - mime_type (:obj:`str`, optional): Mime type of a file as defined by sender. + mime_type (:obj:`str`, optional): MIME type of a file as defined by sender. file_size (:obj:`int`, optional): File size in bytes. bot (:class:`telegram.Bot`, optional): The Bot to use for instance methods. **kwargs (:obj:`dict`): Arbitrary keyword arguments. @@ -59,7 +59,7 @@ class Video(_BaseThumbedMedium): duration (:obj:`int`): Duration of the video in seconds as defined by sender. thumb (:class:`telegram.PhotoSize`): Optional. Video thumbnail. file_name (:obj:`str`): Optional. Original filename as defined by sender. - mime_type (:obj:`str`): Optional. Mime type of a file as defined by sender. + mime_type (:obj:`str`): Optional. MIME type of a file as defined by sender. file_size (:obj:`int`): Optional. File size in bytes. bot (:class:`telegram.Bot`): Optional. The Bot to use for instance methods. diff --git a/telegram/_games/game.py b/telegram/_games/game.py index dd7e48bf464..f5e2edf0907 100644 --- a/telegram/_games/game.py +++ b/telegram/_games/game.py @@ -154,8 +154,9 @@ def parse_text_entity(self, entity: MessageEntity) -> str: def parse_text_entities(self, types: List[str] = None) -> Dict[MessageEntity, str]: """ Returns a :obj:`dict` that maps :class:`telegram.MessageEntity` to :obj:`str`. - It contains entities from this message filtered by their ``type`` attribute as the key, and - the text that each entity belongs to as the value of the :obj:`dict`. + It contains entities from this message filtered by their + :attr:`~telegram.MessageEntity.type` attribute as the key, and the text that each entity + belongs to as the value of the :obj:`dict`. Note: This method should always be used instead of the :attr:`text_entities` attribute, since @@ -163,9 +164,10 @@ def parse_text_entities(self, types: List[str] = None) -> Dict[MessageEntity, st See :attr:`parse_text_entity` for more info. Args: - types (List[:obj:`str`], optional): List of ``MessageEntity`` types as strings. If the - ``type`` attribute of an entity is contained in this list, it will be returned. - Defaults to :attr:`telegram.MessageEntity.ALL_TYPES`. + types (List[:obj:`str`], optional): List of :class:`telegram.MessageEntity` types as + strings. If the :attr:`~telegram.MessageEntity.type` attribute of an entity is + contained in this list, it will be returned. Defaults to + :attr:`telegram.MessageEntity.ALL_TYPES`. Returns: Dict[:class:`telegram.MessageEntity`, :obj:`str`]: A dictionary of entities mapped to diff --git a/telegram/_inline/inlinequery.py b/telegram/_inline/inlinequery.py index 1fcd052b02e..026d5ebc481 100644 --- a/telegram/_inline/inlinequery.py +++ b/telegram/_inline/inlinequery.py @@ -38,7 +38,7 @@ class InlineQuery(TelegramObject): considered equal, if their :attr:`id` is equal. Note: - In Python :keyword:`from` is a reserved word, :paramref:`from_user` + In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. Args: id (:obj:`str`): Unique identifier for this query. @@ -130,11 +130,11 @@ async def answer( ) -> bool: """Shortcut for:: - bot.answer_inline_query( - update.inline_query.id, - *args, - current_offset=self.offset if auto_pagination else None, - **kwargs + await bot.answer_inline_query( + update.inline_query.id, + *args, + current_offset=self.offset if auto_pagination else None, + **kwargs ) For the documentation of the arguments, please see @@ -149,8 +149,7 @@ async def answer( Defaults to :obj:`False`. Raises: - ValueError: If both - :paramref:`~telegram.Bot.answer_inline_query.current_offset` and + ValueError: If both :paramref:`~telegram.Bot.answer_inline_query.current_offset` and :paramref:`auto_pagination` are supplied. """ if current_offset and auto_pagination: diff --git a/telegram/_message.py b/telegram/_message.py index 0e5a5f1b6ac..975d35cd1cf 100644 --- a/telegram/_message.py +++ b/telegram/_message.py @@ -81,7 +81,7 @@ class Message(TelegramObject): considered equal, if their :attr:`message_id` and :attr:`chat` are equal. Note: - In Python :keyword:`from` is a reserved word, :paramref:`from_user` + In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. Args: message_id (:obj:`int`): Unique message identifier inside this chat. @@ -963,8 +963,8 @@ async def reply_media_group( ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1012,8 +1012,8 @@ async def reply_photo( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1070,8 +1070,8 @@ async def reply_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1131,8 +1131,8 @@ async def reply_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -1194,8 +1194,8 @@ async def reply_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1251,8 +1251,8 @@ async def reply_sticker( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -1300,8 +1300,8 @@ async def reply_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, @@ -1365,8 +1365,8 @@ async def reply_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, @@ -1420,8 +1420,8 @@ async def reply_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, diff --git a/telegram/_payment/precheckoutquery.py b/telegram/_payment/precheckoutquery.py index 17524dafae5..897eb4c7d9a 100644 --- a/telegram/_payment/precheckoutquery.py +++ b/telegram/_payment/precheckoutquery.py @@ -35,7 +35,7 @@ class PreCheckoutQuery(TelegramObject): considered equal, if their :attr:`id` is equal. Note: - In Python :keyword:`from` is a reserved word, :paramref:`from_user` + In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. Args: id (:obj:`str`): Unique query identifier. diff --git a/telegram/_payment/shippingquery.py b/telegram/_payment/shippingquery.py index 0eca800e61e..901e3e28f14 100644 --- a/telegram/_payment/shippingquery.py +++ b/telegram/_payment/shippingquery.py @@ -35,7 +35,7 @@ class ShippingQuery(TelegramObject): considered equal, if their :attr:`id` is equal. Note: - In Python :keyword:`from` is a reserved word, :paramref:`from_user` + In Python :keyword:`from` is a reserved word use :paramref:`from_user` instead. Args: id (:obj:`str`): Unique query identifier. diff --git a/telegram/_replykeyboardmarkup.py b/telegram/_replykeyboardmarkup.py index 1c77d49a6c0..0fb88825796 100644 --- a/telegram/_replykeyboardmarkup.py +++ b/telegram/_replykeyboardmarkup.py @@ -29,7 +29,7 @@ class ReplyKeyboardMarkup(TelegramObject): """This object represents a custom keyboard with reply options. Objects of this class are comparable in terms of equality. Two objects of this class are - considered equal, if their the size of :attr:`keyboard` and all the buttons are equal. + considered equal, if their size of :attr:`keyboard` and all the buttons are equal. Example: A user requests to change the bot's language, bot replies to the request with a keyboard @@ -37,7 +37,7 @@ class ReplyKeyboardMarkup(TelegramObject): Args: keyboard (List[List[:obj:`str` | :class:`telegram.KeyboardButton`]]): Array of button rows, - each represented by an Array of :class:`telegram.KeyboardButton` objects. + each represented by an Array of :class:`telegram.KeyboardButton` objects. resize_keyboard (:obj:`bool`, optional): Requests clients to resize the keyboard vertically for optimal fit (e.g., make the keyboard smaller if there are just two rows of buttons). Defaults to :obj:`False`, in which case the custom keyboard is always of the diff --git a/telegram/_update.py b/telegram/_update.py index 1aea4c27472..561070a1c76 100644 --- a/telegram/_update.py +++ b/telegram/_update.py @@ -85,10 +85,11 @@ class Update(TelegramObject): .. versionadded:: 13.4 chat_member (:class:`telegram.ChatMemberUpdated`, optional): A chat member's status was updated in a chat. The bot must be an administrator in the chat and must explicitly - specify ``'chat_member'`` in the list of ``'allowed_updates'`` to receive these + specify :attr:`CHAT_MEMBER` in the list of + :paramref:`telegram.ext.Application.run_polling.allowed_updates` to receive these updates (see :meth:`telegram.Bot.get_updates`, :meth:`telegram.Bot.set_webhook`, - :meth:`telegram.ext.Updater.start_polling` and - :meth:`telegram.ext.Updater.start_webhook`). + :meth:`telegram.ext.Application.run_polling` and + :meth:`telegram.ext.Application.run_webhook`). .. versionadded:: 13.4 chat_join_request (:class:`telegram.ChatJoinRequest`, optional): A request to join the @@ -124,15 +125,17 @@ class Update(TelegramObject): .. versionadded:: 13.4 chat_member (:class:`telegram.ChatMemberUpdated`): Optional. A chat member's status was updated in a chat. The bot must be an administrator in the chat and must explicitly - specify ``'chat_member'`` in the list of ``'allowed_updates'`` to receive these + specify :attr:`CHAT_MEMBER` in the list of + :paramref:`telegram.ext.Application.run_polling.allowed_updates` to receive these updates (see :meth:`telegram.Bot.get_updates`, :meth:`telegram.Bot.set_webhook`, - :meth:`telegram.ext.Updater.start_polling` and - :meth:`telegram.ext.Updater.start_webhook`). + :meth:`telegram.ext.Application.run_polling` and + :meth:`telegram.ext.Application.run_webhook`). .. versionadded:: 13.4 chat_join_request (:class:`telegram.ChatJoinRequest`): Optional. A request to join the - chat has been sent. The bot must have the ``'can_invite_users'`` administrator - right in the chat to receive these updates. + chat has been sent. The bot must have the + :attr:`telegram.ChatPermissions.can_invite_users` administrator right in the chat to + receive these updates. .. versionadded:: 13.8 diff --git a/telegram/_user.py b/telegram/_user.py index ab938741701..e880c5f2967 100644 --- a/telegram/_user.py +++ b/telegram/_user.py @@ -399,8 +399,8 @@ async def send_photo( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -446,8 +446,8 @@ async def send_media_group( ], disable_notification: ODVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -488,8 +488,8 @@ async def send_audio( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -659,8 +659,8 @@ async def send_document( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, @@ -891,8 +891,8 @@ async def send_animation( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -940,8 +940,8 @@ async def send_sticker( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, @@ -981,8 +981,8 @@ async def send_video( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, width: int = None, @@ -1093,8 +1093,8 @@ async def send_video_note( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, thumb: FileInput = None, @@ -1140,8 +1140,8 @@ async def send_voice( disable_notification: DVInput[bool] = DEFAULT_NONE, reply_to_message_id: int = None, reply_markup: ReplyMarkup = None, - read_timeout: float = 20, - write_timeout: float = 20, + read_timeout: ODVInput[float] = DEFAULT_NONE, + write_timeout: ODVInput[float] = 20, connect_timeout: ODVInput[float] = DEFAULT_NONE, pool_timeout: ODVInput[float] = DEFAULT_NONE, parse_mode: ODVInput[str] = DEFAULT_NONE, diff --git a/telegram/_userprofilephotos.py b/telegram/_userprofilephotos.py index aff50105d6c..6822477481e 100644 --- a/telegram/_userprofilephotos.py +++ b/telegram/_userprofilephotos.py @@ -28,7 +28,7 @@ class UserProfilePhotos(TelegramObject): - """This object represent a user's profile pictures. + """This object represents a user's profile pictures. Objects of this class are comparable in terms of equality. Two objects of this class are considered equal, if their :attr:`total_count` and :attr:`photos` are equal. diff --git a/telegram/_utils/datetime.py b/telegram/_utils/datetime.py index 86b60d6d903..63803bfff45 100644 --- a/telegram/_utils/datetime.py +++ b/telegram/_utils/datetime.py @@ -57,7 +57,7 @@ def to_float_timestamp( Converts a given time object to a float POSIX timestamp. Used to convert different time specifications to a common format. The time object can be relative (i.e. indicate a time increment, or a time of day) or absolute. - object objects from the :class:`datetime` module that are timezone-naive will be assumed + Objects from the :class:`datetime` module that are timezone-naive will be assumed to be in UTC, if ``bot`` is not passed or ``bot.defaults`` is :obj:`None`. Args: @@ -65,33 +65,36 @@ def to_float_timestamp( :obj:`datetime.datetime` | :obj:`datetime.time`): Time value to convert. The semantics of this parameter will depend on its type: - * :obj:`int` or :obj:`float` will be interpreted as "seconds from ``reference_t``" + * :obj:`int` or :obj:`float` will be interpreted as "seconds from + :paramref:`reference_t`" * :obj:`datetime.timedelta` will be interpreted as - "time increment from ``reference_t``" + "time increment from :paramref:`reference_timestamp`" * :obj:`datetime.datetime` will be interpreted as an absolute date/time value * :obj:`datetime.time` will be interpreted as a specific time of day reference_timestamp (:obj:`float`, optional): POSIX timestamp that indicates the absolute - time from which relative calculations are to be performed (e.g. when ``t`` is given as - an :obj:`int`, indicating "seconds from ``reference_t``"). Defaults to now (the time at - which this function is called). - - If ``t`` is given as an absolute representation of date & time (i.e. a - :obj:`datetime.datetime` object), ``reference_timestamp`` is not relevant and so its - value should be :obj:`None`. If this is not the case, a ``ValueError`` will be raised. - tzinfo (:obj:`pytz.BaseTzInfo`, optional): If ``t`` is a naive object from the - :class:`datetime` module, it will be interpreted as this timezone. Defaults to + time from which relative calculations are to be performed (e.g. when + :paramref:`time_object` is given as an :obj:`int`, indicating "seconds from + :paramref:`reference_time`"). Defaults to now (the time at which this function is + called). + + If :paramref:`time_object` is given as an absolute representation of date & time (i.e. + a :obj:`datetime.datetime` object), :paramref:`reference_timestamp` is not relevant + and so its value should be :obj:`None`. If this is not the case, a :exc:`ValueError` + will be raised. + tzinfo (:obj:`pytz.BaseTzInfo`, optional): If :paramref:`time_object` is a naive object + from the :mod:`datetime` module, it will be interpreted as this timezone. Defaults to ``pytz.utc``. Note: Only to be used by ``telegram.ext``. - Returns: :obj:`float` | :obj:`None`: - The return value depends on the type of argument ``t``. - If ``t`` is given as a time increment (i.e. as a :obj:`int`, :obj:`float` or - :obj:`datetime.timedelta`), then the return value will be ``reference_t`` + ``t``. + The return value depends on the type of argument :paramref:`time_object`. + If :paramref:`time_object` is given as a time increment (i.e. as a :obj:`int`, + :obj:`float` or :obj:`datetime.timedelta`), then the return value will be + :paramref:`reference_timestamp` + :paramref:`time_object`. Else if it is given as an absolute date/time value (i.e. a :obj:`datetime.datetime` object), the equivalent value as a POSIX timestamp will be returned. @@ -100,9 +103,9 @@ def to_float_timestamp( object), the return value is the nearest future occurrence of that time of day. Raises: - TypeError: If ``t``'s type is not one of those described above. - ValueError: If ``t`` is a :obj:`datetime.datetime` and :obj:`reference_timestamp` is not - :obj:`None`. + TypeError: If :paramref:`time_object` s type is not one of those described above. + ValueError: If :paramref:`time_object` is a :obj:`datetime.datetime` and + :paramref:`reference_timestamp` is not :obj:`None`. """ if reference_timestamp is None: reference_timestamp = time.time() @@ -169,7 +172,7 @@ def from_timestamp(unixtime: Optional[int], tzinfo: dtm.tzinfo = UTC) -> Optiona converted to. Defaults to UTC. Returns: - Timezone aware equivalent :obj:`datetime.datetime` value if ``unixtime`` is not + Timezone aware equivalent :obj:`datetime.datetime` value if :paramref:`unixtime` is not :obj:`None`; else :obj:`None`. """ if unixtime is None: diff --git a/telegram/error.py b/telegram/error.py index 69aa9024920..f12c38aeb93 100644 --- a/telegram/error.py +++ b/telegram/error.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains an classes that represent Telegram errors. +"""This module contains classes that represent Telegram errors. .. versionchanged:: 14.0 Replaced ``Unauthorized`` by :class:`Forbidden`. diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 697164ffb94..74302293656 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -76,11 +76,11 @@ class ApplicationHandlerStop(Exception): different group). In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the - optional ``state`` parameter instead of returning the next state: + optional :paramref:`state` parameter instead of returning the next state: .. code-block:: python - def callback(update, context): + async def conversation_callback(update, context): ... raise ApplicationHandlerStop(next_state) @@ -102,9 +102,10 @@ def __init__(self, state: object = None) -> None: class Application(Generic[BT, CCT, UD, CD, BD, JQ]): - """This class dispatches all kinds of updates to its registered handlers. + """This class dispatches all kinds of updates to its registered handlers, and is the entry + point to a PTB application. - Note: + Tip: This class may not be initialized directly. Use :class:`telegram.ext.ApplicationBuilder` or :meth:`builder` (for convenience). @@ -117,7 +118,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): bot (:class:`telegram.Bot`): The bot object that should be passed to the handlers. update_queue (:class:`asyncio.Queue`): The synchronized queue that will contain the updates. - updater (:class:`telegram.ext.Updater`, optional): The updater used by this application. + updater (:class:`telegram.ext.Updater`): Optional. The updater used by this application. job_queue (:class:`telegram.ext.JobQueue`): Optional. The :class:`telegram.ext.JobQueue` instance to pass onto handler callbacks. chat_data (:obj:`types.MappingProxyType`): A dictionary handlers can use to store data for @@ -139,15 +140,15 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): Manually modifying :attr:`user_data` is almost never needed and unadvisable. bot_data (:obj:`dict`): A dictionary handlers can use to store data for the bot. - persistence (:class:`telegram.ext.BasePersistence`): Optional. The persistence class to + persistence (:class:`telegram.ext.BasePersistence`): The persistence class to store data that should be persistent over restarts. handlers (Dict[:obj:`int`, List[:class:`telegram.ext.Handler`]]): A dictionary mapping each handler group to the list of handlers registered to that group. .. seealso:: :meth:`add_handler`, :meth:`add_handlers`. - error_handlers (Dict[:obj:`callable`, :obj:`bool`]): A dict, where the keys are error - handlers and the values indicate whether they are to be run blocking. + error_handlers (Dict[:term:`coroutine function`, :obj:`bool`]): A dict, where the keys are + error handlers and the values indicate whether they are to be run blocking. .. seealso:: :meth:`add_error_handler` @@ -234,7 +235,7 @@ def __init__( raise TypeError("persistence must be based on telegram.ext.BasePersistence") self.persistence = persistence - # Some book keeping for persistence logic + # Some bookkeeping for persistence logic self._chat_ids_to_be_updated_in_persistence: Set[int] = set() self._user_ids_to_be_updated_in_persistence: Set[int] = set() self._chat_ids_to_be_deleted_in_persistence: Set[int] = set() @@ -253,7 +254,7 @@ def __init__( self.__update_persistence_task: Optional[asyncio.Task] = None self.__update_persistence_event = asyncio.Event() self.__update_persistence_lock = asyncio.Lock() - self.__create_task_tasks: Set[asyncio.Task] = set() + self.__create_task_tasks: Set[asyncio.Task] = set() # Used for awaiting tasks upon exit def _check_initialized(self) -> None: if not self._initialized: @@ -272,10 +273,21 @@ def running(self) -> bool: @property def concurrent_updates(self) -> int: - """0 == not concurrent""" + """:obj:`int`: Indicates the number of concurrent updates set. A value of ``0`` indicates + updates are *not* being processed concurrently. + """ return self._concurrent_updates async def initialize(self) -> None: + """Initializes the Application by initializing: + + * The :attr:`bot`, by calling :meth:`telegram.Bot.initialize`. + * The :attr:`updater`, by calling :meth:`telegram.ext.Updater.initialize`. + * The :attr:`persistence`, by loading persistent conversations and data. + + .. seealso:: + :meth:`shutdown` + """ if self._initialized: _logger.debug('This Application is already initialized.') return @@ -307,9 +319,14 @@ async def _add_ch_to_persistence(self, handler: 'ConversationHandler') -> None: ) async def shutdown(self) -> None: - """ + """Shuts down the Application by shutting down: - Returns: + * :attr:`bot` by calling :meth:`telegram.Bot.shutdown` + * :attr:`updater` by calling :meth:`telegram.ext.Updater.shutdown` + * :attr:`persistence` by calling :meth:`update_persistence` and :meth`persistence.flush` + + .. seealso:: + :meth:`initialize` Raises: :exc:`RuntimeError`: If the application is still :attr:`running`. @@ -334,6 +351,7 @@ async def shutdown(self) -> None: self._initialized = False async def __aenter__(self: _AppType) -> _AppType: + """Simple context manager which initializes the App.""" try: await self.initialize() return self @@ -347,11 +365,13 @@ async def __aexit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + """Shutdown the App from the context manager.""" # Make sure not to return `True` so that exceptions are not suppressed # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ await self.shutdown() async def _initialize_persistence(self) -> None: + """This method basically just loads all the data by awaiting the BP methods""" if not self.persistence: return @@ -392,16 +412,18 @@ def builder() -> 'InitApplicationBuilder': async def start(self) -> None: """Starts - * a background task that fetches updates from :attr:`update_queue` and - processes them. - * :attr:`job_queue`, if set - * a background tasks that calls :meth:`update_persistence` in regular intervals, if + * a background task that fetches updates from :attr:`update_queue` and processes them. + * :attr:`job_queue`, if set. + * a background task that calls :meth:`update_persistence` in regular intervals, if :attr:`persistence` is set. Note: This does *not* start fetching updates from Telegram. You need to either start :attr:`updater` manually or use one of :meth:`run_polling` or :meth:`run_webhook`. + .. seealso:: + :meth:`stop` + Raises: :exc:`RuntimeError`: If the application is already running or was not initialized. """ @@ -446,6 +468,9 @@ async def stop(self) -> None: Once this method is called, no more updates will be fetched from :attr:`update_queue`, even if it's not empty. + .. seealso:: + :meth:`start` + Note: This does *not* stop :attr:`updater`. You need to either manually call :meth:`telegram.ext.Updater.stop` or use one of :meth:`run_polling` or @@ -498,8 +523,46 @@ def run_polling( drop_pending_updates: bool = None, close_loop: bool = True, ) -> None: - """Temp docstring to make this referencable - #TODO: Adda meaningful description + """Starts polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling`. + + .. seealso:: + :meth:`telegram.ext.Updater.start_polling`, :meth:`run_webhook` + + Args: + poll_interval (:obj:`float`, optional): Time to wait between polling updates from + Telegram in seconds. Default is ``0.0``. + timeout (:obj:`float`, optional): Passed to + :paramref:`telegram.Bot.get_updates.timeout`. Default is ``10`` seconds. + bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the + :class:`telegram.ext.Updater` will retry on failures on the Telegram server. + + * < 0 - retry indefinitely (default) + * 0 - no retries + * > 0 - retry up to X times + + read_timeout (:obj:`float`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.read_timeout`. Defaults to ``2``. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + drop_pending_updates (:obj:`bool`, optional): Whether to clean any pending updates on + Telegram servers before actually starting to poll. Default is :obj:`False`. + allowed_updates (List[:obj:`str`], optional): Passed to + :meth:`telegram.Bot.get_updates`. + close_loop (:obj:`bool`, optional): If :obj:`True`, the current event loop will be + closed upon shutdown. + + .. seealso:: + :meth:`asyncio.loop.close` + + Raises: + :exc:`RuntimeError`: If the Application does not have an :class:`telegram.ext.Updater`. """ if not self.updater: raise RuntimeError( @@ -520,7 +583,7 @@ def error_callback(exc: TelegramError) -> None: pool_timeout=pool_timeout, allowed_updates=allowed_updates, drop_pending_updates=drop_pending_updates, - error_callback=error_callback, + error_callback=error_callback, # if there is an error in fetching updates ), close_loop=close_loop, ) @@ -540,8 +603,46 @@ def run_webhook( max_connections: int = 40, close_loop: bool = True, ) -> None: - """Temp docstring to make this referencable - #TODO: Adda meaningful description + """ + Starts a small http server to listen for updates via webhook using + :meth:`telegram.ext.Updater.start_webhook`. If :paramref:`cert` + and :paramref:`key` are not provided, the webhook will be started directly on + ``http://listen:port/url_path``, so SSL can be handled by another + application. Else, the webhook will be started on + ``https://listen:port/url_path``. Also calls :meth:`telegram.Bot.set_webhook` as required. + + .. seealso:: + :meth:`telegram.ext.Updater.start_webhook`, :meth:`run_polling` + + Args: + listen (:obj:`str`, optional): IP-Address to listen on. Defaults to + `127.0.0.1 `_. + port (:obj:`int`, optional): Port the bot should be listening on. Must be one of + :attr:`telegram.constants.SUPPORTED_WEBHOOK_PORTS`. Defaults to ``80``. + url_path (:obj:`str`, optional): Path inside url. Defaults to `` '' `` + cert (:class:`pathlib.Path` | :obj:`str`, optional): Path to the SSL certificate file. + key (:class:`pathlib.Path` | :obj:`str`, optional): Path to the SSL key file. + bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the + :class:`telegram.ext.Updater` will retry on failures on the Telegram server. + + * < 0 - retry indefinitely + * 0 - no retries (default) + * > 0 - retry up to X times + webhook_url (:obj:`str`, optional): Explicitly specify the webhook url. Useful behind + NAT, reverse proxy, etc. Default is derived from :paramref:`listen`, + :paramref:`port`, :paramref:`url_path`, :paramref:`cert`, and :paramref:`key`. + allowed_updates (List[:obj:`str`], optional): Passed to + :meth:`telegram.Bot.set_webhook`. + drop_pending_updates (:obj:`bool`, optional): Whether to clean any pending updates on + Telegram servers before actually starting to poll. Default is :obj:`False`. + ip_address (:obj:`str`, optional): Passed to :meth:`telegram.Bot.set_webhook`. + max_connections (:obj:`int`, optional): Passed to + :meth:`telegram.Bot.set_webhook`. Defaults to ``40``. + close_loop (:obj:`bool`, optional): If :obj:`True`, the current event loop will be + closed upon shutdown. Defaults to :obj:`True`. + + .. seealso:: + :meth:`asyncio.loop.close` """ if not self.updater: raise RuntimeError( @@ -572,9 +673,8 @@ def __run(self, updater_coroutine: Coroutine, close_loop: bool = True) -> None: loop = asyncio.get_event_loop() try: loop.run_until_complete(self.initialize()) - loop.run_until_complete(updater_coroutine) + loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling loop.run_until_complete(self.start()) - loop.run_forever() except (KeyboardInterrupt, SystemExit): pass @@ -603,14 +703,14 @@ def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Ta Note: * If :paramref:`coroutine` raises an exception, it will be set on the task created by this method even though it's handled by :meth:`process_error`. - * If the application is currently running, tasks created by this methods will be + * If the application is currently running, tasks created by this method will be awaited by :meth:`stop`. Args: - coroutine: The coroutine to run as task. - update: Optional. If passed, will be passed to :meth:`process_error` as additional - information for the error handlers. Moreover, the corresponding :attr:`chat_data` - and :attr:`user_data` entries will be updated in the next run of + coroutine (:term:`coroutine function`): The coroutine to run as task. + update (:obj:`object`, optional): If passed, will be passed to :meth:`process_error` + as additional information for the error handlers. Moreover, the corresponding + :attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of :meth:`update_persistence` after the :paramref:`coroutine` is finished. Returns: @@ -643,7 +743,7 @@ def __create_task( return task def __create_task_done_callback(self, task: asyncio.Task) -> None: - self.__create_task_tasks.discard(task) + self.__create_task_tasks.discard(task) # Discard from our set since we are done with it # We just retrieve the eventual exception so that asyncio doesn't complain in case # it's not retrieved somewhere else try: @@ -661,7 +761,7 @@ async def __create_task_callback( return await coroutine except asyncio.CancelledError as cancel: # TODO: in py3.8+, CancelledError is a subclass of BaseException, so we can drop this - # close when we drop py3.7 + # clause when we drop py3.7 raise cancel except Exception as exception: if isinstance(exception, ApplicationHandlerStop): @@ -685,7 +785,7 @@ async def __create_task_callback( # So we can and must handle it await self.process_error(update=update, error=exception, coroutine=coroutine) - # Raise exception so that it can be set on the task + # Raise exception so that it can be set on the task and retrieved by task.exception() raise exception finally: self._mark_for_persistence_update(update=update) @@ -707,6 +807,7 @@ async def _update_fetcher(self) -> None: _logger.debug('Processing update %s', update) if self._concurrent_updates: + # We don't await the below because it has to be run concurrently self.create_task(self.__process_update_wrapper(update), update=update) else: await self.__process_update_wrapper(update) @@ -717,18 +818,16 @@ async def __process_update_wrapper(self, update: object) -> None: self.update_queue.task_done() async def process_update(self, update: object) -> None: - """Processes a single update and updates the persistence. + """Processes a single update and marks the update to be updated by the persistence later. Exceptions raised by handler callbacks will be processed by :meth:`process_update`. .. versionchanged:: 14.0 - This calls :meth:`update_persistence` exactly once after handling of the update was - finished by *all* handlers that handled the update, including asynchronously running - handlers. + Persistence is now updated in an interval set by + :attr:`telegram.ext.BasePersistence.update_interval`. Args: update (:class:`telegram.Update` | :obj:`object` | \ - :class:`telegram.error.TelegramError`): - The update to process. + :class:`telegram.error.TelegramError`): The update to process. Raises: :exc:`RuntimeError`: If the application was not initialized. @@ -737,19 +836,19 @@ async def process_update(self, update: object) -> None: self._check_initialized() context = None - any_blocking = False + any_blocking = False # Flag which is set to True if any handler specifies block=True for handlers in self.handlers.values(): try: for handler in handlers: - check = handler.check_update(update) - if not (check is None or check is False): - if not context: + check = handler.check_update(update) # Should the handler handle this update? + if not (check is None or check is False): # if yes, + if not context: # build a context if not already built context = self.context_types.context.from_update(update, self) await context.refresh_data() coroutine: Coroutine = handler.handle_update(update, self, check, context) - if not handler.block or ( + if not handler.block or ( # if handler is running with block=False, handler.block is DEFAULT_TRUE and isinstance(self.bot, ExtBot) and self.bot.defaults @@ -759,7 +858,7 @@ async def process_update(self, update: object) -> None: else: any_blocking = True await coroutine - break + break # Only a max of 1 handler per group is handled # Stop processing with any other handler. except ApplicationHandlerStop: @@ -775,6 +874,7 @@ async def process_update(self, update: object) -> None: if any_blocking: # Only need to mark the update for persistence if there was at least one # blocking handler - the non-blocking handlers mark the update again when finished + # (in __create_task_callback) self._mark_for_persistence_update(update=update) def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> None: @@ -806,7 +906,7 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> Args: handler (:class:`telegram.ext.Handler`): A Handler instance. - group (:obj:`int`, optional): The group identifier. Default is 0. + group (:obj:`int`, optional): The group identifier. Default is ``0``. """ # Unfortunately due to circular imports this has to be here @@ -852,13 +952,20 @@ def add_handlers( .. seealso:: :meth:`add_handler` Args: - handlers (List[:obj:`telegram.ext.Handler`] | \ - Dict[int, List[:obj:`telegram.ext.Handler`]]): \ + handlers (List[:class:`telegram.ext.Handler`] | \ + Dict[int, List[:class:`telegram.ext.Handler`]]): \ Specify a sequence of handlers *or* a dictionary where the keys are groups and values are handlers. - group (:obj:`int`, optional): Specify which group the sequence of ``handlers`` + group (:obj:`int`, optional): Specify which group the sequence of :paramref:`handlers` should be added to. Defaults to ``0``. + Example:: + + app.add_handlers(handlers={ + -1: [MessageHandler(...)], + 1: [CallbackQueryHandler(...), CommandHandler(...)] + } + """ if isinstance(handlers, dict) and not isinstance(group, DefaultValue): raise ValueError('The `group` argument can only be used with a sequence of handlers.') @@ -885,8 +992,8 @@ def remove_handler(self, handler: Handler, group: int = DEFAULT_GROUP) -> None: """Remove a handler from the specified group. Args: - handler (:class:`telegram.ext.Handler`): A Handler instance. - group (:obj:`object`, optional): The group identifier. Default is 0. + handler (:class:`telegram.ext.Handler`): A :class:`telegram.ext.Handler` instance. + group (:obj:`object`, optional): The group identifier. Default is ``0``. """ if handler in self.handlers[group]: @@ -946,7 +1053,7 @@ def migrate_chat_data( :class:`telegram.ext.Job`. Warning: - When using :paramref:`concurrent_updates` or the :attr:`job_queue`, + When using :attr:`concurrent_updates` or the :attr:`job_queue`, :meth:`process_update` or :meth:`telegram.ext.Job.run` may re-create the old entry due to the asynchronous nature of these features. Please make sure that your program can avoid or handle such situations. @@ -955,9 +1062,12 @@ def migrate_chat_data( message (:class:`telegram.Message`, optional): A message with either :attr:`~telegram.Message.migrate_from_chat_id` or :attr:`~telegram.Message.migrate_to_chat_id`. - Mutually exclusive with passing :paramref:`old_chat_id`` and - :paramref:`new_chat_id` - .. seealso: `telegram.ext.filters.StatusUpdate.MIGRATE` + Mutually exclusive with passing :paramref:`old_chat_id` and + :paramref:`new_chat_id`. + + .. seealso:: + :attr:`telegram.ext.filters.StatusUpdate.MIGRATE` + old_chat_id (:obj:`int`, optional): The old chat ID. Mutually exclusive with passing :paramref:`message` new_chat_id (:obj:`int`, optional): The new chat ID. @@ -1006,6 +1116,8 @@ async def _persistence_updater(self) -> None: if not self.persistence: return + # asyncio synchronization primitives don't accept a timeout argument, it is recommended + # to use wait_for instead try: await asyncio.wait_for( self.__update_persistence_event.wait(), @@ -1031,6 +1143,10 @@ async def update_persistence(self) -> None: This method will be called in regular intervals by the application. There is usually no need to call it manually. + Note: + Any data is deep copied with :func:`copy.deepcopy` before handing it over to the + persistence in order to avoid race conditions, so all persisted data must be copyable. + .. seealso:: :attr:`telegram.ext.BasePersistence.update_interval`. """ async with self.__update_persistence_lock: @@ -1156,10 +1272,13 @@ def add_error_handler( Attempts to add the same callback multiple times will be ignored. Args: - callback (:obj:`callable`): The callback function for this error handler. Will be - called when an error is raised. Callback signature: - ``def callback(update: Optional[object], context: CallbackContext)``. - The error that happened will be present in ``context.error``. + callback (:term:`coroutine function`): The callback function for this error handler. + Will be called when an error is raised. Callback signature:: + + async def callback(update: Optional[object], context: CallbackContext) + + The error that happened will be present in + :attr:`telegram.ext.CallbackContext.error`. block (:obj:`bool`, optional): Determines whether the return value of the callback should be awaited before processing the next error handler in :meth:`process_error`. Defaults to :obj:`True`. @@ -1174,7 +1293,7 @@ def remove_error_handler(self, callback: Callable[[object, CCT], None]) -> None: """Removes an error handler. Args: - callback (:obj:`callable`): The error handler to remove. + callback (:term:`coroutine function`): The error handler to remove. """ self.error_handlers.pop(callback, None) @@ -1196,6 +1315,7 @@ async def process_error( .. versionchanged:: 14.0 + * ``dispatch_error`` was renamed to :meth:`process_error`. * Exceptions raised by error handlers are now properly logged. * :class:`telegram.ext.ApplicationHandlerStop` is no longer reraised but converted into the return value. @@ -1206,10 +1326,11 @@ async def process_error( job (:class:`telegram.ext.Job`, optional): The job that caused the error. .. versionadded:: 14.0 + coroutine (:term:`coroutine function`, optional): The coroutine that caused the error. Returns: - :obj:`bool`: :obj:`True` if one of the error handlers raised - :class:`telegram.ext.ApplicationHandlerStop`. :obj:`False`, otherwise. + :obj:`bool`: :obj:`True`, if one of the error handlers raised + :class:`telegram.ext.ApplicationHandlerStop`. :obj:`False`, otherwise. """ if self.error_handlers: for ( @@ -1223,7 +1344,7 @@ async def process_error( job=job, coroutine=coroutine, ) - if not block or ( + if not block or ( # If error handler has `block=False`, create a Task to run cb block is DEFAULT_TRUE and isinstance(self.bot, ExtBot) and self.bot.defaults diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index aeef5f8f008..cb02d65753c 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -45,7 +45,7 @@ # Type hinting is a bit complicated here because we try to get to a sane level of # leveraging generics and therefore need a number of type variables. -InBT = TypeVar('InBT', bound=Bot) +InBT = TypeVar('InBT', bound=Bot) # 'In' stands for input - used in parameters of methods below InJQ = TypeVar('InJQ', bound=Union[None, JobQueue]) InCCT = TypeVar('InCCT', bound='CallbackContext') InUD = TypeVar('InUD') @@ -205,6 +205,7 @@ def _build_request(self, get_updates: bool) -> BaseRequest: write_timeout=getattr(self, f'{prefix}write_timeout'), pool_timeout=getattr(self, f'{prefix}pool_timeout'), ) + # Get timeouts that were actually set- effective_timeouts = { key: value for key, value in timeouts.items() if not isinstance(value, DefaultValue) } @@ -244,19 +245,20 @@ def build( """ job_queue = DefaultValue.get_value(self._job_queue) persistence = DefaultValue.get_value(self._persistence) - + # If user didn't set updater if isinstance(self._updater, DefaultValue) or self._updater is None: - if isinstance(self._bot, DefaultValue): - bot: Bot = self._build_ext_bot() + if isinstance(self._bot, DefaultValue): # and didn't set a bot + bot: Bot = self._build_ext_bot() # build a bot else: bot = self._bot + # now also build an updater/update_queue for them update_queue = DefaultValue.get_value(self._update_queue) if self._updater is None: updater = None else: updater = Updater(bot=bot, update_queue=update_queue) - else: + else: # if they set an updater, get all necessary attributes for Application from Updater: updater = self._updater bot = self._updater.bot update_queue = self._updater.update_queue @@ -273,7 +275,7 @@ def build( job_queue=job_queue, persistence=persistence, context_types=DefaultValue.get_value(self._context_types), - **self._application_kwargs, + **self._application_kwargs, # For custom Application subclasses ) if job_queue is not None: @@ -290,7 +292,7 @@ def application_class( self: BuilderType, application_class: Type[Application], kwargs: Dict[str, object] = None ) -> BuilderType: """Sets a custom subclass to be used instead of :class:`telegram.ext.Application`. The - subclasses ``__init__`` should look like this + subclass's ``__init__`` should look like this .. code:: python @@ -300,7 +302,7 @@ def __init__(self, custom_arg_1, custom_arg_2, ..., **kwargs): self.custom_arg_2 = custom_arg_2 Args: - application_class (:obj:`type`): A subclass of :class:`telegram.ext.Application` + application_class (:obj:`type`): A subclass of :class:`telegram.ext.Application` kwargs (Dict[:obj:`str`, :obj:`object`], optional): Keyword arguments for the initialization. Defaults to an empty dict. @@ -333,7 +335,7 @@ def base_url(self: BuilderType, base_url: str) -> BuilderType: .. seealso:: :paramref:`telegram.Bot.base_url`, `Local Bot API Server `_, - :meth:`base_url` + :meth:`base_file_url` Args: base_url (:obj:`str`): The URL. @@ -354,7 +356,7 @@ def base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: .. seealso:: :paramref:`telegram.Bot.base_file_url`, `Local Bot API Server `_, - :meth:`base_file_url` + :meth:`base_url` Args: base_file_url (:obj:`str`): The URL. @@ -373,6 +375,8 @@ def _request_check(self, get_updates: bool) -> None: prefix = 'get_updates_' if get_updates else '' name = prefix + 'request' + # Code below tests if it's okay to set a Request object. Only okay if no other request args + # or instances containing a Request were set previously for attr in ('connect_timeout', 'read_timeout', 'write_timeout', 'pool_timeout'): if not isinstance(getattr(self, f"_{prefix}{attr}"), DefaultValue): raise RuntimeError(_TWO_ARGS_REQ.format(name, attr)) @@ -387,27 +391,27 @@ def _request_check(self, get_updates: bool) -> None: def _request_param_check(self, name: str, get_updates: bool) -> None: if get_updates and self._get_updates_request is not DEFAULT_NONE: - raise RuntimeError( + raise RuntimeError( # disallow request args for get_updates if Request for that is set _TWO_ARGS_REQ.format(f'get_updates_{name}', 'get_updates_request instance') ) - if self._request is not DEFAULT_NONE: + if self._request is not DEFAULT_NONE: # disallow request args if request is set raise RuntimeError(_TWO_ARGS_REQ.format(name, 'request instance')) - if self._bot is not DEFAULT_NONE: + if self._bot is not DEFAULT_NONE: # disallow request args if bot is set (has Request) raise RuntimeError( _TWO_ARGS_REQ.format( f'get_updates_{name}' if get_updates else name, 'bot instance' ) ) - if self._updater not in (DEFAULT_NONE, None): + if self._updater not in (DEFAULT_NONE, None): # disallow request args for updater(has bot) raise RuntimeError( _TWO_ARGS_REQ.format(f'get_updates_{name}' if get_updates else name, 'updater') ) def request(self: BuilderType, request: BaseRequest) -> BuilderType: - """Sets a :class:`telegram.request.BaseRequest` object to be used for the ``request`` - parameter of :attr:`telegram.ext.Application.bot`. + """Sets a :class:`telegram.request.BaseRequest` object to be used for the + :paramref:`telegram.Bot.request` parameter of :attr:`telegram.ext.Application.bot`. .. seealso:: :meth:`get_updates_request` @@ -422,31 +426,95 @@ def request(self: BuilderType, request: BaseRequest) -> BuilderType: return self def connection_pool_size(self: BuilderType, connection_pool_size: int) -> BuilderType: + """Sets the size of the connection pool to be used for the + :paramref:`~telegram.request.HTTPXRequest.connection_pool_size` parameter of + :attr:`telegram.Bot.request`. Defaults to ``128``. + + Args: + connection_pool_size (:obj:`int`): The size of the connection pool. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='connection_pool_size', get_updates=False) self._connection_pool_size = connection_pool_size return self def proxy_url(self: BuilderType, proxy_url: str) -> BuilderType: + """Sets the proxy to be used for the :paramref:`~telegram.request.HTTPXRequest.proxy_url` + parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`. + + Args: + proxy_url (:obj:`str`): The URL to the proxy server. See + :paramref:`telegram.request.HTTPXRequest.proxy_url` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='proxy_url', get_updates=False) self._proxy_url = proxy_url return self def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType: + """Sets the connection attempt timeout to be used for the + :paramref:`~telegram.request.HTTPXRequest.connect_timeout` parameter of + :attr:`telegram.Bot.request`. Defaults to ``5.0``. + + Args: + connect_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.connect_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='connect_timeout', get_updates=False) self._connect_timeout = connect_timeout return self def read_timeout(self: BuilderType, read_timeout: Optional[float]) -> BuilderType: + """Sets the waiting timeout to be used for the + :paramref:`~telegram.request.HTTPXRequest.read_timeout` parameter of + :attr:`telegram.Bot.request`. Defaults to ``5.0``. + + Args: + read_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.read_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='read_timeout', get_updates=False) self._read_timeout = read_timeout return self def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderType: + """Sets the write operation timeout to be used for the + :paramref:`~telegram.request.HTTPXRequest.write_timeout` parameter of + :attr:`telegram.Bot.request`. Defaults to ``5.0``. + + Args: + write_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.write_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='write_timeout', get_updates=False) self._write_timeout = write_timeout return self def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderType: + """Sets the connection pool's connection freeing timeout to be used for the + :paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter of + :attr:`telegram.Bot.request`. Defaults to :obj:`None`. + + Args: + pool_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.pool_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='pool_timeout', get_updates=False) self._pool_timeout = pool_timeout return self @@ -471,11 +539,31 @@ def get_updates_request(self: BuilderType, get_updates_request: BaseRequest) -> def get_updates_connection_pool_size( self: BuilderType, get_updates_connection_pool_size: int ) -> BuilderType: + """Sets the size of the connection pool to be used for the + :paramref:`telegram.request.HTTPXRequest.connection_pool_size` parameter which is used + for :meth:`telegram.Bot.get_updates`. Defaults to ``1``. + + Args: + get_updates_connection_pool_size (:obj:`int`): The size of the connection pool. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='connection_pool_size', get_updates=True) self._get_updates_connection_pool_size = get_updates_connection_pool_size return self def get_updates_proxy_url(self: BuilderType, get_updates_proxy_url: str) -> BuilderType: + """Sets the proxy to be used for the :paramref:`telegram.request.HTTPXRequest.proxy_url` + parameter which is used for :meth:`telegram.Bot.get_updates`. Defaults to :obj:`None`. + + Args: + get_updates_proxy_url (:obj:`str`): The URL to the proxy server. See + :paramref:`telegram.request.HTTPXRequest.proxy_url` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='proxy_url', get_updates=True) self._get_updates_proxy_url = get_updates_proxy_url return self @@ -483,6 +571,17 @@ def get_updates_proxy_url(self: BuilderType, get_updates_proxy_url: str) -> Buil def get_updates_connect_timeout( self: BuilderType, get_updates_connect_timeout: Optional[float] ) -> BuilderType: + """Sets the connection attempt timeout to be used for the + :paramref:`telegram.request.HTTPXRequest.connect_timeout` parameter which is used for + :meth:`telegram.Bot.get_updates`. Defaults to ``5.0``. + + Args: + get_updates_connect_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.connect_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='connect_timeout', get_updates=True) self._get_updates_connect_timeout = get_updates_connect_timeout return self @@ -490,6 +589,17 @@ def get_updates_connect_timeout( def get_updates_read_timeout( self: BuilderType, get_updates_read_timeout: Optional[float] ) -> BuilderType: + """Sets the waiting timeout to be used for the + :paramref:`telegram.request.HTTPXRequest.read_timeout` parameter which is used for + :meth:`telegram.Bot.get_updates`. Defaults to ``5.0``. + + Args: + get_updates_read_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.read_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='read_timeout', get_updates=True) self._get_updates_read_timeout = get_updates_read_timeout return self @@ -497,6 +607,17 @@ def get_updates_read_timeout( def get_updates_write_timeout( self: BuilderType, get_updates_write_timeout: Optional[float] ) -> BuilderType: + """Sets the write operation timeout to be used for the + :paramref:`telegram.request.HTTPXRequest.write_timeout` parameter which is used for + :meth:`telegram.Bot.get_updates`. Defaults to ``5.0``. + + Args: + get_updates_write_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.write_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='write_timeout', get_updates=True) self._get_updates_write_timeout = get_updates_write_timeout return self @@ -504,6 +625,17 @@ def get_updates_write_timeout( def get_updates_pool_timeout( self: BuilderType, get_updates_pool_timeout: Optional[float] ) -> BuilderType: + """Sets the connection pool's connection freeing timeout to be used for the + :paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter which is used for + :meth:`telegram.Bot.get_updates`. Defaults to :obj:`None`. + + Args: + get_updates_pool_timeout (:obj:`float`): See + :paramref:`telegram.request.HTTPXRequest.pool_timeout` for more information. + + Returns: + :class:`ApplicationBuilder`: The same builder with the updated argument. + """ self._request_param_check(name='pool_timeout', get_updates=True) self._get_updates_pool_timeout = get_updates_pool_timeout return self @@ -581,7 +713,7 @@ def arbitrary_callback_data( Args: arbitrary_callback_data (:obj:`bool` | :obj:`int`): If :obj:`True` is passed, the - default cache size of 1024 will be used. Pass an integer to specify a different + default cache size of ``1024`` will be used. Pass an integer to specify a different cache size. Returns: @@ -617,7 +749,7 @@ def bot( return self # type: ignore[return-value] def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: - """Sets a :class:`queue.Queue` instance to be used for + """Sets a :class:`asyncio.Queue` instance to be used for :attr:`telegram.ext.Application.update_queue`, i.e. the queue that the application will fetch updates from. Will also be used for the :attr:`telegram.ext.Application.updater`. If not called, a queue will be instantiated. @@ -625,7 +757,7 @@ def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: .. seealso:: :attr:`telegram.ext.Updater.update_queue` Args: - update_queue (:class:`queue.Queue`): The queue. + update_queue (:class:`asyncio.Queue`): The queue. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -640,14 +772,16 @@ def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) Warning: Processing updates concurrently is not recommended when stateful handlers like - :class:`telegram.ext.ConversationHandler` are used. + :class:`telegram.ext.ConversationHandler` are used. Only use this, when you are sure + that your bot does not (explicitly or implicitly) rely on updates being processed + sequentially. - .. seealso:: :paramref:`telegram.ext.Application.concurrent_updates` + .. seealso:: :paramref:`telegram.ext.Application.concurrent_updates` Args: - concurrent_updates (:obj:`bool` | :obj:`int`): Passing :obj:`True` will allow for 4096 - updates to be processed concurrently. Pass an integer to specify a different number - of updates that may be processed concurrently. + concurrent_updates (:obj:`bool` | :obj:`int`): Passing :obj:`True` will allow for + ``4096`` updates to be processed concurrently. Pass an integer to specify a + different number of updates that may be processed concurrently. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -678,8 +812,8 @@ def job_queue( this uses :attr:`telegram.ext.Application.job_queue` internally. Args: - job_queue (:class:`telegram.ext.JobQueue`, optional): The job queue. Pass :obj:`None` - if you don't want to use a job queue. + job_queue (:class:`telegram.ext.JobQueue`): The job queue. Pass :obj:`None` if you + don't want to use a job queue. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -691,6 +825,15 @@ def persistence(self: BuilderType, persistence: 'BasePersistence') -> BuilderTyp """Sets a :class:`telegram.ext.BasePersistence` instance to be used for :attr:`telegram.ext.Application.persistence`. + Note: + When using a persistence, note that all + data stored in :attr:`context.user_data `, + :attr:`context.chat_data `, + :attr:`context.bot_data ` and + in :attr:`telegram.ext.ExtBot.callback_data_cache` must be copyable with + :func:`copy.deepcopy`. This is due to the data being deep copied before handing it over + to the persistence in order to avoid race conditions. + .. seealso:: `Making your bot persistent `_, `persistentconversationbot.py BuilderTyp the persistence instance must use the same types! Args: - persistence (:class:`telegram.ext.BasePersistence`, optional): The persistence - instance. + persistence (:class:`telegram.ext.BasePersistence`): The persistence instance. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -721,7 +863,7 @@ def context_types( /python-telegram-bot/tree/master/examples#contexttypesbotpy>`_ Args: - context_types (:class:`telegram.ext.ContextTypes`, optional): The context types. + context_types (:class:`telegram.ext.ContextTypes`): The context types. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -732,8 +874,9 @@ def context_types( def updater(self: BuilderType, updater: Optional[Updater]) -> BuilderType: """Sets a :class:`telegram.ext.Updater` instance to be used for :attr:`telegram.ext.Application.updater`. The :attr:`telegram.ext.Updater.bot` and - :attr:`telegram.ext.Updater.update_queue` be used for :attr:`telegram.ext.Application.bot` - and :attr:`telegram.ext.Application.update_queue`, respectively. + :attr:`telegram.ext.Updater.update_queue` will be used for + :attr:`telegram.ext.Application.bot` and :attr:`telegram.ext.Application.update_queue`, + respectively. Args: updater (:class:`telegram.ext.Updater` | :obj:`None`): The updater instance or diff --git a/telegram/ext/_basepersistence.py b/telegram/ext/_basepersistence.py index 992eb1ee6d0..998836e121f 100644 --- a/telegram/ext/_basepersistence.py +++ b/telegram/ext/_basepersistence.py @@ -145,7 +145,7 @@ def __init__( @property def update_interval(self) -> float: - """:obj:`int`, optional): Time (in seconds) that the :class:`~telegram.ext.Application` + """:obj:`float`: Time (in seconds) that the :class:`~telegram.ext.Application` will wait between two consecutive runs of updating the persistence. .. versionadded:: 14.0 @@ -163,6 +163,10 @@ def set_bot(self, bot: Bot) -> None: Args: bot (:class:`telegram.Bot`): The bot. + + Raises: + :exc:`TypeError`: If :attr:`PersistenceInput.callback_data` is :obj:`True` and the + :paramref:`bot` is not an instance of :class:`telegram.ext.ExtBot`. """ if self.store_data.callback_data and not isinstance(bot, ExtBot): raise TypeError('callback_data can only be stored when using telegram.ext.ExtBot.') @@ -231,12 +235,12 @@ async def get_callback_data(self) -> Optional[CDCData]: .. versionadded:: 13.6 .. versionchanged:: 14.0 - Changed this method into an ``@abstractmethod``. + Changed this method into an :external:func:`~abc.abstractmethod`. Returns: Optional[Tuple[List[Tuple[:obj:`str`, :obj:`float`, \ Dict[:obj:`str`, :class:`object`]]], Dict[:obj:`str`, :obj:`str`]]]: - The restored meta data or :obj:`None`, if no data was stored. + The restored metadata or :obj:`None`, if no data was stored. """ @abstractmethod @@ -244,7 +248,8 @@ async def get_conversations(self, name: str) -> ConversationDict: """Will be called by :class:`telegram.ext.Application` when a :class:`telegram.ext.ConversationHandler` is added if :attr:`telegram.ext.ConversationHandler.persistent` is :obj:`True`. - It should return the conversations for the handler with `name` or an empty :obj:`dict` + It should return the conversations for the handler with :paramref:`name` or an empty + :obj:`dict`. Args: name (:obj:`str`): The handlers name. @@ -263,7 +268,7 @@ async def update_conversation( Args: name (:obj:`str`): The handler's name. key (:obj:`tuple`): The key the state is changed for. - new_state (:obj:`tuple` | :class:`object`): The new state for the given key. + new_state (:class:`object`): The new state for the given key. """ @abstractmethod @@ -306,7 +311,7 @@ async def update_callback_data(self, data: CDCData) -> None: .. versionadded:: 13.6 .. versionchanged:: 14.0 - Changed this method into an ``@abstractmethod``. + Changed this method into an :external:func:`~abc.abstractmethod`. Args: data (Optional[Tuple[List[Tuple[:obj:`str`, :obj:`float`, \ @@ -345,7 +350,7 @@ async def refresh_user_data(self, user_id: int, user_data: UD) -> None: .. versionadded:: 13.6 .. versionchanged:: 14.0 - Changed this method into an ``@abstractmethod``. + Changed this method into an :external:func:`~abc.abstractmethod`. Args: user_id (:obj:`int`): The user ID this :attr:`~telegram.ext.Application.user_data` is @@ -363,7 +368,7 @@ async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: .. versionadded:: 13.6 .. versionchanged:: 14.0 - Changed this method into an ``@abstractmethod``. + Changed this method into an :external:func:`~abc.abstractmethod`. Args: chat_id (:obj:`int`): The chat ID this :attr:`~telegram.ext.Application.chat_data` is @@ -381,7 +386,7 @@ async def refresh_bot_data(self, bot_data: BD) -> None: .. versionadded:: 13.6 .. versionchanged:: 14.0 - Changed this method into an ``@abstractmethod``. + Changed this method into an :external:func:`~abc.abstractmethod`. Args: bot_data (:obj:`dict` | :attr:`telegram.ext.ContextTypes.bot_data`): @@ -394,5 +399,5 @@ async def flush(self) -> None: persistence a chance to finish up saving or close a database connection gracefully. .. versionchanged:: 14.0 - Changed this method into an ``@abstractmethod``. + Changed this method into an :external:func:`~abc.abstractmethod`. """ diff --git a/telegram/ext/_callbackcontext.py b/telegram/ext/_callbackcontext.py index e702974b5d5..58cbd89a245 100644 --- a/telegram/ext/_callbackcontext.py +++ b/telegram/ext/_callbackcontext.py @@ -56,11 +56,11 @@ class CallbackContext(Generic[BT, UD, CD, BD]): Note: :class:`telegram.ext.Application` will create a single context for an entire update. This means that if you got 2 handlers in different groups and they both get called, they will - get passed the same `CallbackContext` object (of course with proper attributes like - `.matches` differing). This allows you to add custom attributes in a lower handler group - callback, and then subsequently access those attributes in a higher handler group callback. - Note that the attributes on `CallbackContext` might change in the future, so make sure to - use a fairly unique name for the attributes. + receive the same :class:`CallbackContext` object (of course with proper attributes like + :attr:`matches` differing). This allows you to add custom attributes in a lower handler + group callback, and then subsequently access those attributes in a higher handler group + callback. Note that the attributes on :class:`CallbackContext` might change in the future, + so make sure to use a fairly unique name for the attributes. Warning: Do not combine custom attributes with :paramref:`telegram.ext.Handler.block` set to @@ -73,17 +73,19 @@ class CallbackContext(Generic[BT, UD, CD, BD]): context. Attributes: + coroutine (:term:`coroutine function`): Optional. Only present in error handlers if the + error was caused by a coroutine run with :meth:`Application.create_task` or a handler + callback with :attr:`block=False `. matches (List[:meth:`re.Match `]): Optional. If the associated update - originated from - a :class:`filters.Regex`, this will contain a list of match objects for every pattern - where ``re.search(pattern, string)`` returned a match. Note that filters short circuit, - so combined regex filters will not always be evaluated. + originated from a :class:`filters.Regex`, this will contain a list of match objects for + every pattern where ``re.search(pattern, string)`` returned a match. Note that filters + short circuit, so combined regex filters will not always be evaluated. args (List[:obj:`str`]): Optional. Arguments passed to a command if the associated update is handled by :class:`telegram.ext.CommandHandler`, :class:`telegram.ext.PrefixHandler` or :class:`telegram.ext.StringCommandHandler`. It contains a list of the words in the text after the command, using any whitespace string as a delimiter. - error (:obj:`Exception`): Optional. The error that was raised. Only present when passed - to a error handler registered with :attr:`telegram.ext.Application.add_error_handler`. + error (:exc:`Exception`): Optional. The error that was raised. Only present when passed + to an error handler registered with :attr:`telegram.ext.Application.add_error_handler`. job (:class:`telegram.ext.Job`): Optional. The job which originated this callback. Only present when passed to the callback of :class:`telegram.ext.Job` or in error handlers if the error is caused by a job. @@ -107,7 +109,7 @@ class CallbackContext(Generic[BT, UD, CD, BD]): Example: .. code:: python - def callback(update: Update, context: CallbackContext.DEFAULT_TYPE): + async def callback(update: Update, context: CallbackContext.DEFAULT_TYPE): ... .. versionadded: 14.0 @@ -126,10 +128,6 @@ def callback(update: Update, context: CallbackContext.DEFAULT_TYPE): ) def __init__(self: 'CCT', application: 'Application[BT, CCT, UD, CD, BD, JQ]'): - """ - Args: - application (:class:`telegram.ext.Application`): - """ self._application = application self._chat_id_and_data: Optional[Tuple[int, CD]] = None self._user_id_and_data: Optional[Tuple[int, UD]] = None @@ -146,8 +144,8 @@ def application(self) -> 'Application[BT, CCT, UD, CD, BD, JQ]': @property def bot_data(self) -> BD: - """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each - update it will be the same ``dict``. + """:obj:`ContextTypes.bot_data`: Optional. An object that can be used to keep any data in. + For each update it will be the same :attr:`ContextTypes.bot_data`. Defaults to :obj:`dict`. """ return self.application.bot_data @@ -159,8 +157,9 @@ def bot_data(self, value: object) -> NoReturn: @property def chat_data(self) -> Optional[CD]: - """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each - update from the same chat id it will be the same ``dict``. + """:obj:`ContextTypes.chat_data`: Optional. An object that can be used to keep any data in. + For each update from the same chat id it will be the same :obj:`ContextTypes.chat_data`. + Defaults to :obj:`dict`. Warning: When a group chat migrates to a supergroup, its chat id will change and the @@ -180,8 +179,9 @@ def chat_data(self, value: object) -> NoReturn: @property def user_data(self) -> Optional[UD]: - """:obj:`dict`: Optional. A dict that can be used to keep any data in. For each - update from the same user it will be the same ``dict``. + """:obj:`ContextTypes.user_data`: Optional. An object that can be used to keep any data in. + For each update from the same user it will be the same :obj:`ContextTypes.user_data`. + Defaults to :obj:`dict`. """ if self._user_id_and_data: return self._user_id_and_data[1] @@ -227,15 +227,14 @@ def drop_callback_data(self, callback_query: CallbackQuery) -> None: Note: Will *not* raise exceptions in case the data is not found in the cache. - *Will* raise :class:`KeyError` in case the callback query can not be found in the - cache. + *Will* raise :exc:`KeyError` in case the callback query can not be found in the cache. Args: callback_query (:class:`telegram.CallbackQuery`): The callback query. Raises: - KeyError | RuntimeError: :class:`KeyError`, if the callback query can not be found in - the cache and :class:`RuntimeError`, if the bot doesn't allow for arbitrary + KeyError | RuntimeError: :exc:`KeyError`, if the callback query can not be found in + the cache and :exc:`RuntimeError`, if the bot doesn't allow for arbitrary callback data. """ if isinstance(self.bot, ExtBot): @@ -374,9 +373,8 @@ def bot(self) -> BT: @property def job_queue(self) -> Optional['JobQueue']: """ - :class:`telegram.ext.JobQueue`: The ``JobQueue`` used by the - :class:`telegram.ext.Application` and (usually) the :class:`telegram.ext.Updater` - associated with this context. + :class:`telegram.ext.JobQueue`: The :class:`JobQueue` used by the + :class:`telegram.ext.Application`. """ return self._application.job_queue @@ -384,7 +382,7 @@ def job_queue(self) -> Optional['JobQueue']: @property def update_queue(self) -> 'Queue[object]': """ - :class:`asyncio.Queue`: The ``Queue`` instance used by the + :class:`asyncio.Queue`: The :class:`asyncio.Queue` instance used by the :class:`telegram.ext.Application` and (usually) the :class:`telegram.ext.Updater` associated with this context. @@ -394,9 +392,9 @@ def update_queue(self) -> 'Queue[object]': @property def match(self) -> Optional[Match[str]]: """ - `Regex match type`: The first match from :attr:`matches`. + :meth:`re.Match `: The first match from :attr:`matches`. Useful if you are only filtering using a single regex filter. - Returns `None` if :attr:`matches` is empty. + Returns :obj:`None` if :attr:`matches` is empty. """ try: return self.matches[0] # type: ignore[index] # pylint: disable=unsubscriptable-object diff --git a/telegram/ext/_callbackdatacache.py b/telegram/ext/_callbackdatacache.py index 2656f3cc419..b00f6bd514e 100644 --- a/telegram/ext/_callbackdatacache.py +++ b/telegram/ext/_callbackdatacache.py @@ -105,7 +105,7 @@ class CallbackDataCache: Args: bot (:class:`telegram.ext.ExtBot`): The bot this cache is for. maxsize (:obj:`int`, optional): Maximum number of items in each of the internal mappings. - Defaults to 1024. + Defaults to ``1024``. persistent_data (Tuple[List[Tuple[:obj:`str`, :obj:`float`, \ Dict[:obj:`str`, :class:`object`]]], Dict[:obj:`str`, :obj:`str`]], optional): \ @@ -158,8 +158,8 @@ def persistence_data(self) -> CDCData: def process_keyboard(self, reply_markup: InlineKeyboardMarkup) -> InlineKeyboardMarkup: """Registers the reply markup to the cache. If any of the buttons have :attr:`~telegram.InlineKeyboardButton.callback_data`, stores that data and builds a new - keyboard with the correspondingly - replaced buttons. Otherwise does nothing and returns the original reply markup. + keyboard with the correspondingly replaced buttons. Otherwise, does nothing and returns + the original reply markup. Args: reply_markup (:class:`telegram.InlineKeyboardMarkup`): The keyboard. @@ -221,10 +221,11 @@ def __get_keyboard_uuid_and_button_data( @staticmethod def extract_uuids(callback_data: str) -> Tuple[str, str]: - """Extracts the keyboard uuid and the button uuid from the given ``callback_data``. + """Extracts the keyboard uuid and the button uuid from the given :paramref:`callback_data`. Args: - callback_data (:obj:`str`): The ``callback_data`` as present in the button. + callback_data (:obj:`str`): The + :paramref:`~telegram.InlineKeyboardButton.callback_data` as present in the button. Returns: (:obj:`str`, :obj:`str`): Tuple of keyboard and button uuid @@ -240,7 +241,7 @@ def process_message(self, message: Message) -> None: Note: Checks :attr:`telegram.Message.via_bot` and :attr:`telegram.Message.from_user` to check - if the reply markup (if any) was actually sent by this caches bot. If it was not, the + if the reply markup (if any) was actually sent by this cache's bot. If it was not, the message will be returned unchanged. Note that this will fail for channel posts, as :attr:`telegram.Message.from_user` is @@ -249,7 +250,7 @@ def process_message(self, message: Message) -> None: Warning: * Does *not* consider :attr:`telegram.Message.reply_to_message` and - :attr:`telegram.Message.pinned_message`. Pass them to these method separately. + :attr:`telegram.Message.pinned_message`. Pass them to this method separately. * *In place*, i.e. the passed :class:`telegram.Message` will be changed! Args: @@ -346,7 +347,7 @@ def drop_data(self, callback_query: CallbackQuery) -> None: Note: Will *not* raise exceptions in case the callback data is not found in the cache. - *Will* raise :class:`KeyError` in case the callback query can not be found in the + *Will* raise :exc:`KeyError` in case the callback query can not be found in the cache. Args: diff --git a/telegram/ext/_callbackqueryhandler.py b/telegram/ext/_callbackqueryhandler.py index 2af9a474f56..e5a7fff5b92 100644 --- a/telegram/ext/_callbackqueryhandler.py +++ b/telegram/ext/_callbackqueryhandler.py @@ -43,7 +43,8 @@ class CallbackQueryHandler(Handler[Update, CCT]): - """Handler class to handle Telegram callback queries. Optionally based on a regex. + """Handler class to handle Telegram :attr:`callback queries `. + Optionally based on a regex. Read the documentation of the :mod:`re` module for more information. @@ -64,18 +65,21 @@ class CallbackQueryHandler(Handler[Update, CCT]): attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - pattern (:obj:`str` | `Pattern` | :obj:`callable` | :obj:`type`, optional): + pattern (:obj:`str` | :func:`re.Pattern ` | :obj:`callable` | :obj:`type`, \ + optional): Pattern to test :attr:`telegram.CallbackQuery.data` against. If a string or a regex pattern is passed, :func:`re.match` is used on :attr:`telegram.CallbackQuery.data` to determine if an update should be handled by this handler. If your bot allows arbitrary - objects as ``callback_data``, non-strings will be accepted. To filter arbitrary - objects you may pass + objects as :paramref:`~telegram.InlineKeyboardButton.callback_data`, non-strings will + be accepted. To filter arbitrary objects you may pass: * a callable, accepting exactly one argument, namely the :attr:`telegram.CallbackQuery.data`. It must return :obj:`True` or @@ -93,9 +97,9 @@ class CallbackQueryHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. - pattern (`Pattern` | :obj:`callable` | :obj:`type`): Optional. Regex pattern, callback or - type to test :attr:`telegram.CallbackQuery.data` against. + callback (:term:`coroutine function`): The callback function for this handler. + pattern (:func:`re.Pattern ` | :obj:`callable` | :obj:`type`): Optional. + Regex pattern, callback or type to test :attr:`telegram.CallbackQuery.data` against. .. versionchanged:: 13.6 Added support for arbitrary callback data. @@ -126,7 +130,7 @@ def __init__( self.pattern = pattern def check_update(self, update: object) -> Optional[Union[bool, object]]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_chatjoinrequesthandler.py b/telegram/ext/_chatjoinrequesthandler.py index b4660a4aa56..13e9b43054d 100644 --- a/telegram/ext/_chatjoinrequesthandler.py +++ b/telegram/ext/_chatjoinrequesthandler.py @@ -26,7 +26,8 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): - """Handler class to handle Telegram updates that contain a chat join request. + """Handler class to handle Telegram updates that contain + :attr:`telegram.Update.chat_join_request`. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom @@ -35,11 +36,11 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): .. versionadded:: 13.8 Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature for context based API: + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: - ``def callback(update: Update, context: CallbackContext)`` + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -48,7 +49,7 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run asynchronously. """ @@ -56,7 +57,7 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): __slots__ = () def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_chatmemberhandler.py b/telegram/ext/_chatmemberhandler.py index 1241f08b778..7ea2388bbc4 100644 --- a/telegram/ext/_chatmemberhandler.py +++ b/telegram/ext/_chatmemberhandler.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains the ChatMemberHandler classes.""" +"""This module contains the ChatMemberHandler class.""" from typing import ClassVar, TypeVar from telegram import Update @@ -38,9 +38,11 @@ class ChatMemberHandler(Handler[Update, CCT]): attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -53,7 +55,7 @@ class ChatMemberHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. chat_member_types (:obj:`int`, optional): Specifies if this handler should handle only updates with :attr:`telegram.Update.my_chat_member`, :attr:`telegram.Update.chat_member` or both. @@ -69,7 +71,7 @@ class ChatMemberHandler(Handler[Update, CCT]): CHAT_MEMBER: ClassVar[int] = 0 """:obj:`int`: Used as a constant to handle only :attr:`telegram.Update.chat_member`.""" ANY_CHAT_MEMBER: ClassVar[int] = 1 - """:obj:`int`: Used as a constant to handle bot :attr:`telegram.Update.my_chat_member` + """:obj:`int`: Used as a constant to handle both :attr:`telegram.Update.my_chat_member` and :attr:`telegram.Update.chat_member`.""" def __init__( @@ -83,7 +85,7 @@ def __init__( self.chat_member_types = chat_member_types def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_choseninlineresulthandler.py b/telegram/ext/_choseninlineresulthandler.py index 8ca94ea7751..9f40abc4d98 100644 --- a/telegram/ext/_choseninlineresulthandler.py +++ b/telegram/ext/_choseninlineresulthandler.py @@ -33,16 +33,19 @@ class ChosenInlineResultHandler(Handler[Update, CCT]): - """Handler class to handle Telegram updates that contain a chosen inline result. + """Handler class to handle Telegram updates that contain + :attr:`telegram.Update.chosen_inline_result`. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -58,7 +61,7 @@ class ChosenInlineResultHandler(Handler[Update, CCT]): .. versionadded:: 13.6 Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. @@ -85,13 +88,13 @@ def __init__( self.pattern = pattern def check_update(self, update: object) -> Optional[Union[bool, object]]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. Returns: - :obj:`bool` + :obj:`bool` | :obj:`re.match` """ if isinstance(update, Update) and update.chosen_inline_result: diff --git a/telegram/ext/_commandhandler.py b/telegram/ext/_commandhandler.py index b27b8843be2..8e3c29f17d4 100644 --- a/telegram/ext/_commandhandler.py +++ b/telegram/ext/_commandhandler.py @@ -36,12 +36,13 @@ class CommandHandler(Handler[Update, CCT]): """Handler class to handle Telegram commands. Commands are Telegram messages that start with ``/``, optionally followed by an ``@`` and the - bot's name and/or some additional text. The handler will add a ``list`` to the + bot's name and/or some additional text. The handler will add a :obj:`list` to the :class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings, which is the text following the command split on single or consecutive whitespace characters. - By default the handler listens to messages as well as edited messages. To change this behavior - use ``~filters.UpdateType.EDITED_MESSAGE`` in the filter argument. + By default, the handler listens to messages as well as edited messages. To change this behavior + use :attr:`~filters.UpdateType.EDITED_MESSAGE ` + in the filter argument. Note: * :class:`CommandHandler` does *not* handle (edited) channel posts. @@ -53,29 +54,31 @@ class CommandHandler(Handler[Update, CCT]): Args: command (:obj:`str` | Tuple[:obj:`str`] | List[:obj:`str`]): The command or list of commands this handler should listen for. - Limitations are the same as described here https://core.telegram.org/bots#commands - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + Limitations are the same as described `here `_ + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. filters (:class:`telegram.ext.filters.BaseFilter`, optional): A filter inheriting from :class:`telegram.ext.filters.BaseFilter`. Standard filters can be found in :mod:`telegram.ext.filters`. Filters can be combined using bitwise - operators (& for and, | for or, ~ for not). + operators (``&`` for :keyword:`and`, ``|`` for :keyword:`or`, ``~`` for :keyword:`not`) block (:obj:`bool`, optional): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Raises: - ValueError: when command is too long or has illegal chars. + :exc:`ValueError`: When the command is too long or has illegal chars. Attributes: command (:obj:`str` | Tuple[:obj:`str`] | List[:obj:`str`]): The command or list of commands this handler should listen for. - Limitations are the same as described here https://core.telegram.org/bots#commands - callback (:obj:`callable`): The callback function for this handler. + Limitations are the same as described `here `_ + callback (:term:`coroutine function`): The callback function for this handler. filters (:class:`telegram.ext.filters.BaseFilter`): Optional. Only allow updates with these Filters. block (:obj:`bool`): Determines whether the return value of the callback should be @@ -107,7 +110,7 @@ def __init__( def check_update( self, update: object ) -> Optional[Union[bool, Tuple[List[str], Optional[Union[bool, Dict]]]]]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. @@ -162,11 +165,12 @@ def collect_additional_context( class PrefixHandler(CommandHandler): """Handler class to handle custom prefix commands. - This is a intermediate handler between :class:`MessageHandler` and :class:`CommandHandler`. - It supports configurable commands with the same options as CommandHandler. It will respond to - every combination of :attr:`prefix` and :attr:`command`. It will add a :obj:`list` to the - :class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings, - which is the text following the command split on single or consecutive whitespace characters. + This is an intermediate handler between :class:`MessageHandler` and :class:`CommandHandler`. + It supports configurable commands with the same options as :class:`CommandHandler`. It will + respond to every combination of :attr:`prefix` and :attr:`command`. It will add a :obj:`list` + to the :class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of + strings, which is the text following the command split on single or consecutive whitespace + characters. Examples: @@ -190,8 +194,8 @@ class PrefixHandler(CommandHandler): '#test', '!help' and '#help'. - By default the handler listens to messages as well as edited messages. To change this behavior - use ``~filters.UpdateType.EDITED_MESSAGE``. + By default, the handler listens to messages as well as edited messages. To change this behavior + use :attr:`~filters.UpdateType.EDITED_MESSAGE ` Note: * :class:`PrefixHandler` does *not* handle (edited) channel posts. @@ -205,22 +209,24 @@ class PrefixHandler(CommandHandler): The prefix(es) that will precede :attr:`command`. command (:obj:`str` | Tuple[:obj:`str`] | List[:obj:`str`]): The command or list of commands this handler should listen for. - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. filters (:class:`telegram.ext.filters.BaseFilter`, optional): A filter inheriting from :class:`telegram.ext.filters.BaseFilter`. Standard filters can be found in :mod:`telegram.ext.filters`. Filters can be combined using bitwise - operators (& for and, | for or, ~ for not). + operators (``&`` for :keyword:`and`, ``|`` for :keyword:`or`, ``~`` for :keyword:`not`) block (:obj:`bool`, optional): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. filters (:class:`telegram.ext.filters.BaseFilter`): Optional. Only allow updates with these Filters. block (:obj:`bool`): Determines whether the return value of the callback should be @@ -298,7 +304,7 @@ def _build_commands(self) -> None: def check_update( self, update: object ) -> Optional[Union[bool, Tuple[List[str], Optional[Union[bool, Dict]]]]]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_contexttypes.py b/telegram/ext/_contexttypes.py index 4e8fc211724..c5af1ac1afe 100644 --- a/telegram/ext/_contexttypes.py +++ b/telegram/ext/_contexttypes.py @@ -39,15 +39,18 @@ class ContextTypes(Generic[CCT, UD, CD, BD]): (error-)handler callbacks and job callbacks. Must be a subclass of :class:`telegram.ext.CallbackContext`. Defaults to :class:`telegram.ext.CallbackContext`. - bot_data (:obj:`type`, optional): Determines the type of ``context.bot_data`` of all - (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support - instantiating without arguments. - chat_data (:obj:`type`, optional): Determines the type of ``context.chat_data`` of all - (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support - instantiating without arguments. - user_data (:obj:`type`, optional): Determines the type of ``context.user_data`` of all - (error-)handler callbacks and job callbacks. Defaults to :obj:`dict`. Must support - instantiating without arguments. + bot_data (:obj:`type`, optional): Determines the type of + :attr:`context.bot_data ` of all (error-)handler callbacks + and job callbacks. Defaults to :obj:`dict`. Must support instantiating without + arguments. + chat_data (:obj:`type`, optional): Determines the type of + :attr:`context.chat_data ` of all (error-)handler callbacks + and job callbacks. Defaults to :obj:`dict`. Must support instantiating without + arguments. + user_data (:obj:`type`, optional): Determines the type of + :attr:`context.user_data ` of all (error-)handler callbacks + and job callbacks. Defaults to :obj:`dict`. Must support instantiating without + arguments. """ @@ -201,15 +204,21 @@ def context(self) -> Type[CCT]: @property def bot_data(self) -> Type[BD]: - """The type of ``context.bot_data`` of all (error-)handler callbacks and job callbacks.""" + """The type of :attr:`context.bot_data ` of all (error-)handler + callbacks and job callbacks. + """ return self._bot_data @property def chat_data(self) -> Type[CD]: - """The type of ``context.chat_data`` of all (error-)handler callbacks and job callbacks.""" + """The type of :attr:`context.chat_data ` of all (error-)handler + callbacks and job callbacks. + """ return self._chat_data @property def user_data(self) -> Type[UD]: - """The type of ``context.user_data`` of all (error-)handler callbacks and job callbacks.""" + """The type of :attr:`context.user_data ` of all (error-)handler + callbacks and job callbacks. + """ return self._user_data diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index af1dde51d9a..3870558986e 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -66,6 +66,10 @@ @dataclass class _ConversationTimeoutContext(Generic[CCT]): + """Used as a datastore for conversation timeouts. Passed in the + :paramref:`JobQueue.run_once.context` parameter. See :meth:`_trigger_timeout`. + """ + __slots__ = ('conversation_key', 'update', 'application', 'callback_context') conversation_key: ConversationKey @@ -76,8 +80,8 @@ class _ConversationTimeoutContext(Generic[CCT]): @dataclass class PendingState: - """Thin wrapper around asyncio.Task to handle block=False handlers. Note that this is a - public class of this module, since `Application.update_persistence` needs to access it. + """Thin wrapper around :class:`asyncio.Task` to handle block=False handlers. Note that this is + a public class of this module, since :meth:`Application.update_persistence` needs to access it. It's still hidden from users, since this module itself is private. """ @@ -90,6 +94,13 @@ def done(self) -> bool: return self.task.done() def resolve(self) -> object: + """Returns the new state of the :class:`ConversationHandler` if available. If there was an + exception during the task execution, then return the old state. If the returned state was + :obj:`None`, then end the conversation. + + Raises: + :exc:`RuntimeError`: If the current task has not yet finished. + """ if not self.task.done(): raise RuntimeError('New state is not yet available') @@ -112,7 +123,7 @@ def resolve(self) -> object: class ConversationHandler(Handler[Update, CCT]): """ A handler to hold a conversation with a single or multiple users through Telegram updates by - managing four collections of other handlers. + managing three collections of other handlers. Warning: :class:`ConversationHandler` heavily relies on incoming updates being processed one by one. @@ -120,36 +131,38 @@ class ConversationHandler(Handler[Update, CCT]): :obj:`False`. Note: - ``ConversationHandler`` will only accept updates that are (subclass-)instances of + :class:`ConversationHandler` will only accept updates that are (subclass-)instances of :class:`telegram.Update`. This is, because depending on the :attr:`per_user` and - :attr:`per_chat` ``ConversationHandler`` relies on + :attr:`per_chat`, :class:`ConversationHandler` relies on :attr:`telegram.Update.effective_user` and/or :attr:`telegram.Update.effective_chat` in - order to determine which conversation an update should belong to. For ``per_message=True``, - ``ConversationHandler`` uses ``update.callback_query.message.message_id`` when - ``per_chat=True`` and ``update.callback_query.inline_message_id`` when ``per_chat=False``. - For a more detailed explanation, please see our `FAQ`_. + order to determine which conversation an update should belong to. For + :attr:`per_message=True `, :class:`ConversationHandler` uses + :attr:`update.callback_query.message.message_id ` when + :attr:`per_chat=True ` and + :attr:`update.callback_query.inline_message_id <.CallbackQuery.inline_message_id>` when + :attr:`per_chat=False `. For a more detailed explanation, please see our `FAQ`_. - Finally, ``ConversationHandler``, does *not* handle (edited) channel posts. + Finally, :class:`ConversationHandler`, does *not* handle (edited) channel posts. .. _`FAQ`: https://github.com/python-telegram-bot/python-telegram-bot/wiki\ /Frequently-Asked-Questions#what-do-the-per_-settings-in-conversation handler-do - The first collection, a ``list`` named :attr:`entry_points`, is used to initiate the + The first collection, a :obj:`list` named :attr:`entry_points`, is used to initiate the conversation, for example with a :class:`telegram.ext.CommandHandler` or :class:`telegram.ext.MessageHandler`. - The second collection, a ``dict`` named :attr:`states`, contains the different conversation + The second collection, a :obj:`dict` named :attr:`states`, contains the different conversation steps and one or more associated handlers that should be used if the user sends a message when the conversation with them is currently in that state. Here you can also define a state for :attr:`TIMEOUT` to define the behavior when :attr:`conversation_timeout` is exceeded, and a state for :attr:`WAITING` to define behavior when a new update is received while the previous - ``@run_async`` decorated handler is not finished. + :attr:`block=False ` handler is not finished. - The third collection, a ``list`` named :attr:`fallbacks`, is used if the user is currently in a - conversation but the state has either no associated handler or the handler that is associated - to the state is inappropriate for the update, for example if the update contains a command, but - a regular text message is expected. You could use this for a ``/cancel`` command or to let the - user know their message was not recognized. + The third collection, a :obj:`list` named :attr:`fallbacks`, is used if the user is currently + in a conversation but the state has either no associated handler or the handler that is + associated to the state is inappropriate for the update, for example if the update contains a + command, but a regular text message is expected. You could use this for a ``/cancel`` command + or to let the user know their message was not recognized. To change the state of conversation, the callback function of a handler must return the new state after responding to the user. If it does not return anything (returning :obj:`None` by @@ -158,78 +171,82 @@ class ConversationHandler(Handler[Update, CCT]): To end the conversation, the callback function must return :attr:`END` or ``-1``. To handle the conversation timeout, use handler :attr:`TIMEOUT` or ``-2``. Finally, :class:`telegram.ext.ApplicationHandlerStop` can be used in conversations as described - in the corresponding documentation. + in its documentation. Note: In each of the described collections of handlers, a handler may in turn be a - :class:`ConversationHandler`. In that case, the nested :class:`ConversationHandler` should - have the attribute :attr:`map_to_parent` which allows to return to the parent conversation - at specified states within the nested conversation. + :class:`ConversationHandler`. In that case, the child :class:`ConversationHandler` should + have the attribute :attr:`map_to_parent` which allows returning to the parent conversation + at specified states within the child conversation. Note that the keys in :attr:`map_to_parent` must not appear as keys in :attr:`states` attribute or else the latter will be ignored. You may map :attr:`END` to one of the parents - states to continue the parent conversation after this has ended or even map a state to - :attr:`END` to end the *parent* conversation from within the nested one. For an example on - nested :class:`ConversationHandler` s, see our `examples`_. + states to continue the parent conversation after the child conversation has ended or even + map a state to :attr:`END` to end the *parent* conversation from within the child + conversation. For an example on nested :class:`ConversationHandler` s, see our `examples`_. .. _`examples`: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples Args: - entry_points (List[:class:`telegram.ext.Handler`]): A list of ``Handler`` objects that can - trigger the start of the conversation. The first handler which :attr:`check_update` + entry_points (List[:class:`telegram.ext.Handler`]): A list of :obj:`Handler` objects that + can trigger the start of the conversation. The first handler whose :meth:`check_update` method returns :obj:`True` will be used. If all return :obj:`False`, the update is not handled. states (Dict[:obj:`object`, List[:class:`telegram.ext.Handler`]]): A :obj:`dict` that defines the different states of conversation a user can be in and one or more - associated ``Handler`` objects that should be used in that state. The first handler - which :attr:`check_update` method returns :obj:`True` will be used. + associated :obj:`Handler` objects that should be used in that state. The first handler + whose :meth:`check_update` method returns :obj:`True` will be used. fallbacks (List[:class:`telegram.ext.Handler`]): A list of handlers that might be used if the user is in a conversation, but every handler for their current state returned - :obj:`False` on :attr:`check_update`. The first handler which :attr:`check_update` + :obj:`False` on :meth:`check_update`. The first handler which :meth:`check_update` method returns :obj:`True` will be used. If all return :obj:`False`, the update is not handled. allow_reentry (:obj:`bool`, optional): If set to :obj:`True`, a user that is currently in a conversation can restart the conversation by triggering one of the entry points. - per_chat (:obj:`bool`, optional): If the conversationkey should contain the Chat's ID. + per_chat (:obj:`bool`, optional): If the conversation key should contain the Chat's ID. Default is :obj:`True`. - per_user (:obj:`bool`, optional): If the conversationkey should contain the User's ID. + per_user (:obj:`bool`, optional): If the conversation key should contain the User's ID. Default is :obj:`True`. - per_message (:obj:`bool`, optional): If the conversationkey should contain the Message's + per_message (:obj:`bool`, optional): If the conversation key should contain the Message's ID. Default is :obj:`False`. conversation_timeout (:obj:`float` | :obj:`datetime.timedelta`, optional): When this handler is inactive more than this timeout (in seconds), it will be automatically - ended. If this value is 0 or :obj:`None` (default), there will be no timeout. The last - received update and the corresponding ``context`` will be handled by ALL the handler's - who's :attr:`check_update` method returns :obj:`True` that are in the state - :attr:`ConversationHandler.TIMEOUT`. + ended. If this value is ``0`` or :obj:`None` (default), there will be no timeout. The + last received update and the corresponding :class:`context <.CallbackContext>` will be + handled by *ALL* the handler's whose :meth:`check_update` method returns :obj:`True` + that are in the state :attr:`ConversationHandler.TIMEOUT`. Note: - Using `conversation_timeout` with nested conversations is currently not + Using :paramref:`conversation_timeout` with nested conversations is currently not supported. You can still try to use it, but it will likely behave differently from what you expect. - name (:obj:`str`, optional): The name for this conversation handler. Required for persistence. - persistent (:obj:`bool`, optional): If the conversations dict for this handler should be - saved. Name is required and persistence has to be set in :class:`telegram.ext.Updater` + persistent (:obj:`bool`, optional): If the conversation's dict for this handler should be + saved. :paramref:`name` is required and persistence has to be set in + :attr:`Application <.Application.persistence>`. + + .. versionchanged:: 14.0 + Was previously named as ``persistence``. map_to_parent (Dict[:obj:`object`, :obj:`object`], optional): A :obj:`dict` that can be - used to instruct a nested conversation handler to transition into a mapped state on + used to instruct a child conversation handler to transition into a mapped state on its parent conversation handler in place of a specified nested state. - block (:obj:`bool`, optional): Pass :obj:`False` to set a default value for the - :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, + block (:obj:`bool`, optional): Pass :obj:`False` or :obj:`True` to set a default value for + the :attr:`Handler.block` setting of all handlers (in :attr:`entry_points`, :attr:`states` and :attr:`fallbacks`). The resolution order for checking if a handler should be run non-blocking is: - 1. :attr:`telegram.ext.Handler.block` (if set) - 2. the value passed to this parameter (if any) - 3. :attr:`telegram.ext.Defaults.block` (if defaults are used) + 1. :attr:`telegram.ext.Handler.block` (if set) + 2. the value passed to this parameter (if any) + 3. :attr:`telegram.ext.Defaults.block` (if defaults are used) .. versionchanged:: 14.0 No longer overrides the handlers settings. Resolution order was changed. Raises: - ValueError + :exc:`ValueError`: If :paramref:`persistent` is used but :paramref:`name` was not set, or + when :attr:`per_message`, :attr:`per_chat`, :attr:`per_user` are all :obj:`False`. Attributes: block (:obj:`bool`): Determines whether the callback will run asynchronously. Always @@ -259,10 +276,12 @@ class ConversationHandler(Handler[Update, CCT]): END: ClassVar[int] = -1 """:obj:`int`: Used as a constant to return when a conversation is ended.""" TIMEOUT: ClassVar[int] = -2 - """:obj:`int`: Used as a constant to handle state when a conversation is timed out.""" + """:obj:`int`: Used as a constant to handle state when a conversation is timed out + (exceeded :attr:`conversation_timeout`). + """ WAITING: ClassVar[int] = -3 """:obj:`int`: Used as a constant to handle state when a conversation is still waiting on the - previous ``@run_sync`` decorated running handler to finish.""" + previous :attr:`block=False ` handler to finish.""" # pylint: disable=super-init-not-called def __init__( self, @@ -305,6 +324,8 @@ def __init__( self._name = name self._map_to_parent = map_to_parent + # if conversation_timeout is used, this dict is used to schedule a job which runs when the + # conv has timed out. self.timeout_jobs: Dict[ConversationKey, 'Job'] = {} self._timeout_jobs_lock = asyncio.Lock() self._conversations: ConversationDict = {} @@ -335,9 +356,6 @@ def __init__( handler for handler in all_handlers if isinstance(handler, ConversationHandler) ) - # this loop is going to warn the user about handlers which can work unexpected - # in conversations - # this link will be added to all warnings tied to per_* setting per_faq_link = ( " Read this FAQ entry to learn more about the per_* settings: " @@ -345,6 +363,8 @@ def __init__( "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversationhandler-do." ) + # this loop is going to warn the user about handlers which can work unexpectedly + # in conversations for handler in all_handlers: if isinstance(handler, (StringCommandHandler, StringRegexHandler)): warn( @@ -409,8 +429,8 @@ def __init__( @property def entry_points(self) -> List[Handler]: - """List[:class:`telegram.ext.Handler`]: A list of ``Handler`` objects that can trigger the - start of the conversation. + """List[:class:`telegram.ext.Handler`]: A list of :obj:`Handler` objects that can trigger + the start of the conversation. """ return self._entry_points @@ -424,7 +444,7 @@ def entry_points(self, value: object) -> NoReturn: def states(self) -> Dict[object, List[Handler]]: """Dict[:obj:`object`, List[:class:`telegram.ext.Handler`]]: A :obj:`dict` that defines the different states of conversation a user can be in and one or more - associated ``Handler`` objects that should be used in that state. + associated :obj:`Handler` objects that should be used in that state. """ return self._states @@ -436,7 +456,7 @@ def states(self, value: object) -> NoReturn: def fallbacks(self) -> List[Handler]: """List[:class:`telegram.ext.Handler`]: A list of handlers that might be used if the user is in a conversation, but every handler for their current state returned - :obj:`False` on :attr:`check_update`. + :obj:`False` on :meth:`check_update`. """ return self._fallbacks @@ -510,7 +530,9 @@ def name(self, value: object) -> NoReturn: @property def persistent(self) -> bool: """:obj:`bool`: Optional. If the conversations dict for this handler should be - saved.""" + saved. :attr:`name` is required and persistence has to be set in + :attr:`Application <.Application.persistence>`. + """ return self._persistent @persistent.setter @@ -534,9 +556,9 @@ def map_to_parent(self, value: object) -> NoReturn: async def _initialize_persistence( self, application: 'Application' ) -> Dict[str, TrackingDict[ConversationKey, object]]: - """Initializes the persistence for this handler. While this method is marked as protected, - we expect it to be called by the Application/parent conversations. It's just protected to - hide it from users. + """Initializes the persistence for this handler and its child conversations. + While this method is marked as protected, we expect it to be called by the + Application/parent conversations. It's just protected to hide it from users. Args: application (:class:`telegram.ext.Application`): The application. @@ -577,6 +599,7 @@ async def _initialize_persistence( return out def _get_key(self, update: Update) -> ConversationKey: + """Builds the conversation key associated with the update.""" chat = update.effective_chat user = update.effective_user @@ -635,6 +658,7 @@ def _schedule_job( context: CallbackContext, conversation_key: ConversationKey, ) -> None: + """Schedules a job which executes :meth:`_trigger_timeout` upon conversation timeout.""" if new_state == self.END: return @@ -680,11 +704,11 @@ def check_update(self, update: object) -> Optional[_CheckUpdateType]: state = self._conversations.get(key) check: Optional[object] = None - # Resolve promises + # Resolve futures if isinstance(state, PendingState): _logger.debug('Waiting for asyncio Task to finish ...') - # check if promise is finished or not + # check if future is finished or not if state.done(): res = state.resolve() self._update_state(res, key) @@ -746,8 +770,8 @@ async def handle_update( # type: ignore[override] """Send the update to the callback for the current state and Handler Args: - check_result: The result from check_update. For this handler it's a tuple of the - conversation state, key, handler, and the handler's check result. + check_result: The result from :meth:`check_update`. For this handler it's a tuple of + the conversation state, key, handler, and the handler's check result. update (:class:`telegram.Update`): Incoming telegram update. application (:class:`telegram.ext.Application`): Application that originated the update. @@ -779,7 +803,7 @@ async def handle_update( # type: ignore[override] else: block = DefaultValue.get_value(handler.block) - try: + try: # Now create task or await the callback if block: new_state: object = await handler.handle_update( update, application, handler_check_result, context @@ -857,6 +881,10 @@ def _update_state(self, new_state: object, key: ConversationKey) -> None: self._conversations[key] = new_state async def _trigger_timeout(self, context: CallbackContext) -> None: + """This is run whenever a conversation has timed out. Also makes sure that all handlers + which are in the :attr:`TIMEOUT` state and whose :meth:`Handler.check_update` returns + :obj:`True` is handled. + """ job = cast('Job', context.job) ctxt = cast(_ConversationTimeoutContext, job.context) @@ -873,6 +901,7 @@ async def _trigger_timeout(self, context: CallbackContext) -> None: return del self.timeout_jobs[ctxt.conversation_key] + # Now run all handlers which are in TIMEOUT state handlers = self.states.get(self.TIMEOUT, []) for handler in handlers: check = handler.check_update(ctxt.update) diff --git a/telegram/ext/_defaults.py b/telegram/ext/_defaults.py index 2c3139479b0..8e1a43efd73 100644 --- a/telegram/ext/_defaults.py +++ b/telegram/ext/_defaults.py @@ -17,7 +17,7 @@ # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. # pylint: disable=no-self-use -"""This module contains the class Defaults, which allows to pass default values to Updater.""" +"""This module contains the class Defaults, which allows passing default values to Application.""" from typing import NoReturn, Optional, Dict, Any import pytz @@ -32,7 +32,8 @@ class Defaults: Parameters: - parse_mode (:obj:`str`, optional): Send Markdown or HTML, if you want Telegram apps to show + parse_mode (:obj:`str`, optional): Send :attr:`~telegram.constants.ParseMode.MARKDOWN` or + :attr:`~telegram.constants.ParseMode.HTML`, if you want Telegram apps to show bold, italic, fixed-width text or URLs in your bot's message. disable_notification (:obj:`bool`, optional): Sends the message silently. Users will receive a notification with no sound. @@ -45,10 +46,10 @@ class Defaults: be ignored. Default: :obj:`True` in group chats and :obj:`False` in private chats. tzinfo (:obj:`tzinfo`, optional): A timezone to be used for all date(time) inputs appearing throughout PTB, i.e. if a timezone naive date(time) object is passed - somewhere, it will be assumed to be in ``tzinfo``. Must be a timezone provided by the - ``pytz`` module. Defaults to UTC. - block (:obj:`bool`, optional): Default setting for the ``block`` parameter of - handlers and error handlers registered through :meth:`Application.add_handler` and + somewhere, it will be assumed to be in :paramref:`tzinfo`. Must be a timezone provided + by the ``pytz`` module. Defaults to UTC. + block (:obj:`bool`, optional): Default setting for the :paramref:`Handler.block` parameter + of handlers and error handlers registered through :meth:`Application.add_handler` and :meth:`Application.add_error_handler`. Defaults to :obj:`True`. protect_content (:obj:`bool`, optional): Protects the contents of the sent message from forwarding and saving. @@ -194,7 +195,7 @@ def tzinfo(self, value: object) -> NoReturn: @property def block(self) -> bool: - """:obj:`bool`: Optional. Default setting for the ``block`` parameter of + """:obj:`bool`: Optional. Default setting for the :paramref:`Handler.block` parameter of handlers and error handlers registered through :meth:`Application.add_handler` and :meth:`Application.add_error_handler`. """ diff --git a/telegram/ext/_dictpersistence.py b/telegram/ext/_dictpersistence.py index c1f598129e2..c5e6d676397 100644 --- a/telegram/ext/_dictpersistence.py +++ b/telegram/ext/_dictpersistence.py @@ -56,24 +56,24 @@ class DictPersistence(BasePersistence): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. - update_interval (:obj:`int` | :obj:`float`, optional): The - :class:`~telegram.ext.Application` will update - the persistence in regular intervals. This parameter specifies the time (in seconds) to - wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. - - .. versionadded:: 14.0 user_data_json (:obj:`str`, optional): JSON string that will be used to reconstruct user_data on creating this persistence. Default is ``""``. chat_data_json (:obj:`str`, optional): JSON string that will be used to reconstruct chat_data on creating this persistence. Default is ``""``. bot_data_json (:obj:`str`, optional): JSON string that will be used to reconstruct bot_data on creating this persistence. Default is ``""``. + conversations_json (:obj:`str`, optional): JSON string that will be used to reconstruct + conversation on creating this persistence. Default is ``""``. callback_data_json (:obj:`str`, optional): Json string that will be used to reconstruct callback_data on creating this persistence. Default is ``""``. .. versionadded:: 13.6 - conversations_json (:obj:`str`, optional): JSON string that will be used to reconstruct - conversation on creating this persistence. Default is ``""``. + update_interval (:obj:`int` | :obj:`float`, optional): The + :class:`~telegram.ext.Application` will update + the persistence in regular intervals. This parameter specifies the time (in seconds) to + wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. + + .. versionadded:: 14.0 Attributes: store_data (:class:`PersistenceInput`): Specifies which kinds of data will be saved by this diff --git a/telegram/ext/_handler.py b/telegram/ext/_handler.py index 3bbbd372125..a12602b6623 100644 --- a/telegram/ext/_handler.py +++ b/telegram/ext/_handler.py @@ -38,10 +38,15 @@ class Handler(Generic[UT, CCT], ABC): When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. + .. versionchanged:: 14.0 + The attribute ``run_async`` is now :paramref:`block`. + Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -50,7 +55,7 @@ class Handler(Generic[UT, CCT], ABC): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run asynchronously. """ @@ -105,7 +110,7 @@ async def handle_update( Args: update (:obj:`str` | :class:`telegram.Update`): The update to be handled. application (:class:`telegram.ext.Application`): The calling application. - check_result (:class:`object`): The result from :attr:`check_update`. + check_result (:class:`object`): The result from :meth:`check_update`. context (:class:`telegram.ext.CallbackContext`): The context as provided by the application. @@ -126,6 +131,6 @@ def collect_additional_context( context (:class:`telegram.ext.CallbackContext`): The context object. update (:class:`telegram.Update`): The update to gather chat/user id from. application (:class:`telegram.ext.Application`): The calling application. - check_result: The result (return value) from :attr:`check_update`. + check_result: The result (return value) from :meth:`check_update`. """ diff --git a/telegram/ext/_inlinequeryhandler.py b/telegram/ext/_inlinequeryhandler.py index 2e33fd37cf5..5f9fed0ff61 100644 --- a/telegram/ext/_inlinequeryhandler.py +++ b/telegram/ext/_inlinequeryhandler.py @@ -54,28 +54,30 @@ class InlineQueryHandler(Handler[Update, CCT]): updates won't be handled, if :attr:`chat_types` is passed. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. pattern (:obj:`str` | :func:`re.Pattern `, optional): Regex pattern. If not :obj:`None`, :func:`re.match` is used on :attr:`telegram.InlineQuery.query` to determine if an update should be handled by this handler. + block (:obj:`bool`, optional): Determines whether the return value of the callback should + be awaited before processing the next handler in + :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. chat_types (List[:obj:`str`], optional): List of allowed chat types. If passed, will only handle inline queries with the appropriate :attr:`telegram.InlineQuery.chat_type`. .. versionadded:: 13.5 - block (:obj:`bool`, optional): Determines whether the return value of the callback should - be awaited before processing the next handler in - :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. pattern (:obj:`str` | :func:`re.Pattern `): Optional. Regex pattern to test :attr:`telegram.InlineQuery.query` against. - chat_types (List[:obj:`str`], optional): List of allowed chat types. + chat_types (List[:obj:`str`]): Optional. List of allowed chat types. .. versionadded:: 13.5 block (:obj:`bool`): Determines whether the return value of the callback should be @@ -103,13 +105,13 @@ def __init__( def check_update(self, update: object) -> Optional[Union[bool, Match]]: """ - Determines whether an update should be passed to this handlers :attr:`callback`. + Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. Returns: - :obj:`bool` + :obj:`bool` | :obj:`re.match` """ if isinstance(update, Update) and update.inline_query: diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 91a6739cb4b..7bd8c18eceb 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -91,7 +91,6 @@ def _parse_time_input( if shift_day and date_time <= datetime.datetime.now(pytz.utc): date_time += datetime.timedelta(days=1) return date_time - # isinstance(time, datetime.datetime): return time def set_application(self, application: 'Application') -> None: @@ -131,8 +130,11 @@ def run_once( """Creates a new :class:`Job` instance that runs once and adds it to the queue. Args: - callback (:obj:`callable`): The callback function that should be executed by the new - job. Callback signature: ``def callback(context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function that should be executed by + the new job. Callback signature:: + + async def callback(context: CallbackContext) + when (:obj:`int` | :obj:`float` | :obj:`datetime.timedelta` | \ :obj:`datetime.datetime` | :obj:`datetime.time`): Time in or at which the job should run. This parameter will be interpreted @@ -165,7 +167,7 @@ def run_once( Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to - ``callback.__name__``. + :external:attr:`callback.__name__ `. job_kwargs (:obj:`dict`, optional): Arbitrary keyword arguments to pass to the :meth:`apscheduler.schedulers.base.BaseScheduler.add_job()`. @@ -216,8 +218,11 @@ def run_repeating( #daylight-saving-time-behavior Args: - callback (:obj:`callable`): The callback function that should be executed by the new - job. Callback signature: ``def callback(context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function that should be executed by + the new job. Callback signature:: + + async def callback(context: CallbackContext) + interval (:obj:`int` | :obj:`float` | :obj:`datetime.timedelta`): The interval in which the job will run. If it is an :obj:`int` or a :obj:`float`, it will be interpreted as seconds. @@ -253,7 +258,7 @@ def run_repeating( Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to - ``callback.__name__``. + :external:attr:`callback.__name__ `. chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will be available in the callback. @@ -320,8 +325,11 @@ def run_monthly( parameter to have the job run on the last day of the month. Args: - callback (:obj:`callable`): The callback function that should be executed by the new - job. Callback signature: ``def callback(context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function that should be executed by + the new job. Callback signature:: + + async def callback(context: CallbackContext) + when (:obj:`datetime.time`): Time of day at which the job should run. If the timezone (``when.tzinfo``) is :obj:`None`, the default timezone of the bot will be used. day (:obj:`int`): Defines the day of the month whereby the job would run. It should @@ -332,7 +340,7 @@ def run_monthly( Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to - ``callback.__name__``. + :external:attr:`callback.__name__ `. chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will be available in the callback. @@ -393,8 +401,11 @@ def run_daily( #daylight-saving-time-behavior Args: - callback (:obj:`callable`): The callback function that should be executed by the new - job. Callback signature: ``def callback(context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function that should be executed by + the new job. Callback signature:: + + async def callback(context: CallbackContext) + time (:obj:`datetime.time`): Time of day at which the job should run. If the timezone (:obj:`datetime.time.tzinfo`) is :obj:`None`, the default timezone of the bot will be used. @@ -404,7 +415,7 @@ def run_daily( Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to - ``callback.__name__``. + :external:attr:`callback.__name__ `. chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will be available in the callback. @@ -458,15 +469,18 @@ def run_custom( """Creates a new custom defined :class:`Job`. Args: - callback (:obj:`callable`): The callback function that should be executed by the new - job. Callback signature: ``def callback(context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function that should be executed by + the new job. Callback signature:: + + async def callback(context: CallbackContext) + job_kwargs (:obj:`dict`): Arbitrary keyword arguments. Used as arguments for :meth:`apscheduler.schedulers.base.BaseScheduler.add_job`. context (:obj:`object`, optional): Additional data needed for the callback function. Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. name (:obj:`str`, optional): The name of the new job. Defaults to - ``callback.__name__``. + :external:attr:`callback.__name__ `. chat_id (:obj:`int`, optional): Chat id of the chat associated with this job. If passed, the corresponding :attr:`~telegram.ext.CallbackContext.chat_data` will be available in the callback. @@ -502,7 +516,7 @@ async def stop(self, wait: bool = True) -> None: """Shuts down the :class:`~telegram.ext.JobQueue`. Args: - wait (:obj:`bool`, optional): Whether or not to wait until all currently running jobs + wait (:obj:`bool`, optional): Whether to wait until all currently running jobs have finished. Defaults to :obj:`True`. """ @@ -546,8 +560,6 @@ class Job: Note: * All attributes and instance methods of :attr:`job` are also directly available as attributes/methods of the corresponding :class:`telegram.ext.Job` object. - * Two instances of :class:`telegram.ext.Job` are considered equal, if their corresponding - :attr:`job` attributes have the same ``id``. * If :attr:`job` isn't passed on initialization, it must be set manually afterwards for this :class:`telegram.ext.Job` to be useful. @@ -555,11 +567,15 @@ class Job: Removed argument and attribute ``job_queue``. Args: - callback (:obj:`callable`): The callback function that should be executed by the new job. - Callback signature: ``def callback(context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function that should be executed by the + new job. Callback signature:: + + async def callback(context: CallbackContext) + context (:obj:`object`, optional): Additional data needed for the callback function. Can be accessed through :attr:`Job.context` in the callback. Defaults to :obj:`None`. - name (:obj:`str`, optional): The name of the new job. Defaults to ``callback.__name__``. + name (:obj:`str`, optional): The name of the new job. Defaults to + :external:obj:`callback.__name__ `. job (:class:`apscheduler.job.Job`, optional): The APS Job this job is a wrapper for. chat_id (:obj:`int`, optional): Chat id of the chat that this job is associated with. @@ -569,7 +585,8 @@ class Job: ..versionadded:: 14.0 Attributes: - callback (:obj:`callable`): The callback function that should be executed by the new job. + callback (:term:`coroutine function`): The callback function that should be executed by the + new job. context (:obj:`object`): Optional. Additional data needed for the callback function. name (:obj:`str`): Optional. The name of the new job. job (:class:`apscheduler.job.Job`): Optional. The APS Job this job is a wrapper for. diff --git a/telegram/ext/_messagehandler.py b/telegram/ext/_messagehandler.py index 477d21aed8f..97f3ef26855 100644 --- a/telegram/ext/_messagehandler.py +++ b/telegram/ext/_messagehandler.py @@ -33,7 +33,7 @@ class MessageHandler(Handler[Update, CCT]): - """Handler class to handle telegram messages. They might contain text, media or status updates. + """Handler class to handle Telegram messages. They might contain text, media or status updates. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom @@ -48,9 +48,11 @@ class MessageHandler(Handler[Update, CCT]): :attr:`telegram.Update.channel_post` and :attr:`telegram.Update.edited_channel_post`. If you don't want or need any of those pass ``~filters.UpdateType.*`` in the filter argument. - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -58,13 +60,10 @@ class MessageHandler(Handler[Update, CCT]): be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. - Raises: - ValueError - Attributes: filters (:class:`telegram.ext.filters.BaseFilter`): Only allow updates with these Filters. See :mod:`telegram.ext.filters` for a full list of all available filters. - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. @@ -84,7 +83,7 @@ def __init__( self.filters = filters if filters is not None else filters_module.ALL def check_update(self, update: object) -> Optional[Union[bool, Dict[str, list]]]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_picklepersistence.py b/telegram/ext/_picklepersistence.py index 7899dea6675..8d3091ea10a 100644 --- a/telegram/ext/_picklepersistence.py +++ b/telegram/ext/_picklepersistence.py @@ -128,7 +128,7 @@ def persistent_load(self, pid: str) -> Optional[Bot]: class PicklePersistence(BasePersistence[UD, CD, BD]): - """Using python's builtin pickle for making your bot persistent. + """Using python's builtin :mod:`pickle` for making your bot persistent. Attention: The interface provided by this class is intended to be accessed exclusively by @@ -153,12 +153,6 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): store_data (:class:`PersistenceInput`, optional): Specifies which kinds of data will be saved by this persistence instance. By default, all available kinds of data will be saved. - update_interval (:obj:`int` | :obj:`float`, optional): The - :class:`~telegram.ext.Application` will update - the persistence in regular intervals. This parameter specifies the time (in seconds) to - wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. - - .. versionadded:: 14.0 single_file (:obj:`bool`, optional): When :obj:`False` will store 5 separate files of `filename_user_data`, `filename_bot_data`, `filename_chat_data`, `filename_callback_data` and `filename_conversations`. Default is :obj:`True`. @@ -172,6 +166,12 @@ class PicklePersistence(BasePersistence[UD, CD, BD]): :class:`telegram.ext.ContextTypes` will be used. .. versionadded:: 13.6 + update_interval (:obj:`int` | :obj:`float`, optional): The + :class:`~telegram.ext.Application` will update + the persistence in regular intervals. This parameter specifies the time (in seconds) to + wait between two consecutive runs of updating the persistence. Defaults to 60 seconds. + + .. versionadded:: 14.0 Attributes: filepath (:obj:`str` | :obj:`pathlib.Path`): The filepath for storing the pickle files. @@ -400,7 +400,7 @@ async def update_conversation( Args: name (:obj:`str`): The handler's name. key (:obj:`tuple`): The key the state is changed for. - new_state (:obj:`tuple` | :class:`object`): The new state for the given key. + new_state (:class:`object`): The new state for the given key. """ if not self.conversations: self.conversations = {} diff --git a/telegram/ext/_pollanswerhandler.py b/telegram/ext/_pollanswerhandler.py index ddd9e46f9b5..a5a9276fed0 100644 --- a/telegram/ext/_pollanswerhandler.py +++ b/telegram/ext/_pollanswerhandler.py @@ -26,16 +26,19 @@ class PollAnswerHandler(Handler[Update, CCT]): - """Handler class to handle Telegram updates that contain a poll answer. + """Handler class to handle Telegram updates that contain a + :attr:`poll answer `. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -44,7 +47,7 @@ class PollAnswerHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run asynchronously. """ @@ -52,7 +55,7 @@ class PollAnswerHandler(Handler[Update, CCT]): __slots__ = () def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_pollhandler.py b/telegram/ext/_pollhandler.py index 8426aaa75db..d6b37a7824a 100644 --- a/telegram/ext/_pollhandler.py +++ b/telegram/ext/_pollhandler.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains the PollHandler classes.""" +"""This module contains the PollHandler class.""" from telegram import Update @@ -26,16 +26,18 @@ class PollHandler(Handler[Update, CCT]): - """Handler class to handle Telegram updates that contain a poll. + """Handler class to handle Telegram updates that contain a :attr:`poll `. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -44,7 +46,7 @@ class PollHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run asynchronously. """ @@ -52,7 +54,7 @@ class PollHandler(Handler[Update, CCT]): __slots__ = () def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_precheckoutqueryhandler.py b/telegram/ext/_precheckoutqueryhandler.py index f253f4b36b3..30e5f919275 100644 --- a/telegram/ext/_precheckoutqueryhandler.py +++ b/telegram/ext/_precheckoutqueryhandler.py @@ -25,16 +25,18 @@ class PreCheckoutQueryHandler(Handler[Update, CCT]): - """Handler class to handle Telegram PreCheckout callback queries. + """Handler class to handle Telegram :attr:`telegram.Update.pre_checkout_query`. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -43,7 +45,7 @@ class PreCheckoutQueryHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run asynchronously. """ @@ -51,7 +53,7 @@ class PreCheckoutQueryHandler(Handler[Update, CCT]): __slots__ = () def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_shippingqueryhandler.py b/telegram/ext/_shippingqueryhandler.py index 393136d2cba..e26a2028ef2 100644 --- a/telegram/ext/_shippingqueryhandler.py +++ b/telegram/ext/_shippingqueryhandler.py @@ -25,16 +25,18 @@ class ShippingQueryHandler(Handler[Update, CCT]): - """Handler class to handle Telegram shipping callback queries. + """Handler class to handle Telegram :attr:`telegram.Update.shipping_query`. Warning: When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -43,7 +45,7 @@ class ShippingQueryHandler(Handler[Update, CCT]): :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the callback will run asynchronously. """ @@ -51,7 +53,7 @@ class ShippingQueryHandler(Handler[Update, CCT]): __slots__ = () def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:class:`telegram.Update` | :obj:`object`): Incoming update. diff --git a/telegram/ext/_stringcommandhandler.py b/telegram/ext/_stringcommandhandler.py index fc20633e7ed..06821d615ec 100644 --- a/telegram/ext/_stringcommandhandler.py +++ b/telegram/ext/_stringcommandhandler.py @@ -31,12 +31,12 @@ class StringCommandHandler(Handler[str, CCT]): """Handler class to handle string commands. Commands are string updates that start with ``/``. - The handler will add a ``list`` to the + The handler will add a :obj:`list` to the :class:`CallbackContext` named :attr:`CallbackContext.args`. It will contain a list of strings, which is the text following the command split on single whitespace characters. Note: - This handler is not used to handle Telegram :attr:`telegram.Update`, but strings manually + This handler is not used to handle Telegram :class:`telegram.Update`, but strings manually put in the queue. For example to send messages with the bot using command line or API. Warning: @@ -45,9 +45,11 @@ class StringCommandHandler(Handler[str, CCT]): Args: command (:obj:`str`): The command this handler should listen for. - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -57,7 +59,7 @@ class StringCommandHandler(Handler[str, CCT]): Attributes: command (:obj:`str`): The command this handler should listen for. - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. @@ -76,13 +78,13 @@ def __init__( self.command = command def check_update(self, update: object) -> Optional[List[str]]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:obj:`object`): The incoming update. Returns: - :obj:`bool` + List[:obj:`str`]: List containing the text command split on whitespace. """ if isinstance(update, str) and update.startswith('/'): diff --git a/telegram/ext/_stringregexhandler.py b/telegram/ext/_stringregexhandler.py index 7ac278bbdfe..381d0152f4a 100644 --- a/telegram/ext/_stringregexhandler.py +++ b/telegram/ext/_stringregexhandler.py @@ -39,7 +39,7 @@ class StringRegexHandler(Handler[str, CCT]): function is used to determine if an update should be handled by this handler. Note: - This handler is not used to handle Telegram :attr:`telegram.Update`, but strings manually + This handler is not used to handle Telegram :class:`telegram.Update`, but strings manually put in the queue. For example to send messages with the bot using command line or API. Warning: @@ -48,9 +48,11 @@ class StringRegexHandler(Handler[str, CCT]): Args: pattern (:obj:`str` | :func:`re.Pattern `): The regex pattern. - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. @@ -60,7 +62,7 @@ class StringRegexHandler(Handler[str, CCT]): Attributes: pattern (:obj:`str` | :func:`re.Pattern `): The regex pattern. - callback (:obj:`callable`): The callback function for this handler. + callback (:term:`coroutine function`): The callback function for this handler. block (:obj:`bool`): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. @@ -83,13 +85,13 @@ def __init__( self.pattern = pattern def check_update(self, update: object) -> Optional[Match]: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:obj:`object`): The incoming update. Returns: - :obj:`bool` + :obj:`None` | :obj:`re.match` """ if isinstance(update, str): diff --git a/telegram/ext/_typehandler.py b/telegram/ext/_typehandler.py index 6e38464cdc8..7b53f5c1f43 100644 --- a/telegram/ext/_typehandler.py +++ b/telegram/ext/_typehandler.py @@ -37,24 +37,28 @@ class TypeHandler(Handler[UT, CCT]): attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: - type (:obj:`type`): The ``type`` of updates this handler should process, as - determined by ``isinstance`` - callback (:obj:`callable`): The callback function for this handler. Will be called when - :attr:`check_update` has determined that an update should be processed by this handler. - Callback signature: ``def callback(update: Update, context: CallbackContext)`` + type (:external:class:`type`): The :external:class:`type` of updates this handler should + process, as determined by :obj:`isinstance` + callback (:term:`coroutine function`): The callback function for this handler. Will be + called when :meth:`check_update` has determined that an update should be processed by + this handler. Callback signature:: + + async def callback(update: Update, context: CallbackContext) The return value of the callback is usually ignored except for the special case of :class:`telegram.ext.ConversationHandler`. - strict (:obj:`bool`, optional): Use ``type`` instead of ``isinstance``. - Default is :obj:`False` + strict (:obj:`bool`, optional): Use ``type`` instead of :obj:`isinstance`. + Default is :obj:`False`. block (:obj:`bool`, optional): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`. Attributes: - type (:obj:`type`): The ``type`` of updates this handler should process. - callback (:obj:`callable`): The callback function for this handler. - strict (:obj:`bool`): Use ``type`` instead of ``isinstance``. Default is :obj:`False`. + type (:external:class:`type`): The :external:class:`type` of updates this handler should + process. + callback (:term:`coroutine function`): The callback function for this handler. + strict (:obj:`bool`): Use :external:class:`type` instead of :obj:`isinstance`. Default is + :obj:`False`. block (:obj:`bool`): Determines whether the return value of the callback should be awaited before processing the next handler in :meth:`telegram.ext.Application.process_update`. @@ -71,11 +75,11 @@ def __init__( block: DVInput[bool] = DEFAULT_TRUE, ): super().__init__(callback, block=block) - self.type = type # pylint: disable=assigning-non-slot - self.strict = strict # pylint: disable=assigning-non-slot + self.type = type + self.strict = strict def check_update(self, update: object) -> bool: - """Determines whether an update should be passed to this handlers :attr:`callback`. + """Determines whether an update should be passed to this handler's :attr:`callback`. Args: update (:obj:`object`): Incoming update. diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 0c41427d364..c7c3272c001 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -57,11 +57,11 @@ class Updater: the sole purpose of this class is to fetch updates. The entry point to a PTB application is now :class:`telegram.ext.Application`. - Attributes: + Args: bot (:class:`telegram.Bot`): The bot used with this Updater. update_queue (:class:`asyncio.Queue`): Queue for the updates. - Args: + Attributes: bot (:class:`telegram.Bot`): The bot used with this Updater. update_queue (:class:`asyncio.Queue`): Queue for the updates. @@ -100,6 +100,12 @@ def running(self) -> bool: return self._running async def initialize(self) -> None: + """Initialize the Updater & the associated :attr:`bot` by calling + :meth:`telegram.Bot.initialize`. + + .. seealso:: + :meth:`shutdown` + """ if self._initialized: self._logger.debug('This Updater is already initialized.') return @@ -109,8 +115,10 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """ + Shutdown the Updater & the associated :attr:`bot` by calling :meth:`telegram.Bot.shutdown`. - Returns: + .. seealso:: + :meth:`initialize` Raises: :exc:`RuntimeError`: If the updater is still running. @@ -127,6 +135,7 @@ async def shutdown(self) -> None: self._logger.debug('Shut down of Updater complete') async def __aenter__(self: _UpdaterType) -> _UpdaterType: + """Simple context manager which initializes the Updater.""" try: await self.initialize() return self @@ -140,6 +149,7 @@ async def __aexit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + """Shutdown the Updater from the context manager.""" # Make sure not to return `True` so that exceptions are not suppressed # https://docs.python.org/3/reference/datamodel.html?#object.__aexit__ await self.shutdown() @@ -165,38 +175,46 @@ async def start_polling( Args: poll_interval (:obj:`float`, optional): Time to wait between polling updates from Telegram in seconds. Default is ``0.0``. - timeout (:obj:`float`, optional): Passed to :meth:`telegram.Bot.get_updates`. - drop_pending_updates (:obj:`bool`, optional): Whether to clean any pending updates on - Telegram servers before actually starting to poll. Default is :obj:`False`. - - .. versionadded :: 13.4 + timeout (:obj:`float`, optional): Passed to + :paramref:`telegram.Bot.get_updates.timeout`. Defaults to ``10`` seconds. bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the :class:`telegram.ext.Updater` will retry on failures on the Telegram server. * < 0 - retry indefinitely (default) * 0 - no retries * > 0 - retry up to X times - + read_timeout (:obj:`float`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.read_timeout`. Defaults to ``2``. + write_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.write_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.connect_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): Value to pass to + :paramref:`telegram.Bot.get_updates.pool_timeout`. Defaults to + :attr:`~telegram.request.BaseRequest.DEFAULT_NONE`. allowed_updates (List[:obj:`str`], optional): Passed to :meth:`telegram.Bot.get_updates`. - read_timeout (:obj:`float` | :obj:`int`, optional): Grace time in seconds for receiving - the reply from server. Will be added to the ``timeout`` value and used as the read - timeout from server (Default: ``2``). + drop_pending_updates (:obj:`bool`, optional): Whether to clean any pending updates on + Telegram servers before actually starting to poll. Default is :obj:`False`. + + .. versionadded :: 13.4 error_callback (Callable[[:exc:`telegram.error.TelegramError`], :obj:`None`], \ optional): Callback to handle :exc:`telegram.error.TelegramError` s that occur while calling :meth:`telegram.Bot.get_updates` during polling. Defaults to :obj:`None`, in which case errors will be logged. Note: - The :paramref:`error_callback` must *not* be a coroutine function! If - asynchorous behavior of the callback is wanted, please schedule a task from + The :paramref:`error_callback` must *not* be a :term:`coroutine function`! If + asynchronous behavior of the callback is wanted, please schedule a task from within the callback. Returns: :class:`asyncio.Queue`: The update queue that can be filled from the main thread. Raises: - :exc:`RuntimeError`: If the updater is already running. + :exc:`RuntimeError`: If the updater is already running or was not initialized. """ if error_callback and asyncio.iscoroutinefunction(error_callback): @@ -244,7 +262,7 @@ async def _start_polling( self, poll_interval: float, timeout: int, - read_timeout: Optional[float], + read_timeout: float, write_timeout: ODVInput[float], connect_timeout: ODVInput[float], pool_timeout: ODVInput[float], @@ -257,6 +275,9 @@ async def _start_polling( self._logger.debug('Updater started (polling)') + # the bootstrapping phase does two things: + # 1) make sure there is no webhook set + # 2) apply drop_pending_updates await self._bootstrap( bootstrap_retries, drop_pending_updates=drop_pending_updates, @@ -298,9 +319,9 @@ async def polling_action_cb() -> bool: else: for update in updates: await self.update_queue.put(update) - self._last_update_id = updates[-1].update_id + 1 + self._last_update_id = updates[-1].update_id + 1 # Add one to 'confirm' it - return True + return True # Keep fetching updates & don't quit. Polls with poll_interval. def default_error_callback(exc: TelegramError) -> None: self._logger.exception('Exception happened while polling for updates.', exc_info=exc) @@ -349,10 +370,12 @@ async def start_webhook( the deprecated argument ``force_event_loop``. Args: - listen (:obj:`str`, optional): IP-Address to listen on. Default ``127.0.0.1``. + listen (:obj:`str`, optional): IP-Address to listen on. Defaults to + `127.0.0.1 `_. port (:obj:`int`, optional): Port the bot should be listening on. Must be one of :attr:`telegram.constants.SUPPORTED_WEBHOOK_PORTS`. Defaults to ``80``. - url_path (:obj:`str`, optional): Path inside url. + url_path (:obj:`str`, optional): Path inside url (http(s)://listen:port/). + Defaults to ``''``. cert (:class:`pathlib.Path` | :obj:`str`, optional): Path to the SSL certificate file. key (:class:`pathlib.Path` | :obj:`str`, optional): Path to the SSL key file. drop_pending_updates (:obj:`bool`, optional): Whether to clean any pending updates on @@ -361,24 +384,25 @@ async def start_webhook( bootstrap_retries (:obj:`int`, optional): Whether the bootstrapping phase of the :class:`telegram.ext.Updater` will retry on failures on the Telegram server. - * < 0 - retry indefinitely (default) - * 0 - no retries + * < 0 - retry indefinitely + * 0 - no retries (default) * > 0 - retry up to X times webhook_url (:obj:`str`, optional): Explicitly specify the webhook url. Useful behind - NAT, reverse proxy, etc. Default is derived from ``listen``, ``port`` & - ``url_path``. + NAT, reverse proxy, etc. Default is derived from :paramref:`listen`, + :paramref:`port`, :paramref:`url_path`, :paramref:`cert`, and :paramref:`key`. ip_address (:obj:`str`, optional): Passed to :meth:`telegram.Bot.set_webhook`. + Defaults to :obj:`None`. .. versionadded :: 13.4 allowed_updates (List[:obj:`str`], optional): Passed to - :meth:`telegram.Bot.set_webhook`. + :meth:`telegram.Bot.set_webhook`. Defaults to :obj:`None`. max_connections (:obj:`int`, optional): Passed to - :meth:`telegram.Bot.set_webhook`. + :meth:`telegram.Bot.set_webhook`. Defaults to ``40``. .. versionadded:: 13.6 Returns: :class:`queue.Queue`: The update queue that can be filled from the main thread. Raises: - :exc:`RuntimeError`: If the updater is already running. + :exc:`RuntimeError`: If the updater is already running or was not initialized. """ async with self.__lock: if self.running: @@ -499,7 +523,7 @@ async def _network_loop_retry( `action_cb` evaluates :obj:`False`. Args: - action_cb (:obj:`callable`): Network oriented callback function to call. + action_cb (:term:`coroutine function`): Network oriented callback function to call. on_err_cb (:obj:`callable`): Callback to call when TelegramError is caught. Receives the exception object as a parameter. description (:obj:`str`): Description text to use for logs and exception raised. @@ -551,6 +575,10 @@ async def _bootstrap( ip_address: str = None, max_connections: int = 40, ) -> None: + """Prepares the setup for fetching updates: delete or set the webhook and drop pending + updates if appropriate. If there are unsuccessful attempts, this will retry as specified by + :paramref:`max_retries`. + """ retries = 0 async def bootstrap_del_webhook() -> bool: @@ -589,7 +617,7 @@ def bootstrap_on_err_cb(exc: Exception) -> None: raise exc # Dropping pending updates from TG can be efficiently done with the drop_pending_updates - # parameter of delete/start_webhook, even in the case of polling. Also we want to make + # parameter of delete/start_webhook, even in the case of polling. Also, we want to make # sure that no webhook is configured in case of polling, so we just always call # delete_webhook for polling if drop_pending_updates or not webhook_url: @@ -616,6 +644,9 @@ def bootstrap_on_err_cb(exc: Exception) -> None: async def stop(self) -> None: """Stops the polling/webhook. + .. seealso:: + :meth:`start_polling`, :meth:`start_webhook` + Raises: :exc:`RuntimeError`: If the updater is not running. """ @@ -633,12 +664,14 @@ async def stop(self) -> None: self._logger.debug('Updater.stop() is complete') async def _stop_httpd(self) -> None: + """Stops the Webhook server by calling ``WebhookServer.shutdown()``""" if self._httpd: self._logger.debug('Waiting for current webhook connection to be closed.') await self._httpd.shutdown() self._httpd = None async def _stop_polling(self) -> None: + """Stops the polling task by awaiting it.""" if self.__polling_task: self._logger.debug('Waiting background polling task to finish up.') self.__polling_task.cancel() @@ -647,7 +680,7 @@ async def _stop_polling(self) -> None: await self.__polling_task except asyncio.CancelledError: # This only happens in rare edge-cases, e.g. when `stop()` is called directly - # after start_polling(), but let's better be safe than sorry ... + # after start_polling(), but lets better be safe than sorry ... pass self.__polling_task = None diff --git a/telegram/ext/_utils/stack.py b/telegram/ext/_utils/stack.py index b4a675d9318..6b2324a80c8 100644 --- a/telegram/ext/_utils/stack.py +++ b/telegram/ext/_utils/stack.py @@ -46,7 +46,7 @@ def was_called_by(frame: Optional[FrameType], caller: Path) -> bool: caller (:obj:`pathlib.Path`): File that should be the caller. Returns: - :obj:`bool`: Whether or not the frame was called by the specified file. + :obj:`bool`: Whether the frame was called by the specified file. """ if frame is None: return False diff --git a/telegram/ext/_utils/trackingdict.py b/telegram/ext/_utils/trackingdict.py index 314ab219d06..4086a5410c1 100644 --- a/telegram/ext/_utils/trackingdict.py +++ b/telegram/ext/_utils/trackingdict.py @@ -49,7 +49,7 @@ class TrackingDict(UserDict, Generic[_KT, _VT]): Read-access is not tracked. Note: - * ``setdefault()`` and ``pop`` are considered writing only depending on whether or not the + * ``setdefault()`` and ``pop`` are considered writing only depending on whether the key is present * deleting values is considered writing """ @@ -86,7 +86,8 @@ def pop_accessed_write_items(self) -> List[Tuple[_KT, _VT]]: def mark_as_accessed(self, key: _KT) -> None: """Use this method have the key returned again in the next call to - :meth:`pop_accessed_write_items` or :meth:`pop_accessed_keys""" + :meth:`pop_accessed_write_items` or :meth:`pop_accessed_keys` + """ self._write_access_keys.add(key) # Override methods to track access diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 9701ede191f..b20d6a0a92e 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -119,37 +119,36 @@ class BaseFilter: Filters subclassing from this class can combined using bitwise operators: - And: + And:: - >>> (filters.TEXT & filters.Entity(MENTION)) + filters.TEXT & filters.Entity(MENTION) - Or: + Or:: - >>> (filters.AUDIO | filters.VIDEO) + filters.AUDIO | filters.VIDEO - Exclusive Or: + Exclusive Or:: - >>> (filters.Regex('To Be') ^ filters.Regex('Not 2B')) + filters.Regex('To Be') ^ filters.Regex('Not 2B') - Not: + Not:: - >>> ~ filters.COMMAND + ~ filters.COMMAND - Also works with more than two filters: + Also works with more than two filters:: - >>> (filters.TEXT & (filters.Entity(URL) | filters.Entity(TEXT_LINK))) - >>> filters.TEXT & (~ filters.FORWARDED) + filters.TEXT & (filters.Entity(URL) | filters.Entity(TEXT_LINK)) + filters.TEXT & (~ filters.FORWARDED) Note: - Filters use the same short circuiting logic as python's `and`, `or` and `not`. - This means that for example: + Filters use the same short circuiting logic as python's :keyword:`and`, :keyword:`or` and + :keyword:`not`. This means that for example:: - >>> filters.Regex(r'(a?x)') | filters.Regex(r'(b?x)') + filters.Regex(r'(a?x)') | filters.Regex(r'(b?x)') With ``message.text == 'x'``, will only ever return the matches for the first filter, since the second one is never evaluated. - If you want to create your own filters create a class inheriting from either :class:`MessageFilter` or :class:`UpdateFilter` and implement a ``filter()`` method that returns a boolean: :obj:`True` if the message should be @@ -157,7 +156,7 @@ class BaseFilter: Note that the filters work only as class instances, not actual class objects (so remember to initialize your filter classes). - By default the filters name (what will get printed when converted to a string for display) + By default, the filters name (what will get printed when converted to a string for display) will be the class name. If you want to overwrite this assign a better name to the :attr:`name` class variable. @@ -547,7 +546,7 @@ def filter(self, message: Message) -> bool: class CaptionRegex(MessageFilter): """ - Filters updates by searching for an occurrence of ``pattern`` in the message caption. + Filters updates by searching for an occurrence of :paramref:`pattern` in the message caption. This filter works similarly to :class:`Regex`, with the only exception being that it applies to the message caption instead of the text. @@ -874,7 +873,7 @@ def filter(self, message: Message) -> bool: class Command(MessageFilter): """ - Messages with a :attr:`telegram.MessageEntity.BOT_COMMAND`. By default only allows + Messages with a :attr:`telegram.MessageEntity.BOT_COMMAND`. By default, only allows messages `starting` with a bot command. Pass :obj:`False` to also allow messages that contain a bot command `anywhere` in the text. @@ -1497,7 +1496,7 @@ def filter(self, message: Message) -> bool: class Regex(MessageFilter): """ - Filters updates by searching for an occurrence of ``pattern`` in the message text. + Filters updates by searching for an occurrence of :paramref:`pattern` in the message text. The :func:`re.search` function is used to determine whether an update should be filtered. Refer to the documentation of the :obj:`re` module for more information. @@ -1512,7 +1511,8 @@ class Regex(MessageFilter): if you need to specify flags on your pattern. Note: - Filters use the same short circuiting logic as python's `and`, `or` and `not`. + Filters use the same short circuiting logic as python's :keyword:`and`, :keyword:`or` and + :keyword:`not`. This means that for example: >>> filters.Regex(r'(a?x)') | filters.Regex(r'(b?x)') @@ -1973,7 +1973,10 @@ def filter(self, update: Update) -> bool: EDITED = _Edited(name="filters.UpdateType.EDITED") """Updates with either :attr:`telegram.Update.edited_message` or - :attr:`telegram.Update.edited_channel_post`.""" + :attr:`telegram.Update.edited_channel_post`. + + .. versionadded:: 14.0 + """ class _EditedChannelPost(UpdateFilter): __slots__ = () diff --git a/telegram/helpers.py b/telegram/helpers.py index 965754e801d..f9ce4a3a81d 100644 --- a/telegram/helpers.py +++ b/telegram/helpers.py @@ -54,8 +54,10 @@ def escape_markdown(text: str, version: int = 1, entity_type: str = None) -> str text (:obj:`str`): The text. version (:obj:`int` | :obj:`str`): Use to specify the version of telegrams Markdown. Either ``1`` or ``2``. Defaults to ``1``. - entity_type (:obj:`str`, optional): For the entity types ``PRE``, ``CODE`` and the link - part of ``TEXT_LINKS``, only certain characters need to be escaped in ``MarkdownV2``. + entity_type (:obj:`str`, optional): For the entity types + :tg-const:`telegram.MessageEntity.PRE`, :tg-const:`telegram.MessageEntity.CODE` and + the link part of :tg-const:`telegram.MessageEntity.TEXT_LINK`, only certain characters + need to be escaped in :tg-const:`telegram.constants.ParseMode.MARKDOWN_V2`. See the official API documentation for details. Only valid in combination with ``version=2``, will be ignored else. """ @@ -135,14 +137,14 @@ def effective_message_type(entity: Union['Message', 'Update']) -> Optional[str]: def create_deep_linked_url(bot_username: str, payload: str = None, group: bool = False) -> str: """ - Creates a deep-linked URL for this ``bot_username`` with the specified ``payload``. - See https://core.telegram.org/bots#deep-linking to learn more. + Creates a deep-linked URL for this :paramref:`bot_username` with the specified + :paramref:`payload`. See https://core.telegram.org/bots#deep-linking to learn more. - The ``payload`` may consist of the following characters: ``A-Z, a-z, 0-9, _, -`` + The :paramref:`payload` may consist of the following characters: ``A-Z, a-z, 0-9, _, -`` Note: Works well in conjunction with - ``CommandHandler("start", callback, filters = filters.Regex('payload'))`` + ``CommandHandler("start", callback, filters=filters.Regex('payload'))`` Examples: ``create_deep_linked_url(bot.get_me().username, "some-params")`` diff --git a/telegram/request/__init__.py b/telegram/request/__init__.py index 98e9a676740..91dfa60d8c9 100644 --- a/telegram/request/__init__.py +++ b/telegram/request/__init__.py @@ -1,20 +1,20 @@ +# !/usr/bin/env python +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza # -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. # -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. # -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. -# -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains classes that handle the networking backend of ``python-telegram-bot``.""" from ._requestdata import RequestData diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index 78f4b0e0fdb..639f639e486 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -136,19 +136,22 @@ async def post( url (:obj:`str`): The URL to request. request_data (:class:`telegram.request.RequestData`, optional): An object containing information about parameters and files to upload for the request. - connect_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of - time (in seconds) to wait for a connection attempt to a server to succeed instead - of the time specified during creating of this object. - read_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of time - (in seconds) to wait for a response from Telegram's server instead - of the time specified during creating of this object. - write_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of time - (in seconds) to wait for a write operation to complete (in terms of a network - socket; i.e. POSTing a request or uploading a file) instead - of the time specified during creating of this object. - pool_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of time - (in seconds) to wait for a connection to become available instead - of the time specified during creating of this object. + read_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a response from Telegram's server instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a write operation to complete (in terms of + a network socket; i.e. POSTing a request or uploading a file) instead of the time + specified during creating of this object. Defaults to :attr:`DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the + maximum amount of time (in seconds) to wait for a connection attempt to a server + to succeed instead of the time specified during creating of this object. Defaults + to :attr:`DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a connection to become available instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. Returns: Dict[:obj:`str`, ...]: The JSON response of the Bot API. @@ -184,9 +187,22 @@ async def retrieve( Args: url (:obj:`str`): The web location we want to retrieve. - timeout (:obj:`float`, optional): If this value is specified, use it as the read - timeout from the server (instead of the one specified during creation of the - connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a response from Telegram's server instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a write operation to complete (in terms of + a network socket; i.e. POSTing a request or uploading a file) instead of the time + specified during creating of this object. Defaults to :attr:`DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the + maximum amount of time (in seconds) to wait for a connection attempt to a server + to succeed instead of the time specified during creating of this object. Defaults + to :attr:`DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a connection to become available instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. Returns: :obj:`bytes`: The files contents. @@ -220,10 +236,24 @@ async def _request_wrapper( Args: url (:obj:`str`): The URL to request. method (:obj:`str`): HTTP method (i.e. 'POST', 'GET', etc.). - url (:obj:`str`): The request's URL. request_data (:class:`telegram.request.RequestData`, optional): An object containing information about parameters and files to upload for the request. - read_timeout: Timeout for waiting to server's response. + read_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a response from Telegram's server instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a write operation to complete (in terms of + a network socket; i.e. POSTing a request or uploading a file) instead of the time + specified during creating of this object. Defaults to :attr:`DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the + maximum amount of time (in seconds) to wait for a connection attempt to a server + to succeed instead of the time specified during creating of this object. Defaults + to :attr:`DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a connection to become available instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. Returns: bytes: The payload part of the HTTP server response. @@ -263,7 +293,7 @@ async def _request_wrapper( else: message = 'Unknown HTTPError' - # In some special cases, we ca raise more informative exceptions: + # In some special cases, we can raise more informative exceptions: # see https://core.telegram.org/bots/api#responseparameters and # https://core.telegram.org/bots/api#making-requests parameters = response_data.get('parameters') @@ -332,12 +362,22 @@ async def do_request( method (:obj:`str`): HTTP method (i.e. ``'POST'``, ``'GET'``, etc.). request_data (:class:`telegram.request.RequestData`, optional): An object containing information about parameters and files to upload for the request. - read_timeout (:obj:`float`, optional): If this value is specified, use it as the read - timeout from the server (instead of the one specified during creation of the - connection pool). - write_timeout (:obj:`float`, optional): If this value is specified, use it as the write - timeout from the server (instead of the one specified during creation of the - connection pool). + read_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a response from Telegram's server instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. + write_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a write operation to complete (in terms of + a network socket; i.e. POSTing a request or uploading a file) instead of the time + specified during creating of this object. Defaults to :attr:`DEFAULT_NONE`. + connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the + maximum amount of time (in seconds) to wait for a connection attempt to a server + to succeed instead of the time specified during creating of this object. Defaults + to :attr:`DEFAULT_NONE`. + pool_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a connection to become available instead + of the time specified during creating of this object. Defaults to + :attr:`DEFAULT_NONE`. Returns: Tuple[:obj:`int`, :obj:`bytes`]: The HTTP return code & the payload part of the server diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 4e243bf367b..ddbfe8bab10 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -1,20 +1,21 @@ +#!/usr/bin/env python # -# A library that provides a Python interface to the Telegram Bot API -# Copyright (C) 2015-2022 -# Leandro Toledo de Souza +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza # -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. # -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser Public License for more details. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. # -# You should have received a copy of the GNU Lesser Public License -# along with this program. If not, see [http://www.gnu.org/licenses/]. +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains methods to make POST and GET requests using the httpx library.""" import logging from typing import Tuple, Optional @@ -51,25 +52,26 @@ class HTTPXRequest(BaseRequest): Note: * The proxy URL can also be set via the environment variables ``HTTPS_PROXY`` or - ``ALL_PROXY``. See `the docs`_ of ``httpx`` for more info. + ``ALL_PROXY``. See `the docs of httpx`_ for more info. * For Socks5 support, additional dependencies are required. Make sure to install - PTB via ``pip install python-telegram-bot[socks]`` in this case. + PTB via :command:`pip install python-telegram-bot[socks]` in this case. * Socks5 proxies can not be set via environment variables. - .. _the docs: https://www.python-httpx.org/environment_variables/#proxies - connect_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait - for a connection attempt to a server to succeed. :obj:`None` will set an infinite - timeout for connection attempts. Defaults to ``5.0``. - read_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait for - a response from Telegram's server. :obj:`None` will set an infinite timeout. This value - is usually overridden by the various methods of :class:`telegram.Bot`. Defaults to - ``5.0``. - write_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait for - a write operation to complete (in terms of a network socket; i.e. POSTing a request or - uploading a file).:obj:`None` will set an infinite timeout. Defaults to ``5.0``. - pool_timeout (:obj:`float`, optional): The maximum amount of time (in seconds) to wait for - a connection from the connection pool becoming available. :obj:`None` will set an - infinite timeout. Defaults to :obj:`None`. + .. _the docs of httpx: https://www.python-httpx.org/environment_variables/#proxies + read_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a response from Telegram's server instead + of the time specified during creating of this object. Defaults to ``5``. + write_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a write operation to complete (in terms of + a network socket; i.e. POSTing a request or uploading a file) instead of the time + specified during creating of this object. Defaults to ``5``. + connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the + maximum amount of time (in seconds) to wait for a connection attempt to a server + to succeed instead of the time specified during creating of this object. Defaults + to ``5``. + pool_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum + amount of time (in seconds) to wait for a connection to become available instead + of the time specified during creating of this object. Defaults to ``1``. Warning: With a finite pool timeout, you must expect :exc:`telegram.error.TimedOut` @@ -136,6 +138,8 @@ async def do_request( if self._client.is_closed: raise RuntimeError('This HTTPXRequest is not initialized!') + # If user did not specify timeouts (for e.g. in a bot method), use the default ones when we + # created this instance. if isinstance(read_timeout, DefaultValue): read_timeout = self._client.timeout.read if isinstance(write_timeout, DefaultValue): diff --git a/telegram/request/_requestdata.py b/telegram/request/_requestdata.py index 93fba9e6608..f541b75392e 100644 --- a/telegram/request/_requestdata.py +++ b/telegram/request/_requestdata.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains an class that holds a parameters of a request to the Bot API.""" +"""This module contains a class that holds the parameters of a request to the Bot API.""" from typing import List, Dict, Any, Union from urllib.parse import urlencode @@ -47,10 +47,7 @@ class RequestData: __slots__ = ('_parameters', 'contains_files') - def __init__( - self, - parameters: List[RequestParameter] = None, - ): + def __init__(self, parameters: List[RequestParameter] = None): self._parameters = parameters or [] self.contains_files = any(param.input_files for param in self._parameters) diff --git a/telegram/request/_requestparameter.py b/telegram/request/_requestparameter.py index da20cabe700..e6764c2302f 100644 --- a/telegram/request/_requestparameter.py +++ b/telegram/request/_requestparameter.py @@ -16,7 +16,7 @@ # # You should have received a copy of the GNU Lesser Public License # along with this program. If not, see [http://www.gnu.org/licenses/]. -"""This module contains an class that describes a single parameter of a request to the Bot API.""" +"""This module contains a class that describes a single parameter of a request to the Bot API.""" from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -47,7 +47,7 @@ class RequestParameter: Args: name (:obj:`str`): The name of the parameter. value (:obj:`object`): The value of the parameter. Must be JSON-dumpable. - input_files (List[:class:`telegram.InputFile`, optional): A list of files that should be + input_files (List[:class:`telegram.InputFile`], optional): A list of files that should be uploaded along with this parameter. Attributes: diff --git a/tests/test_bot.py b/tests/test_bot.py index 23bbdfe1f58..30c61d15d9e 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -2438,7 +2438,7 @@ class OkException(BaseException): pass async def do_request(*args, **kwargs): - obj = kwargs.get('read_timeout') + obj = kwargs.get('write_timeout') if obj == 20: raise OkException From e8fdd13598a1dd3b59acdc2810d6e738169c349c Mon Sep 17 00:00:00 2001 From: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 08:35:30 +0200 Subject: [PATCH 119/153] Asyncio Tests (#2936) Co-authored-by: Harshil <37377066+harshil21@users.noreply.github.com> --- telegram/error.py | 3 +- telegram/ext/_utils/trackingdict.py | 2 +- tests/conftest.py | 2 + tests/test_application.py | 45 + tests/test_applicationbuilder.py | 5 + tests/test_basepersistence.py | 385 ++++- tests/test_bot.py | 65 +- tests/test_callbackcontext.py | 2 - tests/test_chat.py | 5 + tests/test_chatjoinrequest.py | 2 +- tests/test_chatjoinrequesthandler.py | 2 +- tests/test_chatphoto.py | 5 + tests/test_conversationhandler.py | 2168 ++++++++++++++++++++++++++ tests/test_dictpersistence.py | 403 +++++ tests/test_document.py | 5 + tests/test_error.py | 18 +- tests/test_gamehighscore.py | 2 + tests/test_inlinequery.py | 5 + tests/test_inlinequeryhandler.py | 7 + tests/test_message.py | 45 +- tests/test_messagehandler.py | 1 + tests/test_photo.py | 5 + tests/test_picklepersistence.py | 1015 ++++++++++++ tests/test_poll.py | 12 + tests/test_precheckoutquery.py | 6 + tests/test_replykeyboardmarkup.py | 12 + tests/test_requestdata.py | 5 + tests/test_requestparameter.py | 6 + tests/test_shippingquery.py | 6 + tests/test_stringregexhandler.py | 9 +- tests/test_trackingdict.py | 12 + tests/test_user.py | 5 + tests/test_video.py | 5 + tests/test_videonote.py | 5 + tests/test_voice.py | 5 + tests/test_voicechat.py | 16 +- 36 files changed, 4197 insertions(+), 104 deletions(-) create mode 100644 tests/test_conversationhandler.py create mode 100644 tests/test_dictpersistence.py create mode 100644 tests/test_picklepersistence.py diff --git a/telegram/error.py b/telegram/error.py index f12c38aeb93..5b95ff3f684 100644 --- a/telegram/error.py +++ b/telegram/error.py @@ -74,9 +74,8 @@ def __init__(self, message: str): def __str__(self) -> str: return self.message - # TODO: test this def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.message})' + return f"{self.__class__.__name__}('{self.message}')" def __reduce__(self) -> Tuple[type, Tuple[str]]: return self.__class__, (self.message,) diff --git a/telegram/ext/_utils/trackingdict.py b/telegram/ext/_utils/trackingdict.py index 4086a5410c1..f14c6baa52c 100644 --- a/telegram/ext/_utils/trackingdict.py +++ b/telegram/ext/_utils/trackingdict.py @@ -57,7 +57,7 @@ class TrackingDict(UserDict, Generic[_KT, _VT]): DELETED: ClassVar = object() """Special marker indicating that an entry was deleted.""" - __slots__ = ('_data', '_write_access_keys') + __slots__ = ('_write_access_keys',) def __init__(self) -> None: super().__init__() diff --git a/tests/conftest.py b/tests/conftest.py index 120444a05d3..e1903b45876 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -155,6 +155,7 @@ class DictApplication(Application): @pytest.fixture(scope='session') @pytest.mark.asyncio async def bot(bot_info): + """Makes an ExtBot instance with the given bot_info""" async with make_bot(bot_info) as _bot: yield _bot @@ -162,6 +163,7 @@ async def bot(bot_info): @pytest.fixture(scope='session') @pytest.mark.asyncio async def raw_bot(bot_info): + """Makes an regular Bot instance with the given bot_info""" async with DictBot( bot_info['token'], private_key=PRIVATE_KEY, diff --git a/tests/test_application.py b/tests/test_application.py index 4d333f93254..cd76034becc 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1609,3 +1609,48 @@ def test_run_without_updater(self, bot): with pytest.raises(RuntimeError, match='only available if the application has an Updater'): app.run_polling() + + @pytest.mark.parametrize('method', ['start', 'initialize']) + def test_run_error_in_application(self, bot, monkeypatch, method): + shutdowns = [] + + async def raise_method(*args, **kwargs): + raise RuntimeError('Test Exception') + + async def shutdown(*args, **kwargs): + shutdowns.append(True) + + monkeypatch.setattr(Application, method, raise_method) + monkeypatch.setattr(Application, 'shutdown', shutdown) + monkeypatch.setattr(Updater, 'shutdown', shutdown) + app = ApplicationBuilder().token(bot.token).build() + with pytest.raises(RuntimeError, match='Test Exception'): + app.run_polling(close_loop=False) + + assert not app.running + assert not app.updater.running + assert shutdowns == [True, True] + + @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) + def test_run_error_in_updater(self, bot, monkeypatch, method): + shutdowns = [] + + async def raise_method(*args, **kwargs): + raise RuntimeError('Test Exception') + + async def shutdown(*args, **kwargs): + shutdowns.append(True) + + monkeypatch.setattr(Updater, method, raise_method) + monkeypatch.setattr(Application, 'shutdown', shutdown) + monkeypatch.setattr(Updater, 'shutdown', shutdown) + app = ApplicationBuilder().token(bot.token).build() + with pytest.raises(RuntimeError, match='Test Exception'): + if 'polling' in method: + app.run_polling(close_loop=False) + else: + app.run_webhook(close_loop=False) + + assert not app.running + assert not app.updater.running + assert shutdowns == [True, True] diff --git a/tests/test_applicationbuilder.py b/tests/test_applicationbuilder.py index 6ee13d1605d..3b0c3a09248 100644 --- a/tests/test_applicationbuilder.py +++ b/tests/test_applicationbuilder.py @@ -44,6 +44,11 @@ def builder(): class TestApplicationBuilder: + def test_slot_behaviour(self, builder, mro_slots): + for attr in builder.__slots__: + assert getattr(builder, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(builder)) == len(set(mro_slots(builder))), "duplicate slot" + def test_build_without_token(self, builder): with pytest.raises(RuntimeError, match='No bot token was set.'): builder.build() diff --git a/tests/test_basepersistence.py b/tests/test_basepersistence.py index 0a7cab203a5..6d83c396764 100644 --- a/tests/test_basepersistence.py +++ b/tests/test_basepersistence.py @@ -29,7 +29,7 @@ import pytest from flaky import flaky -from telegram import User, Chat, InlineKeyboardMarkup, InlineKeyboardButton +from telegram import User, Chat, InlineKeyboardMarkup, InlineKeyboardButton, Bot, Update from telegram.ext import ( ApplicationBuilder, PersistenceInput, @@ -209,10 +209,10 @@ def build_update(state: HandlerStates, chat_id: int): return make_message_update(message=str(state.value), user=user, chat=chat) @classmethod - def build_handler(cls, state: HandlerStates): + def build_handler(cls, state: HandlerStates, callback=None): return MessageHandler( filters.Regex(f'^{state.value}$'), - functools.partial(cls.callback, state=state), + callback or functools.partial(cls.callback, state=state), ) @@ -229,7 +229,7 @@ class PappInput(NamedTuple): def build_papp( token: str, store_data: dict = None, update_interval: float = None, fill_data: bool = False ) -> Application: - store_data = PersistenceInput(**store_data) + store_data = PersistenceInput(**(store_data or {})) if update_interval is not None: persistence = TrackingPersistence( store_data=store_data, update_interval=update_interval, fill_data=fill_data @@ -302,13 +302,6 @@ class TestBasePersistence: """Tests basic behavior of BasePersistence and (most importantly) the integration of persistence into the Application.""" - # TODO: Test integration of the more intricate ConversationHandler things once CH itself is - # tested. This includes: - # * pending states, i.e. non-blocking handlers - # * pending states being unresolved on shutdown - # * conversation timeouts - # * nested conversations (can conversations be persistent if their parents aren't?) - def job_callback(self, chat_id: int = None): async def callback(context): if context.user_data: @@ -355,7 +348,6 @@ def test_slot_behaviour(self, mro_slots): assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" # We're interested in BasePersistence, not in the implementation slots = mro_slots(inst, only_parents=True) - print(slots) assert len(slots) == len(set(slots)), "duplicate slot" @pytest.mark.parametrize('bot_data', (True, False)) @@ -388,6 +380,16 @@ def test_abstract_methods(self): ): BasePersistence() + @default_papp + def test_update_interval_immutable(self, papp): + with pytest.raises(AttributeError, match='can not assign a new value to update_interval'): + papp.persistence.update_interval = 7 + + @default_papp + def test_set_bot_error(self, papp): + with pytest.raises(TypeError, match='when using telegram.ext.ExtBot'): + papp.persistence.set_bot(Bot(papp.bot.token)) + def test_construction_with_bad_persistence(self, caplog, bot): class MyPersistence: def __init__(self): @@ -1129,3 +1131,362 @@ async def raise_error(*args, **kwargs): else: assert not papp.persistence.updated_chat_ids assert not papp.persistence.updated_conversations + + @pytest.mark.asyncio + async def test_non_blocking_conversations(self, bot): + papp = build_papp(token=bot.token) + event = asyncio.Event() + + async def callback(_, __): + await event.wait() + return HandlerStates.STATE_1 + + conversation = ConversationHandler( + entry_points=[ + TrackingConversationHandler.build_handler(HandlerStates.END, callback=callback) + ], + states={}, + fallbacks=[], + persistent=True, + name='conv', + block=False, + ) + papp.add_handler(conversation) + + async with papp: + assert papp.persistence.updated_conversations == {} + + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.END, 1) + ) + assert papp.persistence.updated_conversations == {} + + await papp.update_persistence() + await asyncio.sleep(0.01) + # Conversation should have been updated with the current state, i.e. None + assert papp.persistence.updated_conversations == {'conv': ({(1, 1): 1})} + assert papp.persistence.conversations == {'conv': {(1, 1): None}} + + papp.persistence.reset_tracking() + event.set() + await asyncio.sleep(0.01) + await papp.update_persistence() + assert papp.persistence.updated_conversations == {'conv': {(1, 1): 1}} + assert papp.persistence.conversations == {'conv': {(1, 1): HandlerStates.STATE_1}} + + @pytest.mark.asyncio + async def test_non_blocking_conversations_raises_Exception(self, bot): + papp = build_papp(token=bot.token) + + async def callback_1(_, __): + return HandlerStates.STATE_1 + + async def callback_2(_, __): + raise Exception('Test Exception') + + conversation = ConversationHandler( + entry_points=[ + TrackingConversationHandler.build_handler(HandlerStates.END, callback=callback_1) + ], + states={ + HandlerStates.STATE_1: [ + TrackingConversationHandler.build_handler( + HandlerStates.STATE_1, callback=callback_2 + ) + ] + }, + fallbacks=[], + persistent=True, + name='conv', + block=False, + ) + papp.add_handler(conversation) + + async with papp: + assert papp.persistence.updated_conversations == {} + + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.END, 1) + ) + assert papp.persistence.updated_conversations == {} + + await papp.update_persistence() + await asyncio.sleep(0.05) + assert papp.persistence.updated_conversations == {'conv': ({(1, 1): 1})} + # The result of the pending state wasn't retrieved by the CH yet, so we must be in + # state `None` + assert papp.persistence.conversations == {'conv': {(1, 1): None}} + + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, 1) + ) + + papp.persistence.reset_tracking() + await asyncio.sleep(0.01) + await papp.update_persistence() + assert papp.persistence.updated_conversations == {'conv': {(1, 1): 1}} + # since the second callback raised an exception, the state must be the previous one! + assert papp.persistence.conversations == {'conv': {(1, 1): HandlerStates.STATE_1}} + + @pytest.mark.asyncio + async def test_non_blocking_conversations_on_stop(self, bot): + papp = build_papp(token=bot.token, update_interval=100) + event = asyncio.Event() + + async def callback(_, __): + await event.wait() + return HandlerStates.STATE_1 + + conversation = ConversationHandler( + entry_points=[ + TrackingConversationHandler.build_handler(HandlerStates.END, callback=callback) + ], + states={}, + fallbacks=[], + persistent=True, + name='conv', + block=False, + ) + papp.add_handler(conversation) + + await papp.initialize() + assert papp.persistence.updated_conversations == {} + await papp.start() + + await papp.process_update(TrackingConversationHandler.build_update(HandlerStates.END, 1)) + assert papp.persistence.updated_conversations == {} + + stop_task = asyncio.create_task(papp.stop()) + assert not stop_task.done() + event.set() + await asyncio.sleep(0.5) + assert stop_task.done() + assert papp.persistence.updated_conversations == {} + + await papp.shutdown() + await asyncio.sleep(0.01) + # The pending state must have been resolved on shutdown! + assert papp.persistence.updated_conversations == {'conv': {(1, 1): 1}} + assert papp.persistence.conversations == {'conv': {(1, 1): HandlerStates.STATE_1}} + + @pytest.mark.asyncio + async def test_non_blocking_conversations_on_improper_stop(self, bot, caplog): + papp = build_papp(token=bot.token, update_interval=100) + event = asyncio.Event() + + async def callback(_, __): + await event.wait() + return HandlerStates.STATE_1 + + conversation = ConversationHandler( + entry_points=[ + TrackingConversationHandler.build_handler(HandlerStates.END, callback=callback) + ], + states={}, + fallbacks=[], + persistent=True, + name='conv', + block=False, + ) + papp.add_handler(conversation) + + await papp.initialize() + assert papp.persistence.updated_conversations == {} + + await papp.process_update(TrackingConversationHandler.build_update(HandlerStates.END, 1)) + assert papp.persistence.updated_conversations == {} + + with caplog.at_level(logging.WARNING): + await papp.shutdown() + await asyncio.sleep(0.01) + # Because the app wasn't running, the pending state isn't ensured to be done on + # shutdown - hence we expect the persistence to be updated with state `None` + assert papp.persistence.updated_conversations == {'conv': {(1, 1): 1}} + assert papp.persistence.conversations == {'conv': {(1, 1): None}} + + # Ensure that we warn the user about this! + found_record = None + for record in caplog.records: + if record.getMessage().startswith('A ConversationHandlers state was not yet resolved'): + found_record = record + break + assert found_record is not None + + @default_papp + @pytest.mark.asyncio + async def test_conversation_ends(self, papp): + async with papp: + assert papp.persistence.updated_conversations == {} + + for state in HandlerStates: + await papp.process_update(TrackingConversationHandler.build_update(state, 1)) + assert papp.persistence.updated_conversations == {} + + await papp.update_persistence() + assert papp.persistence.updated_conversations == {'conv_1': ({(1, 1): 1})} + # This is the important part: the persistence is updated with `None` when the conv ends + assert papp.persistence.conversations == {'conv_1': {(1, 1): None}} + + @pytest.mark.asyncio + async def test_conversation_timeout(self, bot): + # high update_interval so that we can instead manually call it + papp = build_papp(token=bot.token, update_interval=150) + + async def callback(_, __): + return HandlerStates.STATE_1 + + conversation = ConversationHandler( + entry_points=[ + TrackingConversationHandler.build_handler(HandlerStates.END, callback=callback) + ], + states={HandlerStates.STATE_1: []}, + fallbacks=[], + persistent=True, + name='conv', + conversation_timeout=3, + ) + papp.add_handler(conversation) + + async with papp: + await papp.start() + assert papp.persistence.updated_conversations == {} + + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.END, 1) + ) + assert papp.persistence.updated_conversations == {} + await papp.update_persistence() + assert papp.persistence.updated_conversations == {'conv': ({(1, 1): 1})} + assert papp.persistence.conversations == {'conv': {(1, 1): HandlerStates.STATE_1}} + + papp.persistence.reset_tracking() + await asyncio.sleep(4) + # After the timeout the conversation should run the entry point again … + assert conversation.check_update( + TrackingConversationHandler.build_update(HandlerStates.END, 1) + ) + await papp.update_persistence() + # … and persistence should be updated with `None` + assert papp.persistence.updated_conversations == {'conv': {(1, 1): 1}} + assert papp.persistence.conversations == {'conv': {(1, 1): None}} + + await papp.stop() + + @pytest.mark.asyncio + async def test_persistent_nested_conversations(self, bot): + papp = build_papp(token=bot.token, update_interval=150) + + def build_callback( + state: HandlerStates, + ): + async def callback(_: Update, __: CallbackContext) -> HandlerStates: + return state + + return callback + + grand_child = ConversationHandler( + entry_points=[TrackingConversationHandler.build_handler(HandlerStates.END)], + states={ + HandlerStates.STATE_1: [ + TrackingConversationHandler.build_handler( + HandlerStates.STATE_1, callback=build_callback(HandlerStates.END) + ) + ] + }, + fallbacks=[], + persistent=True, + name='grand_child', + map_to_parent={HandlerStates.END: HandlerStates.STATE_2}, + ) + + child = ConversationHandler( + entry_points=[TrackingConversationHandler.build_handler(HandlerStates.END)], + states={ + HandlerStates.STATE_1: [grand_child], + HandlerStates.STATE_2: [ + TrackingConversationHandler.build_handler(HandlerStates.STATE_2) + ], + }, + fallbacks=[], + persistent=True, + name='child', + map_to_parent={HandlerStates.STATE_3: HandlerStates.STATE_2}, + ) + + parent = ConversationHandler( + entry_points=[TrackingConversationHandler.build_handler(HandlerStates.END)], + states={ + HandlerStates.STATE_1: [child], + HandlerStates.STATE_2: [ + TrackingConversationHandler.build_handler( + HandlerStates.STATE_2, callback=build_callback(HandlerStates.END) + ) + ], + }, + fallbacks=[], + persistent=True, + name='parent', + ) + + papp.add_handler(parent) + papp.persistence.conversations['grand_child'][(1, 1)] = HandlerStates.STATE_1 + papp.persistence.conversations['child'][(1, 1)] = HandlerStates.STATE_1 + papp.persistence.conversations['parent'][(1, 1)] = HandlerStates.STATE_1 + + # Should load the stored data into the persistence so that the updates below are handled + # accordingly + await papp.initialize() + assert papp.persistence.updated_conversations == {} + + assert not parent.check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_2, 1) + ) + assert not parent.check_update( + TrackingConversationHandler.build_update(HandlerStates.END, 1) + ) + assert parent.check_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, 1) + ) + + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_1, 1) + ) + assert papp.persistence.updated_conversations == {} + await papp.update_persistence() + assert papp.persistence.updated_conversations == { + 'grand_child': {(1, 1): 1}, + 'child': {(1, 1): 1}, + } + assert papp.persistence.conversations == { + 'grand_child': {(1, 1): None}, + 'child': {(1, 1): HandlerStates.STATE_2}, + 'parent': {(1, 1): HandlerStates.STATE_1}, + } + + papp.persistence.reset_tracking() + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_2, 1) + ) + await papp.update_persistence() + assert papp.persistence.updated_conversations == { + 'parent': {(1, 1): 1}, + 'child': {(1, 1): 1}, + } + assert papp.persistence.conversations == { + 'child': {(1, 1): None}, + 'parent': {(1, 1): HandlerStates.STATE_2}, + } + + papp.persistence.reset_tracking() + await papp.process_update( + TrackingConversationHandler.build_update(HandlerStates.STATE_2, 1) + ) + await papp.update_persistence() + assert papp.persistence.updated_conversations == { + 'parent': {(1, 1): 1}, + } + assert papp.persistence.conversations == { + 'parent': {(1, 1): None}, + } + + await papp.shutdown() diff --git a/tests/test_bot.py b/tests/test_bot.py index 30c61d15d9e..eba5573bb51 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -80,16 +80,6 @@ def to_camel_case(snake_str): return components[0] + ''.join(x.title() for x in components[1:]) -class ExtBotSubClass(ExtBot): - # used for test_defaults_warning below - pass - - -class BotSubClass(Bot): - # used for test_defaults_warning below - pass - - @pytest.fixture(scope='class') @pytest.mark.asyncio async def message(bot, chat_id): @@ -142,16 +132,6 @@ def inline_results(): ) -@pytest.fixture(scope='function') -@pytest.mark.asyncio -async def inst(request, bot_info, default_bot): - if request.param == 'bot': - async with Bot(bot_info['token']) as _bot: - yield _bot - else: - yield default_bot - - class TestBot: """ Most are executed on tg.ext.ExtBot, as that class only extends the functionality of tg.bot @@ -491,49 +471,6 @@ async def make_assertion(url, request_data: RequestData, *args, **kwargs): finally: await bot.get_me() # because running the mock-get_me messages with bot.bot & friends - # check that tg.Bot does the right thing - # make_assertion basically checks everything that happens in - # Bot._insert_defaults and Bot._insert_defaults_for_ilq_results - async def make_assertion(url, request_data: RequestData, *args, **kwargs): - data = request_data.parameters - - # Check regular kwargs - for k, v in data.items(): - if isinstance(v, DefaultValue): - pytest.fail(f'Parameter {k} was passed as DefaultValue to request') - elif isinstance(v, InputMedia) and isinstance(v.parse_mode, DefaultValue): - pytest.fail(f'Parameter {k} has a DefaultValue parse_mode') - - # Check InputMedia - elif k == 'media' and isinstance(v, list): - for med in v: - if isinstance(med.get('parse_mode', None), DefaultValue): - pytest.fail('One of the media items has a DefaultValue parse_mode') - - # Check inline query results - if bot_method_name.lower().replace('_', '') == 'answerinlinequery': - for result_dict in data['results']: - if isinstance(result_dict.get('parse_mode'), DefaultValue): - pytest.fail('InlineQueryResult has DefaultValue parse_mode') - imc = result_dict.get('input_message_content') - if imc and isinstance(imc.get('parse_mode'), DefaultValue): - pytest.fail( - 'InlineQueryResult is InputMessageContext with DefaultValue parse_mode' - ) - if imc and isinstance(imc.get('disable_web_page_preview'), DefaultValue): - pytest.fail( - 'InlineQueryResult is InputMessageContext with DefaultValue ' - 'disable_web_page_preview ' - ) - # Check datetime conversion - until_date = data.pop('until_date', None) - if until_date and until_date != 946684800: - pytest.fail('Naive until_date was not interpreted as UTC') - - if bot_method_name in ['get_file', 'getFile']: - # The get_file methods try to check if the result is a local file - return File(file_id='result', file_unique_id='result').to_dict() - method = getattr(raw_bot, bot_method_name) signature = inspect.signature(method) kwargs_need_default = [ @@ -3059,7 +2996,7 @@ def test_camel_case_bot(self): if ( function_name.startswith("_") or not callable(function) - or function_name in ["to_dict", "do_init", "do_teardown"] + or function_name in ["to_dict"] ): continue camel_case_function = getattr(Bot, to_camel_case(function_name), False) diff --git a/tests/test_callbackcontext.py b/tests/test_callbackcontext.py index 70b2f0a0492..735b694c181 100644 --- a/tests/test_callbackcontext.py +++ b/tests/test_callbackcontext.py @@ -44,8 +44,6 @@ def test_slot_behaviour(self, app, mro_slots, recwarn): assert getattr(c, attr, 'err') != 'err', f"got extra slot '{attr}'" assert not c.__dict__, f"got missing slot(s): {c.__dict__}" assert len(mro_slots(c)) == len(set(mro_slots(c))), "duplicate slot" - c.args = c.args - assert len(recwarn) == 0, recwarn.list def test_from_job(self, app): job = app.job_queue.run_once(lambda x: x, 10) diff --git a/tests/test_chat.py b/tests/test_chat.py index e3eef6f5dee..eaa5cf4de53 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -65,6 +65,11 @@ class TestChat: has_protected_content = True has_private_forwards = True + def test_slot_behaviour(self, chat, mro_slots): + for attr in chat.__slots__: + assert getattr(chat, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(chat)) == len(set(mro_slots(chat))), "duplicate slot" + def test_de_json(self, bot): json_dict = { 'id': self.id_, diff --git a/tests/test_chatjoinrequest.py b/tests/test_chatjoinrequest.py index c5a53398d49..16e7adcbb9a 100644 --- a/tests/test_chatjoinrequest.py +++ b/tests/test_chatjoinrequest.py @@ -56,7 +56,7 @@ class TestChatJoinRequest: is_primary=False, ) - def test_slot_behaviour(self, chat_join_request, recwarn, mro_slots): + def test_slot_behaviour(self, chat_join_request, mro_slots): inst = chat_join_request for attr in inst.__slots__: assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" diff --git a/tests/test_chatjoinrequesthandler.py b/tests/test_chatjoinrequesthandler.py index ccdee344bb2..9bfb8432eb2 100644 --- a/tests/test_chatjoinrequesthandler.py +++ b/tests/test_chatjoinrequesthandler.py @@ -102,7 +102,7 @@ def chat_join_request_update(bot, chat_join_request): class TestChatJoinRequestHandler: test_flag = False - def test_slot_behaviour(self, recwarn, mro_slots): + def test_slot_behaviour(self, mro_slots): action = ChatJoinRequestHandler(self.callback) for attr in action.__slots__: assert getattr(action, attr, 'err') != 'err', f"got extra slot '{attr}'" diff --git a/tests/test_chatphoto.py b/tests/test_chatphoto.py index 9fd3ce954b8..2c774057594 100644 --- a/tests/test_chatphoto.py +++ b/tests/test_chatphoto.py @@ -60,6 +60,11 @@ class TestChatPhoto: chatphoto_big_file_unique_id = 'bigadc3145fd2e84d95b64d68eaa22aa33e' chatphoto_file_url = 'https://python-telegram-bot.org/static/testfiles/telegram.jpg' + def test_slot_behaviour(self, chat_photo, mro_slots): + for attr in chat_photo.__slots__: + assert getattr(chat_photo, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(chat_photo)) == len(set(mro_slots(chat_photo))), "duplicate slot" + @flaky(3, 1) @pytest.mark.asyncio async def test_send_all_args( diff --git a/tests/test_conversationhandler.py b/tests/test_conversationhandler.py new file mode 100644 index 00000000000..2827eac5ced --- /dev/null +++ b/tests/test_conversationhandler.py @@ -0,0 +1,2168 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""Persistence of conversations is tested in test_basepersistence.py""" +import asyncio +import logging +from warnings import filterwarnings + +import pytest +from flaky import flaky + +from telegram import ( + Chat, + Update, + Message, + MessageEntity, + User, + CallbackQuery, + InlineQuery, + ChosenInlineResult, + ShippingQuery, + PreCheckoutQuery, + Bot, +) +from telegram.ext import ( + ConversationHandler, + CommandHandler, + ApplicationHandlerStop, + TypeHandler, + CallbackContext, + CallbackQueryHandler, + MessageHandler, + filters, + JobQueue, + StringCommandHandler, + StringRegexHandler, + PollHandler, + ShippingQueryHandler, + PreCheckoutQueryHandler, + InlineQueryHandler, + PollAnswerHandler, + ChosenInlineResultHandler, + Defaults, + ApplicationBuilder, + ExtBot, +) +from telegram.warnings import PTBUserWarning +from tests.conftest import make_command_message + + +@pytest.fixture(scope='class') +def user1(): + return User(first_name='Misses Test', id=123, is_bot=False) + + +@pytest.fixture(scope='class') +def user2(): + return User(first_name='Mister Test', id=124, is_bot=False) + + +def raise_ahs(func): + async def decorator(self, *args, **kwargs): + result = await func(self, *args, **kwargs) + if self.raise_app_handler_stop: + raise ApplicationHandlerStop(result) + return result + + return decorator + + +class TestConversationHandler: + """Persistence of conversations is tested in test_basepersistence.py""" + + # State definitions + # At first we're thirsty. Then we brew coffee, we drink it + # and then we can start coding! + END, THIRSTY, BREWING, DRINKING, CODING = range(-1, 4) + + # Drinking state definitions (nested) + # At first we're holding the cup. Then we sip coffee, and last we swallow it + HOLDING, SIPPING, SWALLOWING, REPLENISHING, STOPPING = map(chr, range(ord('a'), ord('f'))) + + current_state, entry_points, states, fallbacks = None, None, None, None + group = Chat(0, Chat.GROUP) + second_group = Chat(1, Chat.GROUP) + + raise_app_handler_stop = False + test_flag = False + + # Test related + @pytest.fixture(autouse=True) + def reset(self): + self.raise_app_handler_stop = False + self.test_flag = False + self.current_state = {} + self.entry_points = [CommandHandler('start', self.start)] + self.states = { + self.THIRSTY: [CommandHandler('brew', self.brew), CommandHandler('wait', self.start)], + self.BREWING: [CommandHandler('pourCoffee', self.drink)], + self.DRINKING: [ + CommandHandler('startCoding', self.code), + CommandHandler('drinkMore', self.drink), + CommandHandler('end', self.end), + ], + self.CODING: [ + CommandHandler('keepCoding', self.code), + CommandHandler('gettingThirsty', self.start), + CommandHandler('drinkMore', self.drink), + ], + } + self.fallbacks = [CommandHandler('eat', self.start)] + self.is_timeout = False + + # for nesting tests + self.nested_states = { + self.THIRSTY: [CommandHandler('brew', self.brew), CommandHandler('wait', self.start)], + self.BREWING: [CommandHandler('pourCoffee', self.drink)], + self.CODING: [ + CommandHandler('keepCoding', self.code), + CommandHandler('gettingThirsty', self.start), + CommandHandler('drinkMore', self.drink), + ], + } + self.drinking_entry_points = [CommandHandler('hold', self.hold)] + self.drinking_states = { + self.HOLDING: [CommandHandler('sip', self.sip)], + self.SIPPING: [CommandHandler('swallow', self.swallow)], + self.SWALLOWING: [CommandHandler('hold', self.hold)], + } + self.drinking_fallbacks = [ + CommandHandler('replenish', self.replenish), + CommandHandler('stop', self.stop), + CommandHandler('end', self.end), + CommandHandler('startCoding', self.code), + CommandHandler('drinkMore', self.drink), + ] + self.drinking_entry_points.extend(self.drinking_fallbacks) + + # Map nested states to parent states: + self.drinking_map_to_parent = { + # Option 1 - Map a fictional internal state to an external parent state + self.REPLENISHING: self.BREWING, + # Option 2 - Map a fictional internal state to the END state on the parent + self.STOPPING: self.END, + # Option 3 - Map the internal END state to an external parent state + self.END: self.CODING, + # Option 4 - Map an external state to the same external parent state + self.CODING: self.CODING, + # Option 5 - Map an external state to the internal entry point + self.DRINKING: self.DRINKING, + } + + # State handlers + def _set_state(self, update, state): + self.current_state[update.message.from_user.id] = state + return state + + # Actions + @raise_ahs + async def start(self, update, context): + if isinstance(update, Update): + return self._set_state(update, self.THIRSTY) + return self._set_state(context.bot, self.THIRSTY) + + @raise_ahs + async def end(self, update, context): + return self._set_state(update, self.END) + + @raise_ahs + async def start_end(self, update, context): + return self._set_state(update, self.END) + + @raise_ahs + async def start_none(self, update, context): + return self._set_state(update, None) + + @raise_ahs + async def brew(self, update, context): + if isinstance(update, Update): + return self._set_state(update, self.BREWING) + return self._set_state(context.bot, self.BREWING) + + @raise_ahs + async def drink(self, update, context): + return self._set_state(update, self.DRINKING) + + @raise_ahs + async def code(self, update, context): + return self._set_state(update, self.CODING) + + @raise_ahs + async def passout(self, update, context): + assert update.message.text == '/brew' + assert isinstance(update, Update) + self.is_timeout = True + + @raise_ahs + async def passout2(self, update, context): + assert isinstance(update, Update) + self.is_timeout = True + + @raise_ahs + async def passout_context(self, update, context): + assert update.message.text == '/brew' + assert isinstance(context, CallbackContext) + self.is_timeout = True + + @raise_ahs + async def passout2_context(self, update, context): + assert isinstance(context, CallbackContext) + self.is_timeout = True + + # Drinking actions (nested) + + @raise_ahs + async def hold(self, update, context): + return self._set_state(update, self.HOLDING) + + @raise_ahs + async def sip(self, update, context): + return self._set_state(update, self.SIPPING) + + @raise_ahs + async def swallow(self, update, context): + return self._set_state(update, self.SWALLOWING) + + @raise_ahs + async def replenish(self, update, context): + return self._set_state(update, self.REPLENISHING) + + @raise_ahs + async def stop(self, update, context): + return self._set_state(update, self.STOPPING) + + def test_slot_behaviour(self, mro_slots): + handler = ConversationHandler(entry_points=[], states={}, fallbacks=[]) + for attr in handler.__slots__: + assert getattr(handler, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(handler)) == len(set(mro_slots(handler))), "duplicate slot" + + def test_init(self): + entry_points = [] + states = {} + fallbacks = [] + map_to_parent = {} + ch = ConversationHandler( + entry_points=entry_points, + states=states, + fallbacks=fallbacks, + per_chat='per_chat', + per_user='per_user', + per_message='per_message', + persistent='persistent', + name='name', + allow_reentry='allow_reentry', + conversation_timeout=42, + map_to_parent=map_to_parent, + ) + assert ch.entry_points is entry_points + assert ch.states is states + assert ch.fallbacks is fallbacks + assert ch.map_to_parent is map_to_parent + assert ch.per_chat == 'per_chat' + assert ch.per_user == 'per_user' + assert ch.per_message == 'per_message' + assert ch.persistent == 'persistent' + assert ch.name == 'name' + assert ch.allow_reentry == 'allow_reentry' + + def test_init_persistent_no_name(self): + with pytest.raises(ValueError, match="can't be persistent when handler is unnamed"): + ConversationHandler( + self.entry_points, states=self.states, fallbacks=[], persistent=True + ) + + @pytest.mark.asyncio + async def test_check_update_returns_non(self, app, user1): + """checks some cases where updates should not be handled""" + conv_handler = ConversationHandler([], {}, [], per_message=True, per_chat=True) + assert not conv_handler.check_update('not an Update') + assert not conv_handler.check_update(Update(0)) + assert not conv_handler.check_update( + Update(0, callback_query=CallbackQuery('1', from_user=user1, chat_instance='1')) + ) + + @pytest.mark.asyncio + async def test_handlers_generate_warning(self, recwarn): + """this function tests all handler + per_* setting combinations.""" + + # the warning message action needs to be set to always, + # otherwise only the first occurrence will be issued + filterwarnings(action="always", category=PTBUserWarning) + + # this class doesn't do anything, its just not the Update class + class NotUpdate: + pass + + recwarn.clear() + + # this conversation handler has the string, string_regex, Pollhandler and TypeHandler + # which should all generate a warning no matter the per_* setting. TypeHandler should + # not when the class is Update + ConversationHandler( + entry_points=[StringCommandHandler("code", self.code)], + states={ + self.BREWING: [ + StringRegexHandler("code", self.code), + PollHandler(self.code), + TypeHandler(NotUpdate, self.code), + ], + }, + fallbacks=[TypeHandler(Update, self.code)], + ) + + # these handlers should all raise a warning when per_chat is True + ConversationHandler( + entry_points=[ShippingQueryHandler(self.code)], + states={ + self.BREWING: [ + InlineQueryHandler(self.code), + PreCheckoutQueryHandler(self.code), + PollAnswerHandler(self.code), + ], + }, + fallbacks=[ChosenInlineResultHandler(self.code)], + per_chat=True, + ) + + # the CallbackQueryHandler should *not* raise when per_message is True, + # but any other one should + ConversationHandler( + entry_points=[CallbackQueryHandler(self.code)], + states={ + self.BREWING: [CommandHandler("code", self.code)], + }, + fallbacks=[CallbackQueryHandler(self.code)], + per_message=True, + ) + + # the CallbackQueryHandler should raise when per_message is False + ConversationHandler( + entry_points=[CommandHandler("code", self.code)], + states={ + self.BREWING: [CommandHandler("code", self.code)], + }, + fallbacks=[CallbackQueryHandler(self.code)], + per_message=False, + ) + + # adding a nested conv to a conversation with timeout should warn + child = ConversationHandler( + entry_points=[CommandHandler("code", self.code)], + states={ + self.BREWING: [CommandHandler("code", self.code)], + }, + fallbacks=[CommandHandler("code", self.code)], + ) + + ConversationHandler( + entry_points=[CommandHandler("code", self.code)], + states={ + self.BREWING: [child], + }, + fallbacks=[CommandHandler("code", self.code)], + conversation_timeout=42, + ) + + # If per_message is True, per_chat should also be True, since msg ids are not unique + ConversationHandler( + entry_points=[CallbackQueryHandler(self.code, "code")], + states={ + self.BREWING: [CallbackQueryHandler(self.code, "code")], + }, + fallbacks=[CallbackQueryHandler(self.code, "code")], + per_message=True, + per_chat=False, + ) + + # the overall number of handlers throwing a warning is 13 + assert len(recwarn) == 13 + # now we test the messages, they are raised in the order they are inserted + # into the conversation handler + assert str(recwarn[0].message) == ( + "The `ConversationHandler` only handles updates of type `telegram.Update`. " + "StringCommandHandler handles updates of type `str`." + ) + assert str(recwarn[1].message) == ( + "The `ConversationHandler` only handles updates of type `telegram.Update`. " + "StringRegexHandler handles updates of type `str`." + ) + assert str(recwarn[2].message) == ( + "PollHandler will never trigger in a conversation since it has no information " + "about the chat or the user who voted in it. Do you mean the " + "`PollAnswerHandler`?" + ) + assert str(recwarn[3].message) == ( + "The `ConversationHandler` only handles updates of type `telegram.Update`. " + "The TypeHandler is set to handle NotUpdate." + ) + + per_faq_link = ( + " Read this FAQ entry to learn more about the per_* settings: " + "https://github.com/python-telegram-bot/python-telegram-bot/wiki" + "/Frequently-Asked-Questions#what-do-the-per_-settings-in-conversationhandler-do." + ) + + assert str(recwarn[4].message) == ( + "Updates handled by ShippingQueryHandler only have information about the user," + " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link + ) + assert str(recwarn[5].message) == ( + "Updates handled by ChosenInlineResultHandler only have information about the user," + " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link + ) + assert str(recwarn[6].message) == ( + "Updates handled by InlineQueryHandler only have information about the user," + " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link + ) + assert str(recwarn[7].message) == ( + "Updates handled by PreCheckoutQueryHandler only have information about the user," + " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link + ) + assert str(recwarn[8].message) == ( + "Updates handled by PollAnswerHandler only have information about the user," + " so this handler won't ever be triggered if `per_chat=True`." + per_faq_link + ) + assert str(recwarn[9].message) == ( + "If 'per_message=True', all entry points, state handlers, and fallbacks must be " + "'CallbackQueryHandler', since no other handlers have a message context." + + per_faq_link + ) + assert str(recwarn[10].message) == ( + "If 'per_message=False', 'CallbackQueryHandler' will not be tracked for " + "every message." + per_faq_link + ) + assert str(recwarn[11].message) == ( + "Using `conversation_timeout` with nested conversations is currently not " + "supported. You can still try to use it, but it will likely behave differently" + " from what you expect." + ) + + assert str(recwarn[12].message) == ( + "If 'per_message=True' is used, 'per_chat=True' should also be used, " + "since message IDs are not globally unique." + ) + + # this for loop checks if the correct stacklevel is used when generating the warning + for warning in recwarn: + assert warning.filename == __file__, "incorrect stacklevel!" + + @pytest.mark.parametrize( + 'attr', + [ + 'entry_points', + 'states', + 'fallbacks', + 'per_chat', + 'per_user', + 'per_message', + 'name', + 'persistent', + 'allow_reentry', + 'conversation_timeout', + 'map_to_parent', + ], + indirect=False, + ) + def test_immutable(self, attr): + ch = ConversationHandler(entry_points=[], states={}, fallbacks=[]) + with pytest.raises(AttributeError, match=f'You can not assign a new value to {attr}'): + setattr(ch, attr, True) + + def test_per_all_false(self): + with pytest.raises(ValueError, match="can't all be 'False'"): + ConversationHandler( + entry_points=[], + states={}, + fallbacks=[], + per_chat=False, + per_user=False, + per_message=False, + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize('raise_ahs', [True, False]) + async def test_basic_and_app_handler_stop(self, app, bot, user1, user2, raise_ahs): + handler = ConversationHandler( + entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks + ) + app.add_handler(handler) + + async def callback(_, __): + self.test_flag = True + + app.add_handler(TypeHandler(object, callback), group=100) + self.raise_app_handler_stop = raise_ahs + + # User one, starts the state machine. + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + assert self.test_flag == (not raise_ahs) + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.BREWING + assert self.test_flag == (not raise_ahs) + + # Lets see if an invalid command makes sure, no state is changed. + message.text = '/nothing' + message.entities[0].length = len('/nothing') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.BREWING + assert self.test_flag is True + self.test_flag = False + + # Lets see if the state machine still works by pouring coffee. + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert self.test_flag == (not raise_ahs) + + # Let's now verify that for another user, who did not start yet, + # the state has not been changed. + message.from_user = user2 + await app.process_update(Update(update_id=0, message=message)) + with pytest.raises(KeyError): + self.current_state[user2.id] + + @pytest.mark.asyncio + async def test_conversation_handler_end(self, caplog, app, bot, user1): + handler = ConversationHandler( + entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks + ) + app.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=0, message=message)) + message.text = '/end' + message.entities[0].length = len('/end') + caplog.clear() + with caplog.at_level(logging.ERROR): + await app.process_update(Update(update_id=0, message=message)) + assert len(caplog.records) == 0 + assert self.current_state[user1.id] == self.END + + # make sure that the conversation has ended by checking that the start command is + # accepted again + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(update_id=0, message=message)) + + @pytest.mark.asyncio + async def test_conversation_handler_fallback(self, app, bot, user1, user2): + handler = ConversationHandler( + entry_points=self.entry_points, states=self.states, fallbacks=self.fallbacks + ) + app.add_handler(handler) + + # first check if fallback will not trigger start when not started + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/eat', + entities=[MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/eat'))], + bot=bot, + ) + + async with app: + await app.process_update(Update(update_id=0, message=message)) + with pytest.raises(KeyError): + self.current_state[user1.id] + + # User starts the state machine. + message.text = '/start' + message.entities[0].length = len('/start') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.BREWING + + # Now a fallback command is issued + message.text = '/eat' + message.entities[0].length = len('/eat') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + + @pytest.mark.asyncio + async def test_unknown_state_warning(self, app, bot, user1, recwarn): + def build_callback(state): + async def callback(_, __): + return state + + return callback + + handler = ConversationHandler( + entry_points=[CommandHandler("start", build_callback(1))], + states={ + 1: [TypeHandler(Update, build_callback(69))], + 2: [TypeHandler(Update, build_callback(42))], + }, + fallbacks=self.fallbacks, + name="xyz", + ) + app.add_handler(handler) + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + try: + await app.process_update(Update(update_id=1, message=message)) + except Exception as exc: + print(exc) + raise exc + assert len(recwarn) == 1 + assert str(recwarn[0].message) == ( + "Handler returned state 69 which is unknown to the ConversationHandler xyz." + ) + + @pytest.mark.asyncio + async def test_conversation_handler_per_chat(self, app, bot, user1, user2): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + per_user=False, + ) + app.add_handler(handler) + + # User one, starts the state machine. + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.process_update(Update(update_id=0, message=message)) + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + + # Let's now verify that for another user, who did not start yet, + # the state will be changed because they are in the same group. + message.from_user = user2 + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=0, message=message)) + + # Check that we're in the DRINKING state by checking that the corresponding command + # is accepted + message.from_user = user1 + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + assert handler.check_update(Update(update_id=0, message=message)) + message.from_user = user2 + assert handler.check_update(Update(update_id=0, message=message)) + + @pytest.mark.asyncio + async def test_conversation_handler_per_user(self, app, bot, user1): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + per_chat=False, + ) + app.add_handler(handler) + + # User one, starts the state machine. + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + + # Let's now verify that for the same user in a different group, the state will still be + # updated + message.chat = self.second_group + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=0, message=message)) + + # Check that we're in the DRINKING state by checking that the corresponding command + # is accepted + message.chat = self.group + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + assert handler.check_update(Update(update_id=0, message=message)) + message.chat = self.second_group + assert handler.check_update(Update(update_id=0, message=message)) + + @pytest.mark.asyncio + @pytest.mark.parametrize('inline', [True, False]) + @pytest.mark.filterwarnings("ignore: If 'per_message=True' is used, 'per_chat=True'") + async def test_conversation_handler_per_message(self, app, bot, user1, user2, inline): + async def entry(update, context): + return 1 + + async def one(update, context): + return 2 + + async def two(update, context): + return ConversationHandler.END + + handler = ConversationHandler( + entry_points=[CallbackQueryHandler(entry)], + states={ + 1: [CallbackQueryHandler(one, pattern='^1$')], + 2: [CallbackQueryHandler(two, pattern='^2$')], + }, + fallbacks=[], + per_message=True, + per_chat=not inline, + ) + app.add_handler(handler) + + # User one, starts the state machine. + message = ( + Message(0, None, self.group, from_user=user1, text='msg w/ inlinekeyboard', bot=bot) + if not inline + else None + ) + inline_message_id = '42' if inline else None + + async with app: + cbq_1 = CallbackQuery( + 0, + user1, + None, + message=message, + data='1', + bot=bot, + inline_message_id=inline_message_id, + ) + cbq_2 = CallbackQuery( + 0, + user1, + None, + message=message, + data='2', + bot=bot, + inline_message_id=inline_message_id, + ) + await app.process_update(Update(update_id=0, callback_query=cbq_1)) + + # Make sure that we're in the correct state + assert handler.check_update(Update(0, callback_query=cbq_1)) + assert not handler.check_update(Update(0, callback_query=cbq_2)) + + await app.process_update(Update(update_id=0, callback_query=cbq_1)) + + # Make sure that we're in the correct state + assert not handler.check_update(Update(0, callback_query=cbq_1)) + assert handler.check_update(Update(0, callback_query=cbq_2)) + + # Let's now verify that for a different user in the same group, the state will not be + # updated + cbq_2.from_user = user2 + await app.process_update(Update(update_id=0, callback_query=cbq_2)) + + cbq_2.from_user = user1 + assert not handler.check_update(Update(0, callback_query=cbq_1)) + assert handler.check_update(Update(0, callback_query=cbq_2)) + + @pytest.mark.asyncio + async def test_end_on_first_message(self, app, bot, user1): + handler = ConversationHandler( + entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[] + ) + app.add_handler(handler) + + # User starts the state machine and immediately ends it. + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + assert handler.check_update(Update(update_id=0, message=message)) + + @pytest.mark.asyncio + async def test_end_on_first_message_non_blocking_handler(self, app, bot, user1): + handler = ConversationHandler( + entry_points=[CommandHandler('start', callback=self.start_end, block=False)], + states={}, + fallbacks=[], + ) + app.add_handler(handler) + + # User starts the state machine with a non-blocking function that immediately ends the + # conversation. non-blocking results are resolved when the users state is queried next + # time. + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + # give the task a chance to finish + await asyncio.sleep(0.1) + + # Let's check that processing the same update again is accepted. this confirms that + # a) the pending state is correctly resolved + # b) the conversation has ended + assert handler.check_update(Update(0, message=message)) + + @pytest.mark.asyncio + async def test_none_on_first_message(self, app, bot, user1): + handler = ConversationHandler( + entry_points=[MessageHandler(filters.ALL, self.start_none)], states={}, fallbacks=[] + ) + app.add_handler(handler) + + # User starts the state machine and a callback function returns None + message = Message(0, None, self.group, from_user=user1, text='/start', bot=bot) + async with app: + await app.process_update(Update(update_id=0, message=message)) + # Check that the same message is accepted again, i.e. the conversation immediately + # ended + assert handler.check_update(Update(0, message=message)) + + @pytest.mark.asyncio + async def test_none_on_first_message_non_blocking_handler(self, app, bot, user1): + handler = ConversationHandler( + entry_points=[CommandHandler('start', self.start_none, block=False)], + states={}, + fallbacks=[], + ) + app.add_handler(handler) + + # User starts the state machine with a non-blocking handler that returns None + # non-blocking results are resolved when the users state is queried next time. + message = Message( + 0, + None, + self.group, + text='/start', + from_user=user1, + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + # Give the task a chance to finish + await asyncio.sleep(0.1) + + # Let's check that processing the same update again is accepted. this confirms that + # a) the pending state is correctly resolved + # b) the conversation has ended + assert handler.check_update(Update(0, message=message)) + + @pytest.mark.asyncio + async def test_per_chat_message_without_chat(self, bot, user1): + handler = ConversationHandler( + entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[] + ) + cbq = CallbackQuery(0, user1, None, None, bot=bot) + update = Update(0, callback_query=cbq) + assert not handler.check_update(update) + + @pytest.mark.asyncio + async def test_channel_message_without_chat(self, bot): + handler = ConversationHandler( + entry_points=[MessageHandler(filters.ALL, self.start_end)], states={}, fallbacks=[] + ) + message = Message(0, date=None, chat=Chat(0, Chat.CHANNEL, 'Misses Test'), bot=bot) + + update = Update(0, channel_post=message) + assert not handler.check_update(update) + + update = Update(0, edited_channel_post=message) + assert not handler.check_update(update) + + @pytest.mark.asyncio + async def test_all_update_types(self, app, bot, user1): + handler = ConversationHandler( + entry_points=[CommandHandler('start', self.start_end)], states={}, fallbacks=[] + ) + message = Message(0, None, self.group, from_user=user1, text='ignore', bot=bot) + callback_query = CallbackQuery(0, user1, None, message=message, data='data', bot=bot) + chosen_inline_result = ChosenInlineResult(0, user1, 'query', bot=bot) + inline_query = InlineQuery(0, user1, 'query', 0, bot=bot) + pre_checkout_query = PreCheckoutQuery(0, user1, 'USD', 100, [], bot=bot) + shipping_query = ShippingQuery(0, user1, [], None, bot=bot) + assert not handler.check_update(Update(0, callback_query=callback_query)) + assert not handler.check_update(Update(0, chosen_inline_result=chosen_inline_result)) + assert not handler.check_update(Update(0, inline_query=inline_query)) + assert not handler.check_update(Update(0, message=message)) + assert not handler.check_update(Update(0, pre_checkout_query=pre_checkout_query)) + assert not handler.check_update(Update(0, shipping_query=shipping_query)) + + @pytest.mark.asyncio + @pytest.mark.parametrize('jq', [True, False]) + async def test_no_running_job_queue_warning(self, app, bot, user1, recwarn, jq): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + # save app.job_queue in temp variable jqueue + # and then set app.job_queue to None. + jqueue = app.job_queue + if not jq: + app.job_queue = None + app.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.5) + assert len(recwarn) == 1 + assert str(recwarn[0].message).startswith("Ignoring `conversation_timeout`") + assert ("is not running" if jq else "has no JobQueue.") in str(recwarn[0].message) + # now set app.job_queue back to it's original value + app.job_queue = jqueue + + @pytest.mark.asyncio + async def test_schedule_job_exception(self, app, bot, user1, monkeypatch, caplog): + def mocked_run_once(*a, **kw): + raise Exception("job error") + + class DictJB(JobQueue): + pass + + app.job_queue = DictJB() + monkeypatch.setattr(app.job_queue, "run_once", mocked_run_once) + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=100, + ) + app.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.start() + + with caplog.at_level(logging.ERROR): + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.5) + + assert len(caplog.records) == 1 + assert caplog.records[0].message == "Failed to schedule timeout." + assert str(caplog.records[0].exc_info[1]) == "job error" + + await app.stop() + + @pytest.mark.asyncio + async def test_non_blocking_exception(self, app, bot, user1, caplog): + """Here we make sure that when a non-blocking handler raises an + exception, the state isn't changed. + """ + error = Exception('task exception') + + async def conv_entry(*a, **kw): + return 1 + + async def raise_error(*a, **kw): + raise error + + handler = ConversationHandler( + entry_points=[CommandHandler("start", conv_entry)], + states={1: [MessageHandler(filters.Text(['error']), raise_error)]}, + fallbacks=self.fallbacks, + block=False, + ) + app.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + # start the conversation + async with app: + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.1) + message.text = "error" + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.1) + caplog.clear() + with caplog.at_level(logging.ERROR): + # This also makes sure that we're still in the same state + assert handler.check_update(Update(0, message=message)) + assert len(caplog.records) == 1 + assert ( + caplog.records[0].message + == "Task function raised exception. Falling back to old state 1" + ) + assert caplog.records[0].exc_info[1] is error + + @pytest.mark.asyncio + async def test_conversation_timeout(self, app, bot, user1): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + # Start state machine, then reach timeout + start_message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + brew_message = Message( + 0, + None, + self.group, + from_user=user1, + text='/brew', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/brew')) + ], + bot=bot, + ) + pour_coffee_message = Message( + 0, + None, + self.group, + from_user=user1, + text='/pourCoffee', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/pourCoffee')) + ], + bot=bot, + ) + async with app: + await app.start() + + await app.process_update(Update(update_id=0, message=start_message)) + assert handler.check_update(Update(0, message=brew_message)) + await asyncio.sleep(0.75) + assert handler.check_update(Update(0, message=start_message)) + + # Start state machine, do something, then reach timeout + await app.process_update(Update(update_id=1, message=start_message)) + assert handler.check_update(Update(0, message=brew_message)) + # assert handler.conversations.get((self.group.id, user1.id)) == self.THIRSTY + # start_message.text = '/brew' + # start_message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=2, message=brew_message)) + assert handler.check_update(Update(0, message=pour_coffee_message)) + # assert handler.conversations.get((self.group.id, user1.id)) == self.BREWING + await asyncio.sleep(0.7) + assert handler.check_update(Update(0, message=start_message)) + # assert handler.conversations.get((self.group.id, user1.id)) is None + + await app.stop() + + @pytest.mark.asyncio + async def test_timeout_not_triggered_on_conv_end_non_blocking(self, bot, app, user1): + def timeout(*a, **kw): + self.test_flag = True + + self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]}) + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + block=False, + ) + app.add_handler(handler) + + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + # start the conversation + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.1) + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=1, message=message)) + await asyncio.sleep(0.1) + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=2, message=message)) + await asyncio.sleep(0.1) + message.text = '/end' + message.entities[0].length = len('/end') + await app.process_update(Update(update_id=3, message=message)) + await asyncio.sleep(1) + # assert timeout handler didn't get called + assert self.test_flag is False + + @pytest.mark.asyncio + async def test_conversation_timeout_application_handler_stop(self, app, bot, user1, recwarn): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + + def timeout(*args, **kwargs): + raise ApplicationHandlerStop() + + self.states.update({ConversationHandler.TIMEOUT: [TypeHandler(Update, timeout)]}) + app.add_handler(handler) + + # Start state machine, then reach timeout + message = Message( + 0, + None, + self.group, + text='/start', + from_user=user1, + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + brew_message = Message( + 0, + None, + self.group, + from_user=user1, + text='/brew', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/brew')) + ], + bot=bot, + ) + + async with app: + await app.start() + + await app.process_update(Update(update_id=0, message=message)) + # Make sure that we're in the next state + assert handler.check_update(Update(0, message=brew_message)) + await app.process_update(Update(0, message=brew_message)) + await asyncio.sleep(0.9) + # Check that conversation has ended by checking that the start messages is accepted + # again + assert handler.check_update(Update(0, message=message)) + assert len(recwarn) == 1 + assert str(recwarn[0].message).startswith('ApplicationHandlerStop in TIMEOUT') + + await app.stop() + + @pytest.mark.asyncio + async def test_conversation_handler_timeout_update_and_context(self, app, bot, user1): + context = None + + async def start_callback(u, c): + nonlocal context, self + context = c + return await self.start(u, c) + + # Start state machine, then reach timeout + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + update = Update(update_id=0, message=message) + + async def timeout_callback(u, c): + nonlocal update, context + assert u is update + assert c is context + + self.is_timeout = (u is update) and (c is context) + + states = self.states + timeout_handler = CommandHandler('start', timeout_callback) + states.update({ConversationHandler.TIMEOUT: [timeout_handler]}) + handler = ConversationHandler( + entry_points=[CommandHandler('start', start_callback)], + states=states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + async with app: + await app.start() + + await app.process_update(update) + await asyncio.sleep(0.9) + # check that the conversation has ended by checking that the start message is accepted + assert handler.check_update(Update(0, message=message)) + assert self.is_timeout + + await app.stop() + + @flaky(3, 1) + @pytest.mark.asyncio + async def test_conversation_timeout_keeps_extending(self, app, bot, user1): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + # Start state machine, wait, do something, verify the timeout is extended. + # t=0 /start (timeout=.5) + # t=.35 /brew (timeout=.85) + # t=.5 original timeout + # t=.6 /pourCoffee (timeout=1.1) + # t=.85 second timeout + # t=1.1 actual timeout + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.start() + + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + assert handler.check_update(Update(0, message=message)) + await asyncio.sleep(0.35) # t=.35 + assert handler.check_update(Update(0, message=message)) + await app.process_update(Update(update_id=0, message=message)) + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + assert handler.check_update(Update(0, message=message)) + await asyncio.sleep(0.25) # t=.6 + assert handler.check_update(Update(0, message=message)) + await app.process_update(Update(update_id=0, message=message)) + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + assert handler.check_update(Update(0, message=message)) + await asyncio.sleep(0.4) # t=1.0 + assert handler.check_update(Update(0, message=message)) + await asyncio.sleep(0.3) # t=1.3 + assert not handler.check_update(Update(0, message=message)) + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + + await app.stop() + + @pytest.mark.asyncio + async def test_conversation_timeout_two_users(self, app, bot, user1, user2): + handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + # Start state machine, do something as second user, then reach timeout + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.start() + + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + assert handler.check_update(Update(0, message=message)) + message.from_user = user2 + await app.process_update(Update(update_id=0, message=message)) + message.text = '/start' + message.entities[0].length = len('/start') + # Make sure that user2s conversation has not yet started + assert handler.check_update(Update(0, message=message)) + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + assert handler.check_update(Update(0, message=message)) + await asyncio.sleep(0.7) + # check that both conversations have ended by checking that the start message is + # accepted again + message.text = '/start' + message.entities[0].length = len('/start') + message.from_user = user1 + assert handler.check_update(Update(0, message=message)) + message.from_user = user2 + assert handler.check_update(Update(0, message=message)) + + await app.stop() + + @pytest.mark.asyncio + async def test_conversation_handler_timeout_state(self, app, bot, user1): + states = self.states + states.update( + { + ConversationHandler.TIMEOUT: [ + CommandHandler('brew', self.passout), + MessageHandler(~filters.Regex('oding'), self.passout2), + ] + } + ) + handler = ConversationHandler( + entry_points=self.entry_points, + states=states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + # CommandHandler timeout + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.start() + + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.7) + # check that conversation has ended by checking that start cmd is accepted again + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + assert self.is_timeout + + # MessageHandler timeout + self.is_timeout = False + message.text = '/start' + message.entities[0].length = len('/start') + await app.process_update(Update(update_id=1, message=message)) + await asyncio.sleep(0.7) + # check that conversation has ended by checking that start cmd is accepted again + assert handler.check_update(Update(0, message=message)) + assert self.is_timeout + + # Timeout but no valid handler + self.is_timeout = False + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.7) + # check that conversation has ended by checking that start cmd is accepted again + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + assert not self.is_timeout + + await app.stop() + + @pytest.mark.asyncio + async def test_conversation_handler_timeout_state_context(self, app, bot, user1): + states = self.states + states.update( + { + ConversationHandler.TIMEOUT: [ + CommandHandler('brew', self.passout_context), + MessageHandler(~filters.Regex('oding'), self.passout2_context), + ] + } + ) + handler = ConversationHandler( + entry_points=self.entry_points, + states=states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + # CommandHandler timeout + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + async with app: + await app.start() + + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.7) + # check that conversation has ended by checking that start cmd is accepted again + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + assert self.is_timeout + + # MessageHandler timeout + self.is_timeout = False + message.text = '/start' + message.entities[0].length = len('/start') + await app.process_update(Update(update_id=1, message=message)) + await asyncio.sleep(0.7) + # check that conversation has ended by checking that start cmd is accepted again + assert handler.check_update(Update(0, message=message)) + assert self.is_timeout + + # Timeout but no valid handler + self.is_timeout = False + await app.process_update(Update(update_id=0, message=message)) + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.7) + # check that conversation has ended by checking that start cmd is accepted again + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + assert not self.is_timeout + + await app.stop() + + @pytest.mark.asyncio + async def test_conversation_timeout_cancel_conflict(self, app, bot, user1): + # Start state machine, wait half the timeout, + # then call a callback that takes more than the timeout + # t=0 /start (timeout=.5) + # t=.25 /slowbrew (sleep .5) + # | t=.5 original timeout (should not execute) + # | t=.75 /slowbrew returns (timeout=1.25) + # t=1.25 timeout + + async def slowbrew(_update, context): + await asyncio.sleep(0.25) + # Let's give to the original timeout a chance to execute + await asyncio.sleep(0.25) + # By returning None we do not override the conversation state so + # we can see if the timeout has been executed + + states = self.states + states[self.THIRSTY].append(CommandHandler('slowbrew', slowbrew)) + states.update({ConversationHandler.TIMEOUT: [MessageHandler(None, self.passout2)]}) + + handler = ConversationHandler( + entry_points=self.entry_points, + states=states, + fallbacks=self.fallbacks, + conversation_timeout=0.5, + ) + app.add_handler(handler) + + # CommandHandler timeout + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + bot=bot, + ) + + async with app: + await app.start() + await app.process_update(Update(update_id=0, message=message)) + await asyncio.sleep(0.25) + message.text = '/slowbrew' + message.entities[0].length = len('/slowbrew') + await app.process_update(Update(update_id=0, message=message)) + # Check that conversation has not ended by checking that start cmd is not accepted + message.text = '/start' + message.entities[0].length = len('/start') + assert not handler.check_update(Update(0, message=message)) + assert not self.is_timeout + + await asyncio.sleep(0.7) + # Check that conversation has ended by checking that start cmd is accepted again + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + assert self.is_timeout + + await app.stop() + + @pytest.mark.asyncio + async def test_nested_conversation_handler(self, app, bot, user1, user2): + self.nested_states[self.DRINKING] = [ + ConversationHandler( + entry_points=self.drinking_entry_points, + states=self.drinking_states, + fallbacks=self.drinking_fallbacks, + map_to_parent=self.drinking_map_to_parent, + ) + ] + handler = ConversationHandler( + entry_points=self.entry_points, states=self.nested_states, fallbacks=self.fallbacks + ) + app.add_handler(handler) + + # User one, starts the state machine. + message = Message( + 0, + None, + self.group, + from_user=user1, + text='/start', + bot=bot, + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.BREWING + + # Lets pour some coffee. + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + + # The user is holding the cup + message.text = '/hold' + message.entities[0].length = len('/hold') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + + # The user is sipping coffee + message.text = '/sip' + message.entities[0].length = len('/sip') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.SIPPING + + # The user is swallowing + message.text = '/swallow' + message.entities[0].length = len('/swallow') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.SWALLOWING + + # The user is holding the cup again + message.text = '/hold' + message.entities[0].length = len('/hold') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + + # The user wants to replenish the coffee supply + message.text = '/replenish' + message.entities[0].length = len('/replenish') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.REPLENISHING + # check that we're in the right state now by checking that the update is accepted + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + assert handler.check_update(Update(0, message=message)) + + # The user wants to drink their coffee again) + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + + # The user is now ready to start coding + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.CODING + + # The user decides it's time to drink again + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + + # The user is holding their cup + message.text = '/hold' + message.entities[0].length = len('/hold') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + + # The user wants to end with the drinking and go back to coding + message.text = '/end' + message.entities[0].length = len('/end') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.END + # check that we're in the right state now by checking that the update is accepted + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + assert handler.check_update(Update(0, message=message)) + + # The user wants to drink once more + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + + # The user wants to stop altogether + message.text = '/stop' + message.entities[0].length = len('/stop') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.STOPPING + # check that the conversation has ended by checking that the start cmd is accepted + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + + @pytest.mark.asyncio + async def test_nested_conversation_application_handler_stop(self, app, bot, user1, user2): + self.nested_states[self.DRINKING] = [ + ConversationHandler( + entry_points=self.drinking_entry_points, + states=self.drinking_states, + fallbacks=self.drinking_fallbacks, + map_to_parent=self.drinking_map_to_parent, + ) + ] + handler = ConversationHandler( + entry_points=self.entry_points, states=self.nested_states, fallbacks=self.fallbacks + ) + + def test_callback(u, c): + self.test_flag = True + + app.add_handler(handler) + app.add_handler(TypeHandler(Update, test_callback), group=1) + self.raise_app_handler_stop = True + + # User one, starts the state machine. + message = Message( + 0, + None, + self.group, + text='/start', + bot=bot, + from_user=user1, + entities=[ + MessageEntity(type=MessageEntity.BOT_COMMAND, offset=0, length=len('/start')) + ], + ) + async with app: + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.THIRSTY + assert not self.test_flag + + # The user is thirsty and wants to brew coffee. + message.text = '/brew' + message.entities[0].length = len('/brew') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.BREWING + assert not self.test_flag + + # Lets pour some coffee. + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user is holding the cup + message.text = '/hold' + message.entities[0].length = len('/hold') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + assert not self.test_flag + + # The user is sipping coffee + message.text = '/sip' + message.entities[0].length = len('/sip') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.SIPPING + assert not self.test_flag + + # The user is swallowing + message.text = '/swallow' + message.entities[0].length = len('/swallow') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.SWALLOWING + assert not self.test_flag + + # The user is holding the cup again + message.text = '/hold' + message.entities[0].length = len('/hold') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + assert not self.test_flag + + # The user wants to replenish the coffee supply + message.text = '/replenish' + message.entities[0].length = len('/replenish') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.REPLENISHING + # check that we're in the right state now by checking that the update is accepted + message.text = '/pourCoffee' + message.entities[0].length = len('/pourCoffee') + assert handler.check_update(Update(0, message=message)) + assert not self.test_flag + + # The user wants to drink their coffee again + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user is now ready to start coding + message.text = '/startCoding' + message.entities[0].length = len('/startCoding') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.CODING + assert not self.test_flag + + # The user decides it's time to drink again + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user is holding their cup + message.text = '/hold' + message.entities[0].length = len('/hold') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.HOLDING + assert not self.test_flag + + # The user wants to end with the drinking and go back to coding + message.text = '/end' + message.entities[0].length = len('/end') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.END + # check that we're in the right state now by checking that the update is accepted + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + assert handler.check_update(Update(0, message=message)) + assert not self.test_flag + + # The user wants to drink once more + message.text = '/drinkMore' + message.entities[0].length = len('/drinkMore') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.DRINKING + assert not self.test_flag + + # The user wants to stop altogether + message.text = '/stop' + message.entities[0].length = len('/stop') + await app.process_update(Update(update_id=0, message=message)) + assert self.current_state[user1.id] == self.STOPPING + # check that the conv has ended by checking that the start cmd is accepted + message.text = '/start' + message.entities[0].length = len('/start') + assert handler.check_update(Update(0, message=message)) + assert not self.test_flag + + @pytest.mark.asyncio + @pytest.mark.parametrize('callback_raises', [True, False]) + async def test_timeout_non_block(self, app, user1, callback_raises): + event = asyncio.Event() + + async def callback(_, __): + await event.wait() + if callback_raises: + raise RuntimeError + return 1 + + conv_handler = ConversationHandler( + entry_points=[MessageHandler(filters.ALL, callback=callback, block=False)], + states={ConversationHandler.TIMEOUT: [TypeHandler(Update, self.passout2)]}, + fallbacks=[], + conversation_timeout=0.5, + ) + app.add_handler(conv_handler) + + async with app: + await app.start() + + message = Message( + 0, + None, + self.group, + text='/start', + from_user=user1, + ) + assert conv_handler.check_update(Update(0, message=message)) + await app.process_update(Update(0, message=message)) + await asyncio.sleep(0.7) + assert not self.is_timeout + event.set() + await asyncio.sleep(0.7) + assert self.is_timeout == (not callback_raises) + + await app.stop() + + @pytest.mark.asyncio + async def test_no_timeout_on_end(self, app, user1): + + conv_handler = ConversationHandler( + entry_points=[MessageHandler(filters.ALL, callback=self.start_end)], + states={ConversationHandler.TIMEOUT: [TypeHandler(Update, self.passout2)]}, + fallbacks=[], + conversation_timeout=0.5, + ) + app.add_handler(conv_handler) + + async with app: + await app.start() + + message = Message( + 0, + None, + self.group, + text='/start', + from_user=user1, + ) + assert conv_handler.check_update(Update(0, message=message)) + await app.process_update(Update(0, message=message)) + await asyncio.sleep(0.7) + assert not self.is_timeout + + await app.stop() + + @pytest.mark.asyncio + async def test_conversation_handler_block_dont_override(self, app): + """This just makes sure that we don't change any attributes of the handlers of the conv""" + conv_handler = ConversationHandler( + entry_points=self.entry_points, + states=self.states, + fallbacks=self.fallbacks, + block=False, + ) + + all_handlers = conv_handler.entry_points + conv_handler.fallbacks + for state_handlers in conv_handler.states.values(): + all_handlers += state_handlers + + for handler in all_handlers: + assert handler.block + + conv_handler = ConversationHandler( + entry_points=[CommandHandler('start', self.start_end, block=False)], + states={1: [CommandHandler('start', self.start_end, block=False)]}, + fallbacks=[CommandHandler('start', self.start_end, block=False)], + block=True, + ) + + all_handlers = conv_handler.entry_points + conv_handler.fallbacks + for state_handlers in conv_handler.states.values(): + all_handlers += state_handlers + + for handler in all_handlers: + assert handler.block is False + + @pytest.mark.asyncio + @pytest.mark.parametrize('default_block', [True, False, None]) + @pytest.mark.parametrize('ch_block', [True, False, None]) + @pytest.mark.parametrize('handler_block', [True, False, None]) + @pytest.mark.parametrize('ext_bot', [True, False], ids=['ExtBot', 'Bot']) + async def test_blocking_resolution_order( + self, bot, default_block, ch_block, handler_block, ext_bot + ): + + event = asyncio.Event() + + async def callback(_, __): + await event.wait() + event.clear() + self.test_flag = True + return 1 + + if handler_block is not None: + handler = CommandHandler('start', callback=callback, block=handler_block) + fallback = MessageHandler(filters.ALL, callback, block=handler_block) + else: + handler = CommandHandler('start', callback=callback) + fallback = MessageHandler(filters.ALL, callback, block=handler_block) + + if default_block is not None: + defaults = Defaults(block=default_block) + else: + defaults = None + + if ch_block is not None: + conv_handler = ConversationHandler( + entry_points=[handler], + states={1: [handler]}, + fallbacks=[fallback], + block=ch_block, + ) + else: + conv_handler = ConversationHandler( + entry_points=[handler], + states={1: [handler]}, + fallbacks=[fallback], + ) + + bot = ExtBot(bot.token, defaults=defaults) if ext_bot else Bot(bot.token) + app = ApplicationBuilder().bot(bot).build() + app.add_handler(conv_handler) + + async with app: + start_message = make_command_message('/start', bot=bot) + fallback_message = make_command_message('/fallback', bot=bot) + + # This loop makes sure that we test all of entry points, states handler & fallbacks + for message in [start_message, start_message, fallback_message]: + process_update_task = asyncio.create_task( + app.process_update(Update(0, message=message)) + ) + if ( + # resolution order is handler_block -> ch_block -> default_block + # setting block=True/False on a lower priority setting may only have an effect + # if it wasn't set for the higher priority settings + (handler_block is False) + or ((handler_block is None) and (ch_block is False)) + or ( + (handler_block is None) + and (ch_block is None) + and ext_bot + and (default_block is False) + ) + ): + # check that the handler was called non-blocking by checking that + # `process_update` has finished + await asyncio.sleep(0.01) + assert process_update_task.done() + else: + # the opposite + assert not process_update_task.done() + + # In any case, the callback must not have finished + assert not self.test_flag + + # After setting the event, the callback must have finished and in the blocking + # case this leads to `process_update` finishing. + event.set() + await asyncio.sleep(0.01) + assert process_update_task.done() + assert self.test_flag + self.test_flag = False + + @pytest.mark.asyncio + async def test_waiting_state(self, app, user1): + event = asyncio.Event() + + async def callback_1(_, __): + self.test_flag = 1 + + async def callback_2(_, __): + self.test_flag = 2 + + async def callback_3(_, __): + self.test_flag = 3 + + async def blocking(_, __): + await event.wait() + return 1 + + conv_handler = ConversationHandler( + entry_points=[MessageHandler(filters.ALL, callback=blocking, block=False)], + states={ + ConversationHandler.WAITING: [ + MessageHandler(filters.Regex('1'), callback_1), + MessageHandler(filters.Regex('2'), callback_2), + ], + 1: [MessageHandler(filters.Regex('2'), callback_3)], + }, + fallbacks=[], + ) + app.add_handler(conv_handler) + + message = Message( + 0, + None, + self.group, + text='/start', + from_user=user1, + ) + + async with app: + await app.process_update(Update(0, message=message)) + assert not self.test_flag + message.text = '1' + await app.process_update(Update(0, message=message)) + assert self.test_flag == 1 + message.text = '2' + await app.process_update(Update(0, message=message)) + assert self.test_flag == 2 + event.set() + await asyncio.sleep(0.05) + self.test_flag = None + await app.process_update(Update(0, message=message)) + assert self.test_flag == 3 diff --git a/tests/test_dictpersistence.py b/tests/test_dictpersistence.py new file mode 100644 index 00000000000..34c5394cbf2 --- /dev/null +++ b/tests/test_dictpersistence.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import pytest + + +try: + import ujson as json +except ImportError: + import json + +from telegram.ext import DictPersistence + + +@pytest.fixture(autouse=True) +def reset_callback_data_cache(bot): + yield + bot.callback_data_cache.clear_callback_data() + bot.callback_data_cache.clear_callback_queries() + bot.arbitrary_callback_data = False + + +@pytest.fixture(scope="function") +def bot_data(): + return {'test1': 'test2', 'test3': {'test4': 'test5'}} + + +@pytest.fixture(scope="function") +def chat_data(): + return {-12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, -67890: {3: 'test4'}} + + +@pytest.fixture(scope="function") +def user_data(): + return {12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, 67890: {3: 'test4'}} + + +@pytest.fixture(scope="function") +def callback_data(): + return [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})], {'test1': 'test2'} + + +@pytest.fixture(scope='function') +def conversations(): + return { + 'name1': {(123, 123): 3, (456, 654): 4}, + 'name2': {(123, 321): 1, (890, 890): 2}, + 'name3': {(123, 321): 1, (890, 890): 2}, + } + + +@pytest.fixture(scope='function') +def user_data_json(user_data): + return json.dumps(user_data) + + +@pytest.fixture(scope='function') +def chat_data_json(chat_data): + return json.dumps(chat_data) + + +@pytest.fixture(scope='function') +def bot_data_json(bot_data): + return json.dumps(bot_data) + + +@pytest.fixture(scope='function') +def callback_data_json(callback_data): + return json.dumps(callback_data) + + +@pytest.fixture(scope='function') +def conversations_json(conversations): + return """{"name1": {"[123, 123]": 3, "[456, 654]": 4}, "name2": + {"[123, 321]": 1, "[890, 890]": 2}, "name3": + {"[123, 321]": 1, "[890, 890]": 2}}""" + + +class TestDictPersistence: + """Just tests the DictPersistence interface. Integration of persistence into Applictation + is tested in TestBasePersistence!""" + + @pytest.mark.asyncio + async def test_slot_behaviour(self, mro_slots, recwarn): + inst = DictPersistence() + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.mark.asyncio + async def test_no_json_given(self): + dict_persistence = DictPersistence() + assert await dict_persistence.get_user_data() == {} + assert await dict_persistence.get_chat_data() == {} + assert await dict_persistence.get_bot_data() == {} + assert await dict_persistence.get_callback_data() is None + assert await dict_persistence.get_conversations('noname') == {} + + @pytest.mark.asyncio + async def test_bad_json_string_given(self): + bad_user_data = 'thisisnojson99900()))(' + bad_chat_data = 'thisisnojson99900()))(' + bad_bot_data = 'thisisnojson99900()))(' + bad_callback_data = 'thisisnojson99900()))(' + bad_conversations = 'thisisnojson99900()))(' + with pytest.raises(TypeError, match='user_data'): + DictPersistence(user_data_json=bad_user_data) + with pytest.raises(TypeError, match='chat_data'): + DictPersistence(chat_data_json=bad_chat_data) + with pytest.raises(TypeError, match='bot_data'): + DictPersistence(bot_data_json=bad_bot_data) + with pytest.raises(TypeError, match='callback_data'): + DictPersistence(callback_data_json=bad_callback_data) + with pytest.raises(TypeError, match='conversations'): + DictPersistence(conversations_json=bad_conversations) + + @pytest.mark.asyncio + async def test_invalid_json_string_given(self): + bad_user_data = '["this", "is", "json"]' + bad_chat_data = '["this", "is", "json"]' + bad_bot_data = '["this", "is", "json"]' + bad_conversations = '["this", "is", "json"]' + bad_callback_data_1 = '[[["str", 3.14, {"di": "ct"}]], "is"]' + bad_callback_data_2 = '[[["str", "non-float", {"di": "ct"}]], {"di": "ct"}]' + bad_callback_data_3 = '[[[{"not": "a str"}, 3.14, {"di": "ct"}]], {"di": "ct"}]' + bad_callback_data_4 = '[[["wrong", "length"]], {"di": "ct"}]' + bad_callback_data_5 = '["this", "is", "json"]' + with pytest.raises(TypeError, match='user_data'): + DictPersistence(user_data_json=bad_user_data) + with pytest.raises(TypeError, match='chat_data'): + DictPersistence(chat_data_json=bad_chat_data) + with pytest.raises(TypeError, match='bot_data'): + DictPersistence(bot_data_json=bad_bot_data) + for bad_callback_data in [ + bad_callback_data_1, + bad_callback_data_2, + bad_callback_data_3, + bad_callback_data_4, + bad_callback_data_5, + ]: + with pytest.raises(TypeError, match='callback_data'): + DictPersistence(callback_data_json=bad_callback_data) + with pytest.raises(TypeError, match='conversations'): + DictPersistence(conversations_json=bad_conversations) + + @pytest.mark.asyncio + async def test_good_json_input( + self, user_data_json, chat_data_json, bot_data_json, conversations_json, callback_data_json + ): + dict_persistence = DictPersistence( + user_data_json=user_data_json, + chat_data_json=chat_data_json, + bot_data_json=bot_data_json, + conversations_json=conversations_json, + callback_data_json=callback_data_json, + ) + user_data = await dict_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await dict_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await dict_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert bot_data['test1'] == 'test2' + assert bot_data['test3']['test4'] == 'test5' + assert 'test6' not in bot_data + + callback_data = await dict_persistence.get_callback_data() + + assert isinstance(callback_data, tuple) + assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] + assert callback_data[1] == {'test1': 'test2'} + + conversation1 = await dict_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await dict_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_good_json_input_callback_data_none(self): + dict_persistence = DictPersistence(callback_data_json='null') + assert dict_persistence.callback_data is None + assert dict_persistence.callback_data_json == 'null' + + @pytest.mark.asyncio + async def test_dict_outputs( + self, + user_data, + user_data_json, + chat_data, + chat_data_json, + bot_data, + bot_data_json, + callback_data_json, + conversations, + conversations_json, + ): + dict_persistence = DictPersistence( + user_data_json=user_data_json, + chat_data_json=chat_data_json, + bot_data_json=bot_data_json, + callback_data_json=callback_data_json, + conversations_json=conversations_json, + ) + assert dict_persistence.user_data == user_data + assert dict_persistence.chat_data == chat_data + assert dict_persistence.bot_data == bot_data + assert dict_persistence.bot_data == bot_data + assert dict_persistence.conversations == conversations + + @pytest.mark.asyncio + async def test_json_outputs( + self, user_data_json, chat_data_json, bot_data_json, callback_data_json, conversations_json + ): + dict_persistence = DictPersistence( + user_data_json=user_data_json, + chat_data_json=chat_data_json, + bot_data_json=bot_data_json, + callback_data_json=callback_data_json, + conversations_json=conversations_json, + ) + assert dict_persistence.user_data_json == user_data_json + assert dict_persistence.chat_data_json == chat_data_json + assert dict_persistence.callback_data_json == callback_data_json + assert dict_persistence.conversations_json == conversations_json + + @pytest.mark.asyncio + async def test_updating( + self, + user_data_json, + chat_data_json, + bot_data_json, + callback_data, + callback_data_json, + conversations, + conversations_json, + ): + dict_persistence = DictPersistence( + user_data_json=user_data_json, + chat_data_json=chat_data_json, + bot_data_json=bot_data_json, + callback_data_json=callback_data_json, + conversations_json=conversations_json, + ) + + user_data = await dict_persistence.get_user_data() + user_data[12345]['test3']['test4'] = 'test6' + assert dict_persistence.user_data != user_data + assert dict_persistence.user_data_json != json.dumps(user_data) + await dict_persistence.update_user_data(12345, user_data[12345]) + assert dict_persistence.user_data == user_data + assert dict_persistence.user_data_json == json.dumps(user_data) + await dict_persistence.drop_user_data(67890) + assert 67890 not in dict_persistence.user_data + dict_persistence._user_data = None + await dict_persistence.drop_user_data(123) + assert isinstance(await dict_persistence.get_user_data(), dict) + + chat_data = await dict_persistence.get_chat_data() + chat_data[-12345]['test3']['test4'] = 'test6' + assert dict_persistence.chat_data != chat_data + assert dict_persistence.chat_data_json != json.dumps(chat_data) + await dict_persistence.update_chat_data(-12345, chat_data[-12345]) + assert dict_persistence.chat_data == chat_data + assert dict_persistence.chat_data_json == json.dumps(chat_data) + await dict_persistence.drop_chat_data(-67890) + assert -67890 not in dict_persistence.chat_data + dict_persistence._chat_data = None + await dict_persistence.drop_chat_data(123) + assert isinstance(await dict_persistence.get_chat_data(), dict) + + bot_data = await dict_persistence.get_bot_data() + bot_data['test3']['test4'] = 'test6' + assert dict_persistence.bot_data != bot_data + assert dict_persistence.bot_data_json != json.dumps(bot_data) + await dict_persistence.update_bot_data(bot_data) + assert dict_persistence.bot_data == bot_data + assert dict_persistence.bot_data_json == json.dumps(bot_data) + + callback_data = await dict_persistence.get_callback_data() + callback_data[1]['test3'] = 'test4' + callback_data[0][0][2]['button2'] = 'test41' + assert dict_persistence.callback_data != callback_data + assert dict_persistence.callback_data_json != json.dumps(callback_data) + await dict_persistence.update_callback_data(callback_data) + assert dict_persistence.callback_data == callback_data + assert dict_persistence.callback_data_json == json.dumps(callback_data) + + conversation1 = await dict_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not dict_persistence.conversations['name1'] == conversation1 + await dict_persistence.update_conversation('name1', (123, 123), 5) + assert dict_persistence.conversations['name1'] == conversation1 + conversations['name1'][(123, 123)] = 5 + assert ( + dict_persistence.conversations_json + == DictPersistence._encode_conversations_to_json(conversations) + ) + assert await dict_persistence.get_conversations('name1') == conversation1 + + dict_persistence._conversations = None + await dict_persistence.update_conversation('name1', (123, 123), 5) + assert dict_persistence.conversations['name1'] == {(123, 123): 5} + assert await dict_persistence.get_conversations('name1') == {(123, 123): 5} + assert ( + dict_persistence.conversations_json + == DictPersistence._encode_conversations_to_json({"name1": {(123, 123): 5}}) + ) + + @pytest.mark.asyncio + async def test_no_data_on_init( + self, bot_data, user_data, chat_data, conversations, callback_data + ): + dict_persistence = DictPersistence() + + assert dict_persistence.user_data is None + assert dict_persistence.chat_data is None + assert dict_persistence.bot_data is None + assert dict_persistence.conversations is None + assert dict_persistence.callback_data is None + assert dict_persistence.user_data_json == 'null' + assert dict_persistence.chat_data_json == 'null' + assert dict_persistence.bot_data_json == 'null' + assert dict_persistence.conversations_json == 'null' + assert dict_persistence.callback_data_json == 'null' + + await dict_persistence.update_bot_data(bot_data) + await dict_persistence.update_user_data(12345, user_data[12345]) + await dict_persistence.update_chat_data(-12345, chat_data[-12345]) + await dict_persistence.update_conversation('name', (1, 1), 'new_state') + await dict_persistence.update_callback_data(callback_data) + + assert dict_persistence.user_data[12345] == user_data[12345] + assert dict_persistence.chat_data[-12345] == chat_data[-12345] + assert dict_persistence.bot_data == bot_data + assert dict_persistence.conversations['name'] == {(1, 1): 'new_state'} + assert dict_persistence.callback_data == callback_data + + @pytest.mark.asyncio + async def test_no_json_dumping_if_data_did_not_change( + self, bot_data, user_data, chat_data, conversations, callback_data, monkeypatch + ): + dict_persistence = DictPersistence() + + await dict_persistence.update_bot_data(bot_data) + await dict_persistence.update_user_data(12345, user_data[12345]) + await dict_persistence.update_chat_data(-12345, chat_data[-12345]) + await dict_persistence.update_conversation('name', (1, 1), 'new_state') + await dict_persistence.update_callback_data(callback_data) + + assert dict_persistence.user_data_json == json.dumps({12345: user_data[12345]}) + assert dict_persistence.chat_data_json == json.dumps({-12345: chat_data[-12345]}) + assert dict_persistence.bot_data_json == json.dumps(bot_data) + assert ( + dict_persistence.conversations_json + == DictPersistence._encode_conversations_to_json({'name': {(1, 1): 'new_state'}}) + ) + assert dict_persistence.callback_data_json == json.dumps(callback_data) + + flag = False + + def dumps(*args, **kwargs): + nonlocal flag + flag = True + + # Since the data doesn't change, json.dumps shoduln't be called beyond this point! + monkeypatch.setattr(json, 'dumps', dumps) + + await dict_persistence.update_bot_data(bot_data) + await dict_persistence.update_user_data(12345, user_data[12345]) + await dict_persistence.update_chat_data(-12345, chat_data[-12345]) + await dict_persistence.update_conversation('name', (1, 1), 'new_state') + await dict_persistence.update_callback_data(callback_data) + + assert not flag diff --git a/tests/test_document.py b/tests/test_document.py index c77ab345bbb..ca9dfb5f540 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -60,6 +60,11 @@ class TestDocument: document_file_id = '5a3128a4d2a04750b5b58397f3b5e812' document_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' + def test_slot_behaviour(self, document, mro_slots): + for attr in document.__slots__: + assert getattr(document, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(document)) == len(set(mro_slots(document))), "duplicate slot" + def test_creation(self, document): assert isinstance(document, Document) assert isinstance(document.file_id, str) diff --git a/tests/test_error.py b/tests/test_error.py index c425fdf9e3a..70a6426c480 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -142,12 +142,12 @@ def test_errors_pickling(self, exception, attributes): (InvalidCallbackData('test data')), ], ) - def test_slots_behavior(self, inst, mro_slots): + def test_slot_behaviour(self, inst, mro_slots): for attr in inst.__slots__: assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" - def test_test_coverage(self): + def test_coverage(self): """ This test is only here to make sure that new errors will override __reduce__ and set __slots__ properly. @@ -178,3 +178,17 @@ def make_assertion(cls): ) make_assertion(TelegramError) + + def test_string_representations(self): + """We just randomly test a few of the subclasses - should suffice""" + e = TelegramError('This is a message') + assert repr(e) == "TelegramError('This is a message')" + assert str(e) == "This is a message" + + e = RetryAfter(42) + assert repr(e) == "RetryAfter('Flood control exceeded. Retry in 42.0 seconds')" + assert str(e) == 'Flood control exceeded. Retry in 42.0 seconds' + + e = BadRequest('This is a message') + assert repr(e) == "BadRequest('This is a message')" + assert str(e) == "This is a message" diff --git a/tests/test_gamehighscore.py b/tests/test_gamehighscore.py index 570ec235449..900f0f9329f 100644 --- a/tests/test_gamehighscore.py +++ b/tests/test_gamehighscore.py @@ -47,6 +47,8 @@ def test_de_json(self, bot): assert highscore.user == self.user assert highscore.score == self.score + assert GameHighScore.de_json(None, bot) is None + def test_to_dict(self, game_highscore): game_highscore_dict = game_highscore.to_dict() diff --git a/tests/test_inlinequery.py b/tests/test_inlinequery.py index fdd15a1fdf7..52487704a6f 100644 --- a/tests/test_inlinequery.py +++ b/tests/test_inlinequery.py @@ -42,6 +42,11 @@ class TestInlineQuery: offset = 'offset' location = Location(8.8, 53.1) + def test_slot_behaviour(self, inline_query, mro_slots): + for attr in inline_query.__slots__: + assert getattr(inline_query, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inline_query)) == len(set(mro_slots(inline_query))), "duplicate slot" + def test_de_json(self, bot): json_dict = { 'id': self.id_, diff --git a/tests/test_inlinequeryhandler.py b/tests/test_inlinequeryhandler.py index 995fc09086c..ccf32276760 100644 --- a/tests/test_inlinequeryhandler.py +++ b/tests/test_inlinequeryhandler.py @@ -142,6 +142,13 @@ async def test_context_pattern(self, app, inline_query): await app.process_update(inline_query) assert self.test_flag + update = Update( + update_id=0, inline_query=InlineQuery(id='id', from_user=None, query='', offset='') + ) + assert not handler.check_update(update) + update.inline_query.query = 'not_a_match' + assert not handler.check_update(update) + @pytest.mark.parametrize('chat_types', [[Chat.SENDER], [Chat.SENDER, Chat.SUPERGROUP], []]) @pytest.mark.parametrize( 'chat_type,result', [(Chat.SENDER, True), (Chat.CHANNEL, False), (None, False)] diff --git a/tests/test_message.py b/tests/test_message.py index daa7d377817..c526f2daae7 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -182,6 +182,12 @@ def message(bot): {'sender_chat': Chat(-123, 'discussion_channel')}, {'is_automatic_forward': True}, {'has_protected_content': True}, + { + 'entities': [ + MessageEntity(MessageEntity.BOLD, 0, 1), + MessageEntity(MessageEntity.TEXT_LINK, 2, 3, url='https://ptb.org'), + ] + }, ], ids=[ 'forwarded_user', @@ -234,6 +240,7 @@ def message(bot): 'sender_chat', 'is_automatic_forward', 'has_protected_content', + 'entities', ], ) def message_params(bot, request): @@ -318,6 +325,11 @@ def test_all_possibilities_de_json_and_to_dict(self, bot, message_params): assert new.to_dict() == message_params.to_dict() + def test_slot_behaviour(self, message, mro_slots): + for attr in message.__slots__: + assert getattr(message, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(message)) == len(set(mro_slots(message))), "duplicate slot" + @pytest.mark.asyncio async def test_parse_entity(self): text = ( @@ -328,6 +340,9 @@ async def test_parse_entity(self): message = Message(1, self.from_user, self.date, self.chat, text=text, entities=[entity]) assert message.parse_entity(entity) == 'http://google.com' + with pytest.raises(RuntimeError, match='Message has no'): + Message(message_id=1, date=self.date, chat=self.chat).parse_entity(entity) + @pytest.mark.asyncio async def test_parse_caption_entity(self): caption = ( @@ -340,6 +355,9 @@ async def test_parse_caption_entity(self): ) assert message.parse_caption_entity(entity) == 'http://google.com' + with pytest.raises(RuntimeError, match='Message has no'): + Message(message_id=1, date=self.date, chat=self.chat).parse_entity(entity) + @pytest.mark.asyncio async def test_parse_entities(self): text = ( @@ -669,18 +687,21 @@ def test_effective_attachment(self, message_params): 'venue', ] - attachment = message_params.effective_attachment - if attachment: - condition = any( - message_params[message_type] is attachment - for message_type in expected_attachment_types - ) - assert condition, 'Got effective_attachment for unexpected type' - else: - condition = any( - message_params[message_type] for message_type in expected_attachment_types - ) - assert not condition, 'effective_attachment was None even though it should not be' + for _ in range(3): + # We run the same test multiple times to make sure that the caching is tested + + attachment = message_params.effective_attachment + if attachment: + condition = any( + message_params[message_type] is attachment + for message_type in expected_attachment_types + ) + assert condition, 'Got effective_attachment for unexpected type' + else: + condition = any( + message_params[message_type] for message_type in expected_attachment_types + ) + assert not condition, 'effective_attachment was None even though it should not be' @pytest.mark.asyncio async def test_reply_text(self, monkeypatch, message): diff --git a/tests/test_messagehandler.py b/tests/test_messagehandler.py index a727a0905f5..5251b1f2e95 100644 --- a/tests/test_messagehandler.py +++ b/tests/test_messagehandler.py @@ -161,6 +161,7 @@ def test_specific_filters(self, message): def test_other_update_types(self, false_update): handler = MessageHandler(None, self.callback) assert not handler.check_update(false_update) + assert not handler.check_update('string') def test_filters_returns_empty_dict(self): class DataFilter(MessageFilter): diff --git a/tests/test_photo.py b/tests/test_photo.py index 6c44f8b29f7..69fb34ced4c 100644 --- a/tests/test_photo.py +++ b/tests/test_photo.py @@ -74,6 +74,11 @@ class TestPhoto: # so we accept three different sizes here. Shouldn't be too much file_size = [29176, 27662] + def test_slot_behaviour(self, photo, mro_slots): + for attr in photo.__slots__: + assert getattr(photo, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(photo)) == len(set(mro_slots(photo))), "duplicate slot" + def test_creation(self, thumb, photo): # Make sure file has been uploaded. assert isinstance(photo, PhotoSize) diff --git a/tests/test_picklepersistence.py b/tests/test_picklepersistence.py new file mode 100644 index 00000000000..fd3a77996bd --- /dev/null +++ b/tests/test_picklepersistence.py @@ -0,0 +1,1015 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +import datetime +import os +import pickle +import gzip +from pathlib import Path + +import pytest + +from telegram.warnings import PTBUserWarning + +from telegram import Update, Message, User, Chat, Bot, TelegramObject +from telegram.ext import ( + PicklePersistence, + ContextTypes, + PersistenceInput, +) + + +@pytest.fixture(autouse=True) +def change_directory(tmp_path: Path): + orig_dir = Path.cwd() + # Switch to a temporary directory, so we don't have to worry about cleaning up files + os.chdir(tmp_path) + yield + # Go back to original directory + os.chdir(orig_dir) + + +@pytest.fixture(autouse=True) +def reset_callback_data_cache(bot): + yield + bot.callback_data_cache.clear_callback_data() + bot.callback_data_cache.clear_callback_queries() + bot.arbitrary_callback_data = False + + +@pytest.fixture(scope="function") +def bot_data(): + return {'test1': 'test2', 'test3': {'test4': 'test5'}} + + +@pytest.fixture(scope="function") +def chat_data(): + return {-12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, -67890: {3: 'test4'}} + + +@pytest.fixture(scope="function") +def user_data(): + return {12345: {'test1': 'test2', 'test3': {'test4': 'test5'}}, 67890: {3: 'test4'}} + + +@pytest.fixture(scope="function") +def callback_data(): + return [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})], {'test1': 'test2'} + + +@pytest.fixture(scope='function') +def conversations(): + return { + 'name1': {(123, 123): 3, (456, 654): 4}, + 'name2': {(123, 321): 1, (890, 890): 2}, + 'name3': {(123, 321): 1, (890, 890): 2}, + } + + +@pytest.fixture(scope='function') +def pickle_persistence(): + return PicklePersistence( + filepath='pickletest', + single_file=False, + on_flush=False, + ) + + +@pytest.fixture(scope='function') +def pickle_persistence_only_bot(): + return PicklePersistence( + filepath='pickletest', + store_data=PersistenceInput(callback_data=False, user_data=False, chat_data=False), + single_file=False, + on_flush=False, + ) + + +@pytest.fixture(scope='function') +def pickle_persistence_only_chat(): + return PicklePersistence( + filepath='pickletest', + store_data=PersistenceInput(callback_data=False, user_data=False, bot_data=False), + single_file=False, + on_flush=False, + ) + + +@pytest.fixture(scope='function') +def pickle_persistence_only_user(): + return PicklePersistence( + filepath='pickletest', + store_data=PersistenceInput(callback_data=False, chat_data=False, bot_data=False), + single_file=False, + on_flush=False, + ) + + +@pytest.fixture(scope='function') +def pickle_persistence_only_callback(): + return PicklePersistence( + filepath='pickletest', + store_data=PersistenceInput(user_data=False, chat_data=False, bot_data=False), + single_file=False, + on_flush=False, + ) + + +@pytest.fixture(scope='function') +def bad_pickle_files(): + for name in [ + 'pickletest_user_data', + 'pickletest_chat_data', + 'pickletest_bot_data', + 'pickletest_callback_data', + 'pickletest_conversations', + 'pickletest', + ]: + Path(name).write_text('(())') + yield True + + +@pytest.fixture(scope='function') +def invalid_pickle_files(): + for name in [ + 'pickletest_user_data', + 'pickletest_chat_data', + 'pickletest_bot_data', + 'pickletest_callback_data', + 'pickletest_conversations', + 'pickletest', + ]: + # Just a random way to trigger pickle.UnpicklingError + # see https://stackoverflow.com/a/44422239/10606962 + with gzip.open(name, 'wb') as file: + pickle.dump([1, 2, 3], file) + yield True + + +@pytest.fixture(scope='function') +def good_pickle_files(user_data, chat_data, bot_data, callback_data, conversations): + data = { + 'user_data': user_data, + 'chat_data': chat_data, + 'bot_data': bot_data, + 'callback_data': callback_data, + 'conversations': conversations, + } + with Path('pickletest_user_data').open('wb') as f: + pickle.dump(user_data, f) + with Path('pickletest_chat_data').open('wb') as f: + pickle.dump(chat_data, f) + with Path('pickletest_bot_data').open('wb') as f: + pickle.dump(bot_data, f) + with Path('pickletest_callback_data').open('wb') as f: + pickle.dump(callback_data, f) + with Path('pickletest_conversations').open('wb') as f: + pickle.dump(conversations, f) + with Path('pickletest').open('wb') as f: + pickle.dump(data, f) + yield True + + +@pytest.fixture(scope='function') +def pickle_files_wo_bot_data(user_data, chat_data, callback_data, conversations): + data = { + 'user_data': user_data, + 'chat_data': chat_data, + 'conversations': conversations, + 'callback_data': callback_data, + } + with Path('pickletest_user_data').open('wb') as f: + pickle.dump(user_data, f) + with Path('pickletest_chat_data').open('wb') as f: + pickle.dump(chat_data, f) + with Path('pickletest_callback_data').open('wb') as f: + pickle.dump(callback_data, f) + with Path('pickletest_conversations').open('wb') as f: + pickle.dump(conversations, f) + with Path('pickletest').open('wb') as f: + pickle.dump(data, f) + yield True + + +@pytest.fixture(scope='function') +def pickle_files_wo_callback_data(user_data, chat_data, bot_data, conversations): + data = { + 'user_data': user_data, + 'chat_data': chat_data, + 'bot_data': bot_data, + 'conversations': conversations, + } + with Path('pickletest_user_data').open('wb') as f: + pickle.dump(user_data, f) + with Path('pickletest_chat_data').open('wb') as f: + pickle.dump(chat_data, f) + with Path('pickletest_bot_data').open('wb') as f: + pickle.dump(bot_data, f) + with Path('pickletest_conversations').open('wb') as f: + pickle.dump(conversations, f) + with Path('pickletest').open('wb') as f: + pickle.dump(data, f) + yield True + + +@pytest.fixture(scope='function') +def update(bot): + user = User(id=321, first_name='test_user', is_bot=False) + chat = Chat(id=123, type='group') + message = Message(1, datetime.datetime.now(), chat, from_user=user, text="Hi there", bot=bot) + return Update(0, message=message) + + +class TestPicklePersistence: + """Just tests the PicklePersistence interface. Integration of persistence into Applictation + is tested in TestBasePersistence!""" + + class DictSub(TelegramObject): # Used for testing our custom (Un)Pickler. + def __init__(self, private, normal, b): + self._private = private + self.normal = normal + self._bot = b + + class SlotsSub(TelegramObject): + __slots__ = ('new_var', '_private') + + def __init__(self, new_var, private): + self.new_var = new_var + self._private = private + + class NormalClass: + def __init__(self, my_var): + self.my_var = my_var + + @pytest.mark.asyncio + async def test_slot_behaviour(self, mro_slots, pickle_persistence): + inst = pickle_persistence + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.mark.asyncio + @pytest.mark.parametrize('on_flush', (True, False)) + async def test_on_flush(self, pickle_persistence, on_flush): + pickle_persistence.on_flush = on_flush + pickle_persistence.single_file = True + file_path = Path(pickle_persistence.filepath) + + await pickle_persistence.update_callback_data('somedata') + assert file_path.is_file() != on_flush + + await pickle_persistence.update_bot_data('data') + assert file_path.is_file() != on_flush + + await pickle_persistence.update_user_data(123, 'data') + assert file_path.is_file() != on_flush + + await pickle_persistence.update_chat_data(123, 'data') + assert file_path.is_file() != on_flush + + await pickle_persistence.update_conversation('name', (1, 1), 'new_state') + assert file_path.is_file() != on_flush + + await pickle_persistence.flush() + assert file_path.is_file() + + @pytest.mark.asyncio + async def test_pickle_behaviour_with_slots(self, pickle_persistence): + bot_data = await pickle_persistence.get_bot_data() + bot_data['message'] = Message(3, datetime.datetime.now(), Chat(2, type='supergroup')) + await pickle_persistence.update_bot_data(bot_data) + retrieved = await pickle_persistence.get_bot_data() + assert retrieved == bot_data + + @pytest.mark.asyncio + async def test_no_files_present_multi_file(self, pickle_persistence): + assert await pickle_persistence.get_user_data() == {} + assert await pickle_persistence.get_chat_data() == {} + assert await pickle_persistence.get_bot_data() == {} + assert await pickle_persistence.get_callback_data() is None + assert await pickle_persistence.get_conversations('noname') == {} + + @pytest.mark.asyncio + async def test_no_files_present_single_file(self, pickle_persistence): + pickle_persistence.single_file = True + assert await pickle_persistence.get_user_data() == {} + assert await pickle_persistence.get_chat_data() == {} + assert await pickle_persistence.get_bot_data() == {} + assert await pickle_persistence.get_callback_data() is None + assert await pickle_persistence.get_conversations('noname') == {} + + @pytest.mark.asyncio + async def test_with_bad_multi_file(self, pickle_persistence, bad_pickle_files): + with pytest.raises(TypeError, match='pickletest_user_data'): + await pickle_persistence.get_user_data() + with pytest.raises(TypeError, match='pickletest_chat_data'): + await pickle_persistence.get_chat_data() + with pytest.raises(TypeError, match='pickletest_bot_data'): + await pickle_persistence.get_bot_data() + with pytest.raises(TypeError, match='pickletest_callback_data'): + await pickle_persistence.get_callback_data() + with pytest.raises(TypeError, match='pickletest_conversations'): + await pickle_persistence.get_conversations('name') + + @pytest.mark.asyncio + async def test_with_invalid_multi_file(self, pickle_persistence, invalid_pickle_files): + with pytest.raises(TypeError, match='pickletest_user_data does not contain'): + await pickle_persistence.get_user_data() + with pytest.raises(TypeError, match='pickletest_chat_data does not contain'): + await pickle_persistence.get_chat_data() + with pytest.raises(TypeError, match='pickletest_bot_data does not contain'): + await pickle_persistence.get_bot_data() + with pytest.raises(TypeError, match='pickletest_callback_data does not contain'): + await pickle_persistence.get_callback_data() + with pytest.raises(TypeError, match='pickletest_conversations does not contain'): + await pickle_persistence.get_conversations('name') + + @pytest.mark.asyncio + async def test_with_bad_single_file(self, pickle_persistence, bad_pickle_files): + pickle_persistence.single_file = True + with pytest.raises(TypeError, match='pickletest'): + await pickle_persistence.get_user_data() + with pytest.raises(TypeError, match='pickletest'): + await pickle_persistence.get_chat_data() + with pytest.raises(TypeError, match='pickletest'): + await pickle_persistence.get_bot_data() + with pytest.raises(TypeError, match='pickletest'): + await pickle_persistence.get_callback_data() + with pytest.raises(TypeError, match='pickletest'): + await pickle_persistence.get_conversations('name') + + @pytest.mark.asyncio + async def test_with_invalid_single_file(self, pickle_persistence, invalid_pickle_files): + pickle_persistence.single_file = True + with pytest.raises(TypeError, match='pickletest does not contain'): + await pickle_persistence.get_user_data() + with pytest.raises(TypeError, match='pickletest does not contain'): + await pickle_persistence.get_chat_data() + with pytest.raises(TypeError, match='pickletest does not contain'): + await pickle_persistence.get_bot_data() + with pytest.raises(TypeError, match='pickletest does not contain'): + await pickle_persistence.get_callback_data() + with pytest.raises(TypeError, match='pickletest does not contain'): + await pickle_persistence.get_conversations('name') + + @pytest.mark.asyncio + async def test_with_good_multi_file(self, pickle_persistence, good_pickle_files): + user_data = await pickle_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await pickle_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await pickle_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert bot_data['test1'] == 'test2' + assert bot_data['test3']['test4'] == 'test5' + assert 'test0' not in bot_data + + callback_data = await pickle_persistence.get_callback_data() + assert isinstance(callback_data, tuple) + assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] + assert callback_data[1] == {'test1': 'test2'} + + conversation1 = await pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_with_good_single_file(self, pickle_persistence, good_pickle_files): + pickle_persistence.single_file = True + user_data = await pickle_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await pickle_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await pickle_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert bot_data['test1'] == 'test2' + assert bot_data['test3']['test4'] == 'test5' + assert 'test0' not in bot_data + + callback_data = await pickle_persistence.get_callback_data() + assert isinstance(callback_data, tuple) + assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] + assert callback_data[1] == {'test1': 'test2'} + + conversation1 = await pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_with_multi_file_wo_bot_data(self, pickle_persistence, pickle_files_wo_bot_data): + user_data = await pickle_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await pickle_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await pickle_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert not bot_data.keys() + + callback_data = await pickle_persistence.get_callback_data() + assert isinstance(callback_data, tuple) + assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] + assert callback_data[1] == {'test1': 'test2'} + + conversation1 = await pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_with_multi_file_wo_callback_data( + self, pickle_persistence, pickle_files_wo_callback_data + ): + user_data = await pickle_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await pickle_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await pickle_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert bot_data['test1'] == 'test2' + assert bot_data['test3']['test4'] == 'test5' + assert 'test0' not in bot_data + + callback_data = await pickle_persistence.get_callback_data() + assert callback_data is None + + conversation1 = await pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_with_single_file_wo_bot_data( + self, pickle_persistence, pickle_files_wo_bot_data + ): + pickle_persistence.single_file = True + user_data = await pickle_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await pickle_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await pickle_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert not bot_data.keys() + + callback_data = await pickle_persistence.get_callback_data() + assert isinstance(callback_data, tuple) + assert callback_data[0] == [('test1', 1000, {'button1': 'test0', 'button2': 'test1'})] + assert callback_data[1] == {'test1': 'test2'} + + conversation1 = await pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_with_single_file_wo_callback_data( + self, pickle_persistence, pickle_files_wo_callback_data + ): + user_data = await pickle_persistence.get_user_data() + assert isinstance(user_data, dict) + assert user_data[12345]['test1'] == 'test2' + assert user_data[67890][3] == 'test4' + + chat_data = await pickle_persistence.get_chat_data() + assert isinstance(chat_data, dict) + assert chat_data[-12345]['test1'] == 'test2' + assert chat_data[-67890][3] == 'test4' + + bot_data = await pickle_persistence.get_bot_data() + assert isinstance(bot_data, dict) + assert bot_data['test1'] == 'test2' + assert bot_data['test3']['test4'] == 'test5' + assert 'test0' not in bot_data + + callback_data = await pickle_persistence.get_callback_data() + assert callback_data is None + + conversation1 = await pickle_persistence.get_conversations('name1') + assert isinstance(conversation1, dict) + assert conversation1[(123, 123)] == 3 + assert conversation1[(456, 654)] == 4 + with pytest.raises(KeyError): + conversation1[(890, 890)] + conversation2 = await pickle_persistence.get_conversations('name2') + assert isinstance(conversation1, dict) + assert conversation2[(123, 321)] == 1 + assert conversation2[(890, 890)] == 2 + with pytest.raises(KeyError): + conversation2[(123, 123)] + + @pytest.mark.asyncio + async def test_updating_multi_file(self, pickle_persistence, good_pickle_files): + user_data = await pickle_persistence.get_user_data() + user_data[12345]['test3']['test4'] = 'test6' + assert pickle_persistence.user_data != user_data + await pickle_persistence.update_user_data(12345, user_data[12345]) + assert pickle_persistence.user_data == user_data + with Path('pickletest_user_data').open('rb') as f: + user_data_test = dict(pickle.load(f)) + assert user_data_test == user_data + await pickle_persistence.drop_user_data(67890) + assert 67890 not in await pickle_persistence.get_user_data() + + chat_data = await pickle_persistence.get_chat_data() + chat_data[-12345]['test3']['test4'] = 'test6' + assert pickle_persistence.chat_data != chat_data + await pickle_persistence.update_chat_data(-12345, chat_data[-12345]) + assert pickle_persistence.chat_data == chat_data + with Path('pickletest_chat_data').open('rb') as f: + chat_data_test = dict(pickle.load(f)) + assert chat_data_test == chat_data + await pickle_persistence.drop_chat_data(-67890) + assert -67890 not in await pickle_persistence.get_chat_data() + + bot_data = await pickle_persistence.get_bot_data() + bot_data['test3']['test4'] = 'test6' + assert pickle_persistence.bot_data != bot_data + await pickle_persistence.update_bot_data(bot_data) + assert pickle_persistence.bot_data == bot_data + with Path('pickletest_bot_data').open('rb') as f: + bot_data_test = pickle.load(f) + assert bot_data_test == bot_data + + callback_data = await pickle_persistence.get_callback_data() + callback_data[1]['test3'] = 'test4' + assert pickle_persistence.callback_data != callback_data + await pickle_persistence.update_callback_data(callback_data) + assert pickle_persistence.callback_data == callback_data + with Path('pickletest_callback_data').open('rb') as f: + callback_data_test = pickle.load(f) + assert callback_data_test == callback_data + + conversation1 = await pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + await pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + assert await pickle_persistence.get_conversations('name1') == conversation1 + with Path('pickletest_conversations').open('rb') as f: + conversations_test = dict(pickle.load(f)) + assert conversations_test['name1'] == conversation1 + + pickle_persistence.conversations = None + await pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == {(123, 123): 5} + assert await pickle_persistence.get_conversations('name1') == {(123, 123): 5} + + @pytest.mark.asyncio + async def test_updating_single_file(self, pickle_persistence, good_pickle_files): + pickle_persistence.single_file = True + + user_data = await pickle_persistence.get_user_data() + user_data[12345]['test3']['test4'] = 'test6' + assert pickle_persistence.user_data != user_data + await pickle_persistence.update_user_data(12345, user_data[12345]) + assert pickle_persistence.user_data == user_data + with Path('pickletest').open('rb') as f: + user_data_test = dict(pickle.load(f))['user_data'] + assert user_data_test == user_data + await pickle_persistence.drop_user_data(67890) + assert 67890 not in await pickle_persistence.get_user_data() + + chat_data = await pickle_persistence.get_chat_data() + chat_data[-12345]['test3']['test4'] = 'test6' + assert pickle_persistence.chat_data != chat_data + await pickle_persistence.update_chat_data(-12345, chat_data[-12345]) + assert pickle_persistence.chat_data == chat_data + with Path('pickletest').open('rb') as f: + chat_data_test = dict(pickle.load(f))['chat_data'] + assert chat_data_test == chat_data + await pickle_persistence.drop_chat_data(-67890) + assert -67890 not in await pickle_persistence.get_chat_data() + + bot_data = await pickle_persistence.get_bot_data() + bot_data['test3']['test4'] = 'test6' + assert pickle_persistence.bot_data != bot_data + await pickle_persistence.update_bot_data(bot_data) + assert pickle_persistence.bot_data == bot_data + with Path('pickletest').open('rb') as f: + bot_data_test = pickle.load(f)['bot_data'] + assert bot_data_test == bot_data + + callback_data = await pickle_persistence.get_callback_data() + callback_data[1]['test3'] = 'test4' + assert pickle_persistence.callback_data != callback_data + await pickle_persistence.update_callback_data(callback_data) + assert pickle_persistence.callback_data == callback_data + with Path('pickletest').open('rb') as f: + callback_data_test = pickle.load(f)['callback_data'] + assert callback_data_test == callback_data + + conversation1 = await pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + await pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + assert await pickle_persistence.get_conversations('name1') == conversation1 + with Path('pickletest').open('rb') as f: + conversations_test = dict(pickle.load(f))['conversations'] + assert conversations_test['name1'] == conversation1 + + pickle_persistence.conversations = None + await pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == {(123, 123): 5} + assert await pickle_persistence.get_conversations('name1') == {(123, 123): 5} + + @pytest.mark.asyncio + async def test_updating_single_file_no_data(self, pickle_persistence): + pickle_persistence.single_file = True + assert not any( + [ + pickle_persistence.user_data, + pickle_persistence.chat_data, + pickle_persistence.bot_data, + pickle_persistence.callback_data, + pickle_persistence.conversations, + ] + ) + await pickle_persistence.flush() + with pytest.raises(FileNotFoundError, match='pickletest'): + open('pickletest', 'rb') + + @pytest.mark.asyncio + async def test_save_on_flush_multi_files(self, pickle_persistence, good_pickle_files): + # Should run without error + await pickle_persistence.flush() + pickle_persistence.on_flush = True + + user_data = await pickle_persistence.get_user_data() + user_data[54321] = {} + user_data[54321]['test9'] = 'test 10' + assert pickle_persistence.user_data != user_data + + await pickle_persistence.update_user_data(54321, user_data[54321]) + assert pickle_persistence.user_data == user_data + + await pickle_persistence.drop_user_data(0) + assert pickle_persistence.user_data == user_data + + with Path('pickletest_user_data').open('rb') as f: + user_data_test = dict(pickle.load(f)) + assert user_data_test != user_data + + chat_data = await pickle_persistence.get_chat_data() + chat_data[54321] = {} + chat_data[54321]['test9'] = 'test 10' + assert pickle_persistence.chat_data != chat_data + + await pickle_persistence.update_chat_data(54321, chat_data[54321]) + assert pickle_persistence.chat_data == chat_data + + await pickle_persistence.drop_chat_data(0) + assert pickle_persistence.user_data == user_data + + with Path('pickletest_chat_data').open('rb') as f: + chat_data_test = dict(pickle.load(f)) + assert chat_data_test != chat_data + + bot_data = await pickle_persistence.get_bot_data() + bot_data['test6'] = 'test 7' + assert pickle_persistence.bot_data != bot_data + + await pickle_persistence.update_bot_data(bot_data) + assert pickle_persistence.bot_data == bot_data + + with Path('pickletest_bot_data').open('rb') as f: + bot_data_test = pickle.load(f) + assert bot_data_test != bot_data + + callback_data = await pickle_persistence.get_callback_data() + callback_data[1]['test3'] = 'test4' + assert pickle_persistence.callback_data != callback_data + + await pickle_persistence.update_callback_data(callback_data) + assert pickle_persistence.callback_data == callback_data + + with Path('pickletest_callback_data').open('rb') as f: + callback_data_test = pickle.load(f) + assert callback_data_test != callback_data + + conversation1 = await pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + + await pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + + with Path('pickletest_conversations').open('rb') as f: + conversations_test = dict(pickle.load(f)) + assert not conversations_test['name1'] == conversation1 + + await pickle_persistence.flush() + with Path('pickletest_user_data').open('rb') as f: + user_data_test = dict(pickle.load(f)) + assert user_data_test == user_data + + with Path('pickletest_chat_data').open('rb') as f: + chat_data_test = dict(pickle.load(f)) + assert chat_data_test == chat_data + + with Path('pickletest_bot_data').open('rb') as f: + bot_data_test = pickle.load(f) + assert bot_data_test == bot_data + + with Path('pickletest_conversations').open('rb') as f: + conversations_test = dict(pickle.load(f)) + assert conversations_test['name1'] == conversation1 + + @pytest.mark.asyncio + async def test_save_on_flush_single_files(self, pickle_persistence, good_pickle_files): + # Should run without error + await pickle_persistence.flush() + + pickle_persistence.on_flush = True + pickle_persistence.single_file = True + + user_data = await pickle_persistence.get_user_data() + user_data[54321] = {} + user_data[54321]['test9'] = 'test 10' + assert pickle_persistence.user_data != user_data + await pickle_persistence.update_user_data(54321, user_data[54321]) + assert pickle_persistence.user_data == user_data + with Path('pickletest').open('rb') as f: + user_data_test = dict(pickle.load(f))['user_data'] + assert user_data_test != user_data + + chat_data = await pickle_persistence.get_chat_data() + chat_data[54321] = {} + chat_data[54321]['test9'] = 'test 10' + assert pickle_persistence.chat_data != chat_data + await pickle_persistence.update_chat_data(54321, chat_data[54321]) + assert pickle_persistence.chat_data == chat_data + with Path('pickletest').open('rb') as f: + chat_data_test = dict(pickle.load(f))['chat_data'] + assert chat_data_test != chat_data + + bot_data = await pickle_persistence.get_bot_data() + bot_data['test6'] = 'test 7' + assert pickle_persistence.bot_data != bot_data + await pickle_persistence.update_bot_data(bot_data) + assert pickle_persistence.bot_data == bot_data + with Path('pickletest').open('rb') as f: + bot_data_test = pickle.load(f)['bot_data'] + assert bot_data_test != bot_data + + callback_data = await pickle_persistence.get_callback_data() + callback_data[1]['test3'] = 'test4' + assert pickle_persistence.callback_data != callback_data + await pickle_persistence.update_callback_data(callback_data) + assert pickle_persistence.callback_data == callback_data + with Path('pickletest').open('rb') as f: + callback_data_test = pickle.load(f)['callback_data'] + assert callback_data_test != callback_data + + conversation1 = await pickle_persistence.get_conversations('name1') + conversation1[(123, 123)] = 5 + assert not pickle_persistence.conversations['name1'] == conversation1 + await pickle_persistence.update_conversation('name1', (123, 123), 5) + assert pickle_persistence.conversations['name1'] == conversation1 + with Path('pickletest').open('rb') as f: + conversations_test = dict(pickle.load(f))['conversations'] + assert not conversations_test['name1'] == conversation1 + + await pickle_persistence.flush() + with Path('pickletest').open('rb') as f: + user_data_test = dict(pickle.load(f))['user_data'] + assert user_data_test == user_data + + with Path('pickletest').open('rb') as f: + chat_data_test = dict(pickle.load(f))['chat_data'] + assert chat_data_test == chat_data + + with Path('pickletest').open('rb') as f: + bot_data_test = pickle.load(f)['bot_data'] + assert bot_data_test == bot_data + + with Path('pickletest').open('rb') as f: + conversations_test = dict(pickle.load(f))['conversations'] + assert conversations_test['name1'] == conversation1 + + @pytest.mark.asyncio + async def test_custom_pickler_unpickler_simple( + self, pickle_persistence, update, good_pickle_files, bot, recwarn + ): + pickle_persistence.set_bot(bot) # assign the current bot to the persistence + data_with_bot = {'current_bot': update.message} + await pickle_persistence.update_chat_data( + 12345, data_with_bot + ) # also calls BotPickler.dumps() + + # Test that regular pickle load fails - + err_msg = ( + "A load persistent id instruction was encountered,\nbut no persistent_load " + "function was specified." + ) + with pytest.raises(pickle.UnpicklingError, match=err_msg): + with open('pickletest_chat_data', 'rb') as f: + pickle.load(f) + + # Test that our custom unpickler works as intended -- inserts the current bot + # We have to create a new instance otherwise unpickling is skipped + pp = PicklePersistence("pickletest", single_file=False, on_flush=False) + pp.set_bot(bot) # Set the bot + assert (await pp.get_chat_data())[12345]['current_bot'].get_bot() is bot + + # Now test that pickling of unknown bots in TelegramObjects will be replaced by None- + assert not len(recwarn) + data_with_bot = {} + async with Bot(bot.token) as other_bot: + data_with_bot['unknown_bot_in_user'] = User(1, 'Dev', False, bot=other_bot) + await pickle_persistence.update_chat_data(12345, data_with_bot) + assert len(recwarn) == 1 + assert recwarn[-1].category is PTBUserWarning + assert str(recwarn[-1].message).startswith("Unknown bot instance found.") + pp = PicklePersistence("pickletest", single_file=False, on_flush=False) + pp.set_bot(bot) + assert (await pp.get_chat_data())[12345]['unknown_bot_in_user']._bot is None + + @pytest.mark.asyncio + async def test_custom_pickler_unpickler_with_custom_objects( + self, bot, pickle_persistence, good_pickle_files + ): + dict_s = self.DictSub("private", 'normal', bot) + slot_s = self.SlotsSub("new_var", 'private_var') + regular = self.NormalClass(12) + + pickle_persistence.set_bot(bot) + await pickle_persistence.update_user_data( + 1232, {'sub_dict': dict_s, 'sub_slots': slot_s, 'r': regular} + ) + pp = PicklePersistence("pickletest", single_file=False, on_flush=False) + pp.set_bot(bot) # Set the bot + data = (await pp.get_user_data())[1232] + sub_dict = data['sub_dict'] + sub_slots = data['sub_slots'] + sub_regular = data['r'] + assert sub_dict._bot is bot + assert sub_dict.normal == dict_s.normal + assert sub_dict._private == dict_s._private + assert sub_slots.new_var == slot_s.new_var + assert sub_slots._private == slot_s._private + assert sub_slots._bot is None # We didn't set the bot, so it shouldn't have it here. + assert sub_regular.my_var == regular.my_var + + @pytest.mark.parametrize( + 'filepath', + ['pickletest', Path('pickletest')], + ids=['str filepath', 'pathlib.Path filepath'], + ) + @pytest.mark.asyncio + async def test_filepath_argument_types(self, filepath): + pick_persist = PicklePersistence( + filepath=filepath, + on_flush=False, + ) + await pick_persist.update_user_data(1, 1) + + assert (await pick_persist.get_user_data())[1] == 1 + assert Path(filepath).is_file() + + @pytest.mark.parametrize('singlefile', [True, False]) + @pytest.mark.parametrize('ud', [int, float, complex]) + @pytest.mark.parametrize('cd', [int, float, complex]) + @pytest.mark.parametrize('bd', [int, float, complex]) + @pytest.mark.asyncio + async def test_with_context_types(self, ud, cd, bd, singlefile): + cc = ContextTypes(user_data=ud, chat_data=cd, bot_data=bd) + persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc) + + assert isinstance(await persistence.get_bot_data(), bd) + assert await persistence.get_bot_data() == 0 + + persistence.user_data = None + persistence.chat_data = None + await persistence.drop_user_data(123) + await persistence.drop_chat_data(123) + assert isinstance(await persistence.get_user_data(), dict) + assert isinstance(await persistence.get_chat_data(), dict) + persistence.user_data = None + persistence.chat_data = None + await persistence.update_user_data(1, ud(1)) + await persistence.update_chat_data(1, cd(1)) + await persistence.update_bot_data(bd(1)) + assert (await persistence.get_user_data())[1] == 1 + assert (await persistence.get_chat_data())[1] == 1 + assert await persistence.get_bot_data() == 1 + + await persistence.flush() + persistence = PicklePersistence('pickletest', single_file=singlefile, context_types=cc) + assert isinstance((await persistence.get_user_data())[1], ud) + assert (await persistence.get_user_data())[1] == 1 + assert isinstance((await persistence.get_chat_data())[1], cd) + assert (await persistence.get_chat_data())[1] == 1 + assert isinstance(await persistence.get_bot_data(), bd) + assert await persistence.get_bot_data() == 1 + + @pytest.mark.asyncio + async def test_no_write_if_data_did_not_change( + self, pickle_persistence, bot_data, user_data, chat_data, conversations, callback_data + ): + pickle_persistence.single_file = True + pickle_persistence.on_flush = False + + await pickle_persistence.update_bot_data(bot_data) + await pickle_persistence.update_user_data(12345, user_data[12345]) + await pickle_persistence.update_chat_data(-12345, chat_data[-12345]) + await pickle_persistence.update_conversation('name', (1, 1), 'new_state') + await pickle_persistence.update_callback_data(callback_data) + + assert pickle_persistence.filepath.is_file() + pickle_persistence.filepath.unlink() + assert not pickle_persistence.filepath.is_file() + + await pickle_persistence.update_bot_data(bot_data) + await pickle_persistence.update_user_data(12345, user_data[12345]) + await pickle_persistence.update_chat_data(-12345, chat_data[-12345]) + await pickle_persistence.update_conversation('name', (1, 1), 'new_state') + await pickle_persistence.update_callback_data(callback_data) + + assert not pickle_persistence.filepath.is_file() diff --git a/tests/test_poll.py b/tests/test_poll.py index 1c32b66a94d..2f8b042828c 100644 --- a/tests/test_poll.py +++ b/tests/test_poll.py @@ -217,6 +217,18 @@ def test_parse_entity(self, poll): assert poll.parse_explanation_entity(entity) == 'http://google.com' + with pytest.raises(RuntimeError, match='Poll has no'): + Poll( + 'id', + 'question', + [PollOption('text', voter_count=0)], + total_voter_count=0, + is_closed=False, + is_anonymous=False, + type=Poll.QUIZ, + allows_multiple_answers=False, + ).parse_explanation_entity(entity) + def test_parse_entities(self, poll): entity = MessageEntity(type=MessageEntity.URL, offset=13, length=17) entity_2 = MessageEntity(type=MessageEntity.BOLD, offset=13, length=1) diff --git a/tests/test_precheckoutquery.py b/tests/test_precheckoutquery.py index c782e066729..88a81ae5643 100644 --- a/tests/test_precheckoutquery.py +++ b/tests/test_precheckoutquery.py @@ -46,6 +46,12 @@ class TestPreCheckoutQuery: from_user = User(0, '', False) order_info = OrderInfo() + def test_slot_behaviour(self, pre_checkout_query, mro_slots): + inst = pre_checkout_query + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + def test_de_json(self, bot): json_dict = { 'id': self.id_, diff --git a/tests/test_replykeyboardmarkup.py b/tests/test_replykeyboardmarkup.py index 405a85cd78a..6480d1d3167 100644 --- a/tests/test_replykeyboardmarkup.py +++ b/tests/test_replykeyboardmarkup.py @@ -39,6 +39,12 @@ class TestReplyKeyboardMarkup: one_time_keyboard = True selective = True + def test_slot_behaviour(self, reply_keyboard_markup, mro_slots): + inst = reply_keyboard_markup + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + @flaky(3, 1) @pytest.mark.asyncio async def test_send_message_with_reply_keyboard_markup( @@ -100,6 +106,12 @@ def test_expected_values(self, reply_keyboard_markup): assert reply_keyboard_markup.one_time_keyboard == self.one_time_keyboard assert reply_keyboard_markup.selective == self.selective + def test_wrong_keyboard_inputs(self): + with pytest.raises(ValueError): + ReplyKeyboardMarkup([['button1'], 'Button2']) + with pytest.raises(ValueError): + ReplyKeyboardMarkup('button') + def test_to_dict(self, reply_keyboard_markup): reply_keyboard_markup_dict = reply_keyboard_markup.to_dict() diff --git a/tests/test_requestdata.py b/tests/test_requestdata.py index 3a38a8e82b8..2f254b6b26d 100644 --- a/tests/test_requestdata.py +++ b/tests/test_requestdata.py @@ -133,6 +133,11 @@ def mixed_rqs(mixed_params) -> RequestData: class TestRequestData: + def test_slot_behaviour(self, simple_rqs, mro_slots): + for attr in simple_rqs.__slots__: + assert getattr(simple_rqs, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(simple_rqs)) == len(set(mro_slots(simple_rqs))), "duplicate slot" + def test_contains_files(self, simple_rqs, file_rqs, mixed_rqs): assert not simple_rqs.contains_files assert file_rqs.contains_files diff --git a/tests/test_requestparameter.py b/tests/test_requestparameter.py index aaf9ea75027..52c404a1e5c 100644 --- a/tests/test_requestparameter.py +++ b/tests/test_requestparameter.py @@ -38,6 +38,12 @@ def test_init(self): assert request_parameter.value == 'value' assert request_parameter.input_files is None + def test_slot_behaviour(self, mro_slots): + inst = RequestParameter('name', 'value', [1, 2]) + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + @pytest.mark.parametrize( 'value, expected', [ diff --git a/tests/test_shippingquery.py b/tests/test_shippingquery.py index d9415436a6d..ce6cdb745a3 100644 --- a/tests/test_shippingquery.py +++ b/tests/test_shippingquery.py @@ -40,6 +40,12 @@ class TestShippingQuery: from_user = User(0, '', False) shipping_address = ShippingAddress('GB', '', 'London', '12 Grimmauld Place', '', 'WC1') + def test_slot_behaviour(self, shipping_query, mro_slots): + inst = shipping_query + for attr in inst.__slots__: + assert getattr(inst, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + def test_de_json(self, bot): json_dict = { 'id': TestShippingQuery.id_, diff --git a/tests/test_stringregexhandler.py b/tests/test_stringregexhandler.py index b7db2ec5bbe..56adc5e7fc1 100644 --- a/tests/test_stringregexhandler.py +++ b/tests/test_stringregexhandler.py @@ -19,6 +19,7 @@ import asyncio import pytest +import re from telegram import ( Bot, @@ -97,8 +98,12 @@ async def callback_pattern(self, update, context): self.test_flag = context.matches[0].groupdict() == {'begin': 't', 'end': ' message'} @pytest.mark.asyncio - async def test_basic(self, app): - handler = StringRegexHandler('(?P.*)est(?P.*)', self.callback) + @pytest.mark.parametrize('compile', (True, False)) + async def test_basic(self, app, compile): + pattern = '(?P.*)est(?P.*)' + if compile: + pattern = re.compile('(?P.*)est(?P.*)') + handler = StringRegexHandler(pattern, self.callback) app.add_handler(handler) assert handler.check_update('test message') diff --git a/tests/test_trackingdict.py b/tests/test_trackingdict.py index 7f8849a693f..f6e5e91cd15 100644 --- a/tests/test_trackingdict.py +++ b/tests/test_trackingdict.py @@ -35,6 +35,11 @@ def data() -> dict: class TestTrackingDict: + def test_slot_behaviour(self, td, mro_slots): + for attr in td.__slots__: + assert getattr(td, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(td)) == len(set(mro_slots(td))), "duplicate slot" + def test_representations(self, td, data): assert repr(td) == repr(data) assert str(td) == str(data) @@ -159,3 +164,10 @@ def test_iter(self, td, data): td.update_no_track({2: 2, 3: 3, 4: 4}) assert not td.pop_accessed_keys() assert list(iter(td)) == list(iter(data)) + + def test_mark_as_accessed(self, td): + td[1] = 2 + assert td.pop_accessed_keys() == {1} + assert td.pop_accessed_keys() == set() + td.mark_as_accessed(1) + assert td.pop_accessed_keys() == {1} diff --git a/tests/test_user.py b/tests/test_user.py index d4f621d3ec2..ad9195c9d91 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -65,6 +65,11 @@ class TestUser: can_read_all_group_messages = True supports_inline_queries = False + def test_slot_behaviour(self, user, mro_slots): + for attr in user.__slots__: + assert getattr(user, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(user)) == len(set(mro_slots(user))), "duplicate slot" + def test_de_json(self, json_dict, bot): user = User.de_json(json_dict, bot) diff --git a/tests/test_video.py b/tests/test_video.py index 141f20a2668..b12806d278c 100644 --- a/tests/test_video.py +++ b/tests/test_video.py @@ -67,6 +67,11 @@ class TestVideo: video_file_id = '5a3128a4d2a04750b5b58397f3b5e812' video_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' + def test_slot_behaviour(self, video, mro_slots): + for attr in video.__slots__: + assert getattr(video, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(video)) == len(set(mro_slots(video))), "duplicate slot" + def test_creation(self, video): # Make sure file has been uploaded. assert isinstance(video, Video) diff --git a/tests/test_videonote.py b/tests/test_videonote.py index 915a0e88615..7f11bee7b69 100644 --- a/tests/test_videonote.py +++ b/tests/test_videonote.py @@ -60,6 +60,11 @@ class TestVideoNote: videonote_file_id = '5a3128a4d2a04750b5b58397f3b5e812' videonote_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' + def test_slot_behaviour(self, video_note, mro_slots): + for attr in video_note.__slots__: + assert getattr(video_note, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(video_note)) == len(set(mro_slots(video_note))), "duplicate slot" + def test_creation(self, video_note): # Make sure file has been uploaded. assert isinstance(video_note, VideoNote) diff --git a/tests/test_voice.py b/tests/test_voice.py index 4190e95cdcf..08d492c7798 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -58,6 +58,11 @@ class TestVoice: voice_file_id = '5a3128a4d2a04750b5b58397f3b5e812' voice_file_unique_id = 'adc3145fd2e84d95b64d68eaa22aa33e' + def test_slot_behaviour(self, voice, mro_slots): + for attr in voice.__slots__: + assert getattr(voice, attr, 'err') != 'err', f"got extra slot '{attr}'" + assert len(mro_slots(voice)) == len(set(mro_slots(voice))), "duplicate slot" + @pytest.mark.asyncio async def test_creation(self, voice): # Make sure file has been uploaded. diff --git a/tests/test_voicechat.py b/tests/test_voicechat.py index 92251b073df..ef93cb22c86 100644 --- a/tests/test_voicechat.py +++ b/tests/test_voicechat.py @@ -111,14 +111,20 @@ def test_de_json(self, user1, user2, bot): assert voice_chat_participants.users[0].id == user1.id assert voice_chat_participants.users[1].id == user2.id - def test_to_dict(self, user1, user2): - voice_chat_participants = VoiceChatParticipantsInvited([user1, user2]) + @pytest.mark.parametrize('use_users', (True, False)) + def test_to_dict(self, user1, user2, use_users): + voice_chat_participants = VoiceChatParticipantsInvited( + [user1, user2] if use_users else None + ) voice_chat_dict = voice_chat_participants.to_dict() assert isinstance(voice_chat_dict, dict) - assert voice_chat_dict["users"] == [user1.to_dict(), user2.to_dict()] - assert voice_chat_dict["users"][0]["id"] == user1.id - assert voice_chat_dict["users"][1]["id"] == user2.id + if use_users: + assert voice_chat_dict["users"] == [user1.to_dict(), user2.to_dict()] + assert voice_chat_dict["users"][0]["id"] == user1.id + assert voice_chat_dict["users"][1]["id"] == user2.id + else: + assert voice_chat_dict == {} def test_equality(self, user1, user2): a = VoiceChatParticipantsInvited([user1]) From bafd4ede8fc99617c74379637b111be230e440c2 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 08:43:23 +0200 Subject: [PATCH 120/153] pre-commit --- telegram/_bot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index e204c3d8b03..4fce356b7e2 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -4792,7 +4792,8 @@ async def restrict_chat_member( """ Use this method to restrict a user in a supergroup. The bot must be an administrator in the supergroup for this to work and must have the appropriate admin rights. Pass - :obj:`True` for all boolean parameters in :class:`telegram.ChatPermissions` to lift restrictions from a user. + :obj:`True` for all boolean parameters in :class:`telegram.ChatPermissions` to lift + restrictions from a user. Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username From 0c9be83b8f740f886f122b7027b1e52fd5f3b5b5 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 09:38:14 +0200 Subject: [PATCH 121/153] some typos, wording etc --- .github/workflows/test.yml | 1 - .pre-commit-config.yaml | 2 +- .../telegram.ext.applicationbuilder.rst | 2 +- examples/passportbot.html | 2 +- pyproject.toml | 2 +- telegram/_bot.py | 6 +- telegram/_callbackquery.py | 2 +- telegram/_chat.py | 98 +++++++++++-------- telegram/_chatjoinrequest.py | 4 +- telegram/_files/inputfile.py | 10 +- telegram/_message.py | 72 +++++++------- telegram/_payment/precheckoutquery.py | 2 +- telegram/_payment/shippingquery.py | 2 +- telegram/_user.py | 52 +++++----- 14 files changed, 139 insertions(+), 118 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0c28ee6ccde..73a639512b9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,6 @@ on: branches: - master - v14 - - asyncio push: branches: - master diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31751c4477e..5d7979a6d61 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: - tornado>=6.1 - APScheduler==3.6.3 - cachetools==4.2.2 - - . # this basically does `pip install -e .`n + - . # this basically does `pip install -e .` - id: mypy name: mypy-examples files: ^examples/.*\.py$ diff --git a/docs/source/telegram.ext.applicationbuilder.rst b/docs/source/telegram.ext.applicationbuilder.rst index fbdec5357a1..cff3899c492 100644 --- a/docs/source/telegram.ext.applicationbuilder.rst +++ b/docs/source/telegram.ext.applicationbuilder.rst @@ -1,4 +1,4 @@ -:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_builders.py +:github_url: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/telegram/ext/_applicationbuilder.py telegram.ext.ApplicationBuilder =============================== diff --git a/examples/passportbot.html b/examples/passportbot.html index b25c51f6a50..f793a7779fc 100644 --- a/examples/passportbot.html +++ b/examples/passportbot.html @@ -18,7 +18,7 @@

Telegram passport test

"use strict"; Telegram.Passport.createAuthButton('telegram_passport_auth', { - bot_id: 703777048, // YOUR BOT ID + bot_id: 1234567890, // YOUR BOT ID scope: { data: [{ type: 'id_document', diff --git a/pyproject.toml b/pyproject.toml index a6c381a6ffb..4ea2d7badaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,4 +7,4 @@ skip-string-normalization = true # so that pre-commit run --all-files does the correct thing # see https://github.com/psf/black/issues/1778 force-exclude = '^(?!/(telegram|examples|tests)/).*\.py$' -include = '(telegram|examples|tests)/.*\.py$' \ No newline at end of file +include = '(telegram|examples|tests)/.*\.py$' diff --git a/telegram/_bot.py b/telegram/_bot.py index 4fce356b7e2..8b65f3045ab 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -364,7 +364,9 @@ async def _send_message( async def initialize(self) -> None: """Initialize resources used by this class. Currently calls :meth:`get_me` to cache :attr:`bot` and calls :meth:`telegram.request.BaseRequest.initialize` for - :attr:`request`. + the request objects used by this bot. + + .. versionadded:: 14.0 """ if self._initialized: self._logger.debug('This Bot is already initialized.') @@ -377,6 +379,8 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """Stop & clear resources used by this class. Currently just calls :meth:`telegram.request.BaseRequest.shutdown` for the request objects used by this bot. + + .. versionadded:: 14.0 """ if not self._initialized: self._logger.debug('This Bot is already shut down. Returning.') diff --git a/telegram/_callbackquery.py b/telegram/_callbackquery.py index b49dc7926c4..e24fe2841ac 100644 --- a/telegram/_callbackquery.py +++ b/telegram/_callbackquery.py @@ -156,7 +156,7 @@ async def answer( ) -> bool: """Shortcut for:: - bot.answer_callback_query(update.callback_query.id, *args, **kwargs) + await bot.answer_callback_query(update.callback_query.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.answer_callback_query`. diff --git a/telegram/_chat.py b/telegram/_chat.py index 4fb931a29f6..a9310312186 100644 --- a/telegram/_chat.py +++ b/telegram/_chat.py @@ -310,7 +310,7 @@ async def leave( ) -> bool: """Shortcut for:: - bot.leave_chat(update.effective_chat.id, *args, **kwargs) + await bot.leave_chat(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.leave_chat`. @@ -337,7 +337,7 @@ async def get_administrators( ) -> List['ChatMember']: """Shortcut for:: - bot.get_chat_administrators(update.effective_chat.id, *args, **kwargs) + await bot.get_chat_administrators(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.get_chat_administrators`. @@ -368,7 +368,7 @@ async def get_member_count( ) -> int: """Shortcut for:: - bot.get_chat_member_count(update.effective_chat.id, *args, **kwargs) + await bot.get_chat_member_count(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.get_chat_member_count`. @@ -396,7 +396,7 @@ async def get_member( ) -> 'ChatMember': """Shortcut for:: - bot.get_chat_member(update.effective_chat.id, *args, **kwargs) + await bot.get_chat_member(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.get_chat_member`. @@ -427,7 +427,7 @@ async def ban_member( ) -> bool: """Shortcut for:: - bot.ban_chat_member(update.effective_chat.id, *args, **kwargs) + await bot.ban_chat_member(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.ban_chat_member`. @@ -458,7 +458,7 @@ async def ban_sender_chat( ) -> bool: """Shortcut for:: - bot.ban_chat_sender_chat(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.ban_chat_sender_chat(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.ban_chat_sender_chat`. @@ -490,7 +490,11 @@ async def ban_chat( ) -> bool: """Shortcut for:: - bot.ban_chat_sender_chat(sender_chat_id=update.effective_chat.id, *args, **kwargs) + await bot.ban_chat_sender_chat( + sender_chat_id=update.effective_chat.id, + *args, + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.ban_chat_sender_chat`. @@ -522,7 +526,7 @@ async def unban_sender_chat( ) -> bool: """Shortcut for:: - bot.unban_chat_sender_chat(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.unban_chat_sender_chat(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.unban_chat_sender_chat`. @@ -554,7 +558,11 @@ async def unban_chat( ) -> bool: """Shortcut for:: - bot.unban_chat_sender_chat(sender_chat_id=update.effective_chat.id, *args, **kwargs) + await bot.unban_chat_sender_chat( + sender_chat_id=update.effective_chat.id, + *args, + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.unban_chat_sender_chat`. @@ -587,7 +595,7 @@ async def unban_member( ) -> bool: """Shortcut for:: - bot.unban_chat_member(update.effective_chat.id, *args, **kwargs) + await bot.unban_chat_member(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.unban_chat_member`. @@ -628,7 +636,7 @@ async def promote_member( ) -> bool: """Shortcut for:: - bot.promote_chat_member(update.effective_chat.id, *args, **kwargs) + await bot.promote_chat_member(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.promote_chat_member`. @@ -673,7 +681,7 @@ async def restrict_member( ) -> bool: """Shortcut for:: - bot.restrict_chat_member(update.effective_chat.id, *args, **kwargs) + await bot.restrict_chat_member(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.restrict_chat_member`. @@ -707,7 +715,7 @@ async def set_permissions( ) -> bool: """Shortcut for:: - bot.set_chat_permissions(update.effective_chat.id, *args, **kwargs) + await bot.set_chat_permissions(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.set_chat_permissions`. @@ -738,7 +746,11 @@ async def set_administrator_custom_title( ) -> bool: """Shortcut for:: - bot.set_chat_administrator_custom_title(update.effective_chat.id, *args, **kwargs) + await bot.set_chat_administrator_custom_title( + update.effective_chat.id, + *args, + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.set_chat_administrator_custom_title`. @@ -770,7 +782,7 @@ async def pin_message( ) -> bool: """Shortcut for:: - bot.pin_chat_message(chat_id=update.effective_chat.id, + await bot.pin_chat_message(chat_id=update.effective_chat.id, *args, **kwargs) @@ -803,7 +815,7 @@ async def unpin_message( ) -> bool: """Shortcut for:: - bot.unpin_chat_message(chat_id=update.effective_chat.id, + await bot.unpin_chat_message(chat_id=update.effective_chat.id, *args, **kwargs) @@ -834,7 +846,7 @@ async def unpin_all_messages( ) -> bool: """Shortcut for:: - bot.unpin_all_chat_messages(chat_id=update.effective_chat.id, + await bot.unpin_all_chat_messages(chat_id=update.effective_chat.id, *args, **kwargs) @@ -873,7 +885,7 @@ async def send_message( ) -> 'Message': """Shortcut for:: - bot.send_message(update.effective_chat.id, *args, **kwargs) + await bot.send_message(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_message`. @@ -916,7 +928,7 @@ async def send_media_group( ) -> List['Message']: """Shortcut for:: - bot.send_media_group(update.effective_chat.id, *args, **kwargs) + await bot.send_media_group(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_media_group`. @@ -949,7 +961,7 @@ async def send_chat_action( ) -> bool: """Shortcut for:: - bot.send_chat_action(update.effective_chat.id, *args, **kwargs) + await bot.send_chat_action(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_chat_action`. @@ -990,7 +1002,7 @@ async def send_photo( ) -> 'Message': """Shortcut for:: - bot.send_photo(update.effective_chat.id, *args, **kwargs) + await bot.send_photo(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_photo`. @@ -1037,7 +1049,7 @@ async def send_contact( ) -> 'Message': """Shortcut for:: - bot.send_contact(update.effective_chat.id, *args, **kwargs) + await bot.send_contact(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_contact`. @@ -1088,7 +1100,7 @@ async def send_audio( ) -> 'Message': """Shortcut for:: - bot.send_audio(update.effective_chat.id, *args, **kwargs) + await bot.send_audio(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_audio`. @@ -1141,7 +1153,7 @@ async def send_document( ) -> 'Message': """Shortcut for:: - bot.send_document(update.effective_chat.id, *args, **kwargs) + await bot.send_document(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_document`. @@ -1186,7 +1198,7 @@ async def send_dice( ) -> 'Message': """Shortcut for:: - bot.send_dice(update.effective_chat.id, *args, **kwargs) + await bot.send_dice(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_dice`. @@ -1225,7 +1237,7 @@ async def send_game( ) -> 'Message': """Shortcut for:: - bot.send_game(update.effective_chat.id, *args, **kwargs) + await bot.send_game(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_game`. @@ -1284,7 +1296,7 @@ async def send_invoice( ) -> 'Message': """Shortcut for:: - bot.send_invoice(update.effective_chat.id, *args, **kwargs) + await bot.send_invoice(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_invoice`. @@ -1357,7 +1369,7 @@ async def send_location( ) -> 'Message': """Shortcut for:: - bot.send_location(update.effective_chat.id, *args, **kwargs) + await bot.send_location(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_location`. @@ -1410,7 +1422,7 @@ async def send_animation( ) -> 'Message': """Shortcut for:: - bot.send_animation(update.effective_chat.id, *args, **kwargs) + await bot.send_animation(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_animation`. @@ -1457,7 +1469,7 @@ async def send_sticker( ) -> 'Message': """Shortcut for:: - bot.send_sticker(update.effective_chat.id, *args, **kwargs) + await bot.send_sticker(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_sticker`. @@ -1504,7 +1516,7 @@ async def send_venue( ) -> 'Message': """Shortcut for:: - bot.send_venue(update.effective_chat.id, *args, **kwargs) + await bot.send_venue(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_venue`. @@ -1560,7 +1572,7 @@ async def send_video( ) -> 'Message': """Shortcut for:: - bot.send_video(update.effective_chat.id, *args, **kwargs) + await bot.send_video(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_video`. @@ -1612,7 +1624,7 @@ async def send_video_note( ) -> 'Message': """Shortcut for:: - bot.send_video_note(update.effective_chat.id, *args, **kwargs) + await bot.send_video_note(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_video_note`. @@ -1660,7 +1672,7 @@ async def send_voice( ) -> 'Message': """Shortcut for:: - bot.send_voice(update.effective_chat.id, *args, **kwargs) + await bot.send_voice(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_voice`. @@ -1716,7 +1728,7 @@ async def send_poll( ) -> 'Message': """Shortcut for:: - bot.send_poll(update.effective_chat.id, *args, **kwargs) + await bot.send_poll(update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_poll`. @@ -1770,7 +1782,7 @@ async def send_copy( ) -> 'MessageId': """Shortcut for:: - bot.copy_message(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.copy_message(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.copy_message`. @@ -1817,7 +1829,7 @@ async def copy_message( ) -> 'MessageId': """Shortcut for:: - bot.copy_message(from_chat_id=update.effective_chat.id, *args, **kwargs) + await bot.copy_message(from_chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.copy_message`. @@ -1854,7 +1866,7 @@ async def export_invite_link( ) -> str: """Shortcut for:: - bot.export_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.export_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.export_chat_invite_link`. @@ -1888,7 +1900,7 @@ async def create_invite_link( ) -> 'ChatInviteLink': """Shortcut for:: - bot.create_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.create_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.create_chat_invite_link`. @@ -1931,7 +1943,7 @@ async def edit_invite_link( ) -> 'ChatInviteLink': """Shortcut for:: - bot.edit_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.edit_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.edit_chat_invite_link`. @@ -1970,7 +1982,7 @@ async def revoke_invite_link( ) -> 'ChatInviteLink': """Shortcut for:: - bot.revoke_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.revoke_chat_invite_link(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.revoke_chat_invite_link`. @@ -2002,7 +2014,7 @@ async def approve_join_request( ) -> bool: """Shortcut for:: - bot.approve_chat_join_request(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.approve_chat_join_request(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.approve_chat_join_request`. @@ -2034,7 +2046,7 @@ async def decline_join_request( ) -> bool: """Shortcut for:: - bot.decline_chat_join_request(chat_id=update.effective_chat.id, *args, **kwargs) + await bot.decline_chat_join_request(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.decline_chat_join_request`. diff --git a/telegram/_chatjoinrequest.py b/telegram/_chatjoinrequest.py index daca7b41791..6581cd55e47 100644 --- a/telegram/_chatjoinrequest.py +++ b/telegram/_chatjoinrequest.py @@ -119,7 +119,7 @@ async def approve( ) -> bool: """Shortcut for:: - bot.approve_chat_join_request(chat_id=update.effective_chat.id, + await bot.approve_chat_join_request(chat_id=update.effective_chat.id, user_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see @@ -149,7 +149,7 @@ async def decline( ) -> bool: """Shortcut for:: - bot.decline_chat_join_request(chat_id=update.effective_chat.id, + await bot.decline_chat_join_request(chat_id=update.effective_chat.id, user_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index ec77c1a1293..b2aa8ee3c1b 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -35,13 +35,19 @@ class InputFile: """This object represents a Telegram InputFile. + .. versionchanged:: 14.0 + The former attribute ``attach`` was renamed to :attr:`attach_name`. + Args: - obj (:term:`file object` | :obj:`bytes`): An open file descriptor or the files content as - bytes. + obj (:term:`file object` | :obj:`bytes` | :obj:`str`): An open file descriptor or the files + content as bytes or string. Note: If :paramref:`obj` is a string, it will be encoded as bytes via :external:obj:`obj.encode('utf-8') `. + + .. versionchanged:: 14.0 + Accept string input. filename (:obj:`str`, optional): Filename for this InputFile. Attributes: diff --git a/telegram/_message.py b/telegram/_message.py index 975d35cd1cf..8cfca085807 100644 --- a/telegram/_message.py +++ b/telegram/_message.py @@ -745,7 +745,7 @@ async def reply_text( ) -> 'Message': """Shortcut for:: - bot.send_message(update.effective_message.chat_id, *args, **kwargs) + await bot.send_message(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_message`. @@ -797,7 +797,7 @@ async def reply_markdown( ) -> 'Message': """Shortcut for:: - bot.send_message( + await bot.send_message( update.effective_message.chat_id, parse_mode=ParseMode.MARKDOWN, *args, @@ -859,7 +859,7 @@ async def reply_markdown_v2( ) -> 'Message': """Shortcut for:: - bot.send_message( + await bot.send_message( update.effective_message.chat_id, parse_mode=ParseMode.MARKDOWN_V2, *args, @@ -917,7 +917,7 @@ async def reply_html( ) -> 'Message': """Shortcut for:: - bot.send_message( + await bot.send_message( update.effective_message.chat_id, parse_mode=ParseMode.HTML, *args, @@ -974,7 +974,7 @@ async def reply_media_group( ) -> List['Message']: """Shortcut for:: - bot.send_media_group(update.effective_message.chat_id, *args, **kwargs) + await bot.send_media_group(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_media_group`. @@ -1026,7 +1026,7 @@ async def reply_photo( ) -> 'Message': """Shortcut for:: - bot.send_photo(update.effective_message.chat_id, *args, **kwargs) + await bot.send_photo(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_photo`. @@ -1085,7 +1085,7 @@ async def reply_audio( ) -> 'Message': """Shortcut for:: - bot.send_audio(update.effective_message.chat_id, *args, **kwargs) + await bot.send_audio(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_audio`. @@ -1146,7 +1146,7 @@ async def reply_document( ) -> 'Message': """Shortcut for:: - bot.send_document(update.effective_message.chat_id, *args, **kwargs) + await bot.send_document(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_document`. @@ -1207,7 +1207,7 @@ async def reply_animation( ) -> 'Message': """Shortcut for:: - bot.send_animation(update.effective_message.chat_id, *args, **kwargs) + await bot.send_animation(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_animation`. @@ -1262,7 +1262,7 @@ async def reply_sticker( ) -> 'Message': """Shortcut for:: - bot.send_sticker(update.effective_message.chat_id, *args, **kwargs) + await bot.send_sticker(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_sticker`. @@ -1318,7 +1318,7 @@ async def reply_video( ) -> 'Message': """Shortcut for:: - bot.send_video(update.effective_message.chat_id, *args, **kwargs) + await bot.send_video(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_video`. @@ -1378,7 +1378,7 @@ async def reply_video_note( ) -> 'Message': """Shortcut for:: - bot.send_video_note(update.effective_message.chat_id, *args, **kwargs) + await bot.send_video_note(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_video_note`. @@ -1434,7 +1434,7 @@ async def reply_voice( ) -> 'Message': """Shortcut for:: - bot.send_voice(update.effective_message.chat_id, *args, **kwargs) + await bot.send_voice(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_voice`. @@ -1492,7 +1492,7 @@ async def reply_location( ) -> 'Message': """Shortcut for:: - bot.send_location(update.effective_message.chat_id, *args, **kwargs) + await bot.send_location(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_location`. @@ -1553,7 +1553,7 @@ async def reply_venue( ) -> 'Message': """Shortcut for:: - bot.send_venue(update.effective_message.chat_id, *args, **kwargs) + await bot.send_venue(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_venue`. @@ -1612,7 +1612,7 @@ async def reply_contact( ) -> 'Message': """Shortcut for:: - bot.send_contact(update.effective_message.chat_id, *args, **kwargs) + await bot.send_contact(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_contact`. @@ -1674,7 +1674,7 @@ async def reply_poll( ) -> 'Message': """Shortcut for:: - bot.send_poll(update.effective_message.chat_id, *args, **kwargs) + await bot.send_poll(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_poll`. @@ -1732,7 +1732,7 @@ async def reply_dice( ) -> 'Message': """Shortcut for:: - bot.send_dice(update.effective_message.chat_id, *args, **kwargs) + await bot.send_dice(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_dice`. @@ -1773,7 +1773,7 @@ async def reply_chat_action( ) -> bool: """Shortcut for:: - bot.send_chat_action(update.effective_message.chat_id, *args, **kwargs) + await bot.send_chat_action(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_chat_action`. @@ -1810,7 +1810,7 @@ async def reply_game( ) -> 'Message': """Shortcut for:: - bot.send_game(update.effective_message.chat_id, *args, **kwargs) + await bot.send_game(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_game`. @@ -1879,7 +1879,7 @@ async def reply_invoice( ) -> 'Message': """Shortcut for:: - bot.send_invoice(update.effective_message.chat_id, *args, **kwargs) + await bot.send_invoice(update.effective_message.chat_id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_invoice`. @@ -1952,7 +1952,7 @@ async def forward( ) -> 'Message': """Shortcut for:: - bot.forward_message(chat_id=chat_id, + await bot.forward_message(chat_id=chat_id, from_chat_id=update.effective_message.chat_id, message_id=update.effective_message.message_id, *args, @@ -2004,7 +2004,7 @@ async def copy( ) -> 'MessageId': """Shortcut for:: - bot.copy_message(chat_id=chat_id, + await bot.copy_message(chat_id=chat_id, from_chat_id=update.effective_message.chat_id, message_id=update.effective_message.message_id, *args, @@ -2056,7 +2056,7 @@ async def reply_copy( ) -> 'MessageId': """Shortcut for:: - bot.copy_message(chat_id=message.chat.id, + await bot.copy_message(chat_id=message.chat.id, from_chat_id=from_chat_id, message_id=message_id, *args, @@ -2111,7 +2111,7 @@ async def edit_text( ) -> Union['Message', bool]: """Shortcut for:: - bot.edit_message_text(chat_id=message.chat_id, + await bot.edit_message_text(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2158,7 +2158,7 @@ async def edit_caption( ) -> Union['Message', bool]: """Shortcut for:: - bot.edit_message_caption(chat_id=message.chat_id, + await bot.edit_message_caption(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2203,7 +2203,7 @@ async def edit_media( ) -> Union['Message', bool]: """Shortcut for:: - bot.edit_message_media(chat_id=message.chat_id, + await bot.edit_message_media(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2245,7 +2245,7 @@ async def edit_reply_markup( ) -> Union['Message', bool]: """Shortcut for:: - bot.edit_message_reply_markup(chat_id=message.chat_id, + await bot.edit_message_reply_markup(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2291,7 +2291,7 @@ async def edit_live_location( ) -> Union['Message', bool]: """Shortcut for:: - bot.edit_message_live_location(chat_id=message.chat_id, + await bot.edit_message_live_location(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2337,7 +2337,7 @@ async def stop_live_location( ) -> Union['Message', bool]: """Shortcut for:: - bot.stop_message_live_location(chat_id=message.chat_id, + await bot.stop_message_live_location(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2380,7 +2380,7 @@ async def set_game_score( ) -> Union['Message', bool]: """Shortcut for:: - bot.set_game_score(chat_id=message.chat_id, + await bot.set_game_score(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2422,7 +2422,7 @@ async def get_game_high_scores( ) -> List['GameHighScore']: """Shortcut for:: - bot.get_game_high_scores(chat_id=message.chat_id, + await bot.get_game_high_scores(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2460,7 +2460,7 @@ async def delete( ) -> bool: """Shortcut for:: - bot.delete_message(chat_id=message.chat_id, + await bot.delete_message(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2492,7 +2492,7 @@ async def stop_poll( ) -> Poll: """Shortcut for:: - bot.stop_poll(chat_id=message.chat_id, + await bot.stop_poll(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2526,7 +2526,7 @@ async def pin( ) -> bool: """Shortcut for:: - bot.pin_chat_message(chat_id=message.chat_id, + await bot.pin_chat_message(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) @@ -2558,7 +2558,7 @@ async def unpin( ) -> bool: """Shortcut for:: - bot.unpin_chat_message(chat_id=message.chat_id, + await bot.unpin_chat_message(chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs) diff --git a/telegram/_payment/precheckoutquery.py b/telegram/_payment/precheckoutquery.py index 897eb4c7d9a..9fc8a6022b2 100644 --- a/telegram/_payment/precheckoutquery.py +++ b/telegram/_payment/precheckoutquery.py @@ -126,7 +126,7 @@ async def answer( # pylint: disable=invalid-name ) -> bool: """Shortcut for:: - bot.answer_pre_checkout_query(update.pre_checkout_query.id, *args, **kwargs) + await bot.answer_pre_checkout_query(update.pre_checkout_query.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.answer_pre_checkout_query`. diff --git a/telegram/_payment/shippingquery.py b/telegram/_payment/shippingquery.py index 901e3e28f14..595a7409f48 100644 --- a/telegram/_payment/shippingquery.py +++ b/telegram/_payment/shippingquery.py @@ -100,7 +100,7 @@ async def answer( # pylint: disable=invalid-name ) -> bool: """Shortcut for:: - bot.answer_shipping_query(update.shipping_query.id, *args, **kwargs) + await bot.answer_shipping_query(update.shipping_query.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.answer_shipping_query`. diff --git a/telegram/_user.py b/telegram/_user.py index e880c5f2967..760826f9ab9 100644 --- a/telegram/_user.py +++ b/telegram/_user.py @@ -176,7 +176,7 @@ async def get_profile_photos( """ Shortcut for:: - bot.get_user_profile_photos(update.effective_user.id, *args, **kwargs) + await bot.get_user_profile_photos(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.get_user_profile_photos`. @@ -265,7 +265,7 @@ async def pin_message( ) -> bool: """Shortcut for:: - bot.pin_chat_message(chat_id=update.effective_user.id, + await bot.pin_chat_message(chat_id=update.effective_user.id, *args, **kwargs) @@ -297,7 +297,7 @@ async def unpin_message( ) -> bool: """Shortcut for:: - bot.unpin_chat_message(chat_id=update.effective_user.id, + await bot.unpin_chat_message(chat_id=update.effective_user.id, *args, **kwargs) @@ -327,7 +327,7 @@ async def unpin_all_messages( ) -> bool: """Shortcut for:: - bot.unpin_all_chat_messages(chat_id=update.effective_user.id, + await bot.unpin_all_chat_messages(chat_id=update.effective_user.id, *args, **kwargs) @@ -366,7 +366,7 @@ async def send_message( ) -> 'Message': """Shortcut for:: - bot.send_message(update.effective_user.id, *args, **kwargs) + await bot.send_message(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_message`. @@ -412,7 +412,7 @@ async def send_photo( ) -> 'Message': """Shortcut for:: - bot.send_photo(update.effective_user.id, *args, **kwargs) + await bot.send_photo(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_photo`. @@ -456,7 +456,7 @@ async def send_media_group( ) -> List['Message']: """Shortcut for:: - bot.send_media_group(update.effective_user.id, *args, **kwargs) + await bot.send_media_group(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_media_group`. @@ -502,7 +502,7 @@ async def send_audio( ) -> 'Message': """Shortcut for:: - bot.send_audio(update.effective_user.id, *args, **kwargs) + await bot.send_audio(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_audio`. @@ -544,7 +544,7 @@ async def send_chat_action( ) -> bool: """Shortcut for:: - bot.send_chat_action(update.effective_user.id, *args, **kwargs) + await bot.send_chat_action(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_chat_action`. @@ -585,7 +585,7 @@ async def send_contact( ) -> 'Message': """Shortcut for:: - bot.send_contact(update.effective_user.id, *args, **kwargs) + await bot.send_contact(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_contact`. @@ -628,7 +628,7 @@ async def send_dice( ) -> 'Message': """Shortcut for:: - bot.send_dice(update.effective_user.id, *args, **kwargs) + await bot.send_dice(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_dice`. @@ -673,7 +673,7 @@ async def send_document( ) -> 'Message': """Shortcut for:: - bot.send_document(update.effective_user.id, *args, **kwargs) + await bot.send_document(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_document`. @@ -718,7 +718,7 @@ async def send_game( ) -> 'Message': """Shortcut for:: - bot.send_game(update.effective_user.id, *args, **kwargs) + await bot.send_game(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_game`. @@ -777,7 +777,7 @@ async def send_invoice( ) -> 'Message': """Shortcut for:: - bot.send_invoice(update.effective_user.id, *args, **kwargs) + await bot.send_invoice(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_invoice`. @@ -850,7 +850,7 @@ async def send_location( ) -> 'Message': """Shortcut for:: - bot.send_location(update.effective_user.id, *args, **kwargs) + await bot.send_location(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_location`. @@ -903,7 +903,7 @@ async def send_animation( ) -> 'Message': """Shortcut for:: - bot.send_animation(update.effective_user.id, *args, **kwargs) + await bot.send_animation(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_animation`. @@ -950,7 +950,7 @@ async def send_sticker( ) -> 'Message': """Shortcut for:: - bot.send_sticker(update.effective_user.id, *args, **kwargs) + await bot.send_sticker(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_sticker`. @@ -998,7 +998,7 @@ async def send_video( ) -> 'Message': """Shortcut for:: - bot.send_video(update.effective_user.id, *args, **kwargs) + await bot.send_video(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_video`. @@ -1054,7 +1054,7 @@ async def send_venue( ) -> 'Message': """Shortcut for:: - bot.send_venue(update.effective_user.id, *args, **kwargs) + await bot.send_venue(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_venue`. @@ -1105,7 +1105,7 @@ async def send_video_note( ) -> 'Message': """Shortcut for:: - bot.send_video_note(update.effective_user.id, *args, **kwargs) + await bot.send_video_note(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_video_note`. @@ -1153,7 +1153,7 @@ async def send_voice( ) -> 'Message': """Shortcut for:: - bot.send_voice(update.effective_user.id, *args, **kwargs) + await bot.send_voice(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_voice`. @@ -1209,7 +1209,7 @@ async def send_poll( ) -> 'Message': """Shortcut for:: - bot.send_poll(update.effective_user.id, *args, **kwargs) + await bot.send_poll(update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.send_poll`. @@ -1263,7 +1263,7 @@ async def send_copy( ) -> 'MessageId': """Shortcut for:: - bot.copy_message(chat_id=update.effective_user.id, *args, **kwargs) + await bot.copy_message(chat_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.copy_message`. @@ -1310,7 +1310,7 @@ async def copy_message( ) -> 'MessageId': """Shortcut for:: - bot.copy_message(from_chat_id=update.effective_user.id, *args, **kwargs) + await bot.copy_message(from_chat_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.copy_message`. @@ -1348,7 +1348,7 @@ async def approve_join_request( ) -> bool: """Shortcut for:: - bot.approve_chat_join_request(user_id=update.effective_user.id, *args, **kwargs) + await bot.approve_chat_join_request(user_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.approve_chat_join_request`. @@ -1380,7 +1380,7 @@ async def decline_join_request( ) -> bool: """Shortcut for:: - bot.decline_chat_join_request(user_id=update.effective_user.id, *args, **kwargs) + await bot.decline_chat_join_request(user_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.decline_chat_join_request`. From 696e622db7d4c7d2e67cc7cb297e81a1d40a978a Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:07:58 +0200 Subject: [PATCH 122/153] Apply black style to code blocks in docs --- .github/CONTRIBUTING.rst | 5 +- README.rst | 6 +- README_RAW.rst | 6 +- telegram/_callbackquery.py | 3 +- telegram/_chat.py | 24 ++---- telegram/_chatjoinrequest.py | 10 ++- telegram/_chatmemberupdated.py | 2 +- telegram/_inline/inlinequery.py | 8 +- telegram/_message.py | 126 +++++++++++++--------------- telegram/_user.py | 18 ++-- telegram/_utils/defaultvalue.py | 3 +- telegram/ext/_applicationbuilder.py | 2 +- telegram/ext/_commandhandler.py | 9 +- telegram/ext/_utils/stack.py | 2 +- 14 files changed, 103 insertions(+), 121 deletions(-) diff --git a/.github/CONTRIBUTING.rst b/.github/CONTRIBUTING.rst index 68a98e2b27b..e34cb1dd76b 100644 --- a/.github/CONTRIBUTING.rst +++ b/.github/CONTRIBUTING.rst @@ -254,11 +254,12 @@ break the API classes. For example: # GOOD def __init__(self, id, name, last_name=None, **kwargs): - self.last_name = last_name + self.last_name = last_name + # BAD def __init__(self, id, name, last_name=None): - self.last_name = last_name + self.last_name = last_name .. _`Code of Conduct`: https://www.python.org/psf/codeofconduct/ diff --git a/README.rst b/README.rst index ad83ab74398..fbdfd5a6e62 100644 --- a/README.rst +++ b/README.rst @@ -190,8 +190,10 @@ This library uses the ``logging`` module. To set up logging to standard output, .. code:: python import logging - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) at the beginning of your script. diff --git a/README_RAW.rst b/README_RAW.rst index e94a0d4bf62..fb4184d56e1 100644 --- a/README_RAW.rst +++ b/README_RAW.rst @@ -165,8 +165,10 @@ This library uses the ``logging`` module. To set up logging to standard output, .. code:: python import logging - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) at the beginning of your script. diff --git a/telegram/_callbackquery.py b/telegram/_callbackquery.py index e24fe2841ac..84a5c7f4d0f 100644 --- a/telegram/_callbackquery.py +++ b/telegram/_callbackquery.py @@ -728,7 +728,8 @@ async def copy_message( from_chat_id=update.message.chat_id, message_id=update.message.message_id, *args, - **kwargs) + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Message.copy`. diff --git a/telegram/_chat.py b/telegram/_chat.py index a9310312186..ea9469721e6 100644 --- a/telegram/_chat.py +++ b/telegram/_chat.py @@ -491,9 +491,7 @@ async def ban_chat( """Shortcut for:: await bot.ban_chat_sender_chat( - sender_chat_id=update.effective_chat.id, - *args, - **kwargs + sender_chat_id=update.effective_chat.id, *args, **kwargs ) For the documentation of the arguments, please see @@ -559,9 +557,7 @@ async def unban_chat( """Shortcut for:: await bot.unban_chat_sender_chat( - sender_chat_id=update.effective_chat.id, - *args, - **kwargs + sender_chat_id=update.effective_chat.id, *args, **kwargs ) For the documentation of the arguments, please see @@ -747,9 +743,7 @@ async def set_administrator_custom_title( """Shortcut for:: await bot.set_chat_administrator_custom_title( - update.effective_chat.id, - *args, - **kwargs + update.effective_chat.id, *args, **kwargs ) For the documentation of the arguments, please see @@ -782,9 +776,7 @@ async def pin_message( ) -> bool: """Shortcut for:: - await bot.pin_chat_message(chat_id=update.effective_chat.id, - *args, - **kwargs) + await bot.pin_chat_message(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.pin_chat_message`. @@ -815,9 +807,7 @@ async def unpin_message( ) -> bool: """Shortcut for:: - await bot.unpin_chat_message(chat_id=update.effective_chat.id, - *args, - **kwargs) + await bot.unpin_chat_message(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.unpin_chat_message`. @@ -846,9 +836,7 @@ async def unpin_all_messages( ) -> bool: """Shortcut for:: - await bot.unpin_all_chat_messages(chat_id=update.effective_chat.id, - *args, - **kwargs) + await bot.unpin_all_chat_messages(chat_id=update.effective_chat.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.unpin_all_chat_messages`. diff --git a/telegram/_chatjoinrequest.py b/telegram/_chatjoinrequest.py index 6581cd55e47..f8c60431f1d 100644 --- a/telegram/_chatjoinrequest.py +++ b/telegram/_chatjoinrequest.py @@ -119,8 +119,9 @@ async def approve( ) -> bool: """Shortcut for:: - await bot.approve_chat_join_request(chat_id=update.effective_chat.id, - user_id=update.effective_user.id, *args, **kwargs) + await bot.approve_chat_join_request( + chat_id=update.effective_chat.id, user_id=update.effective_user.id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.approve_chat_join_request`. @@ -149,8 +150,9 @@ async def decline( ) -> bool: """Shortcut for:: - await bot.decline_chat_join_request(chat_id=update.effective_chat.id, - user_id=update.effective_user.id, *args, **kwargs) + await bot.decline_chat_join_request( + chat_id=update.effective_chat.id, user_id=update.effective_user.id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.decline_chat_join_request`. diff --git a/telegram/_chatmemberupdated.py b/telegram/_chatmemberupdated.py index 640b7a2289a..8b4bcd659c4 100644 --- a/telegram/_chatmemberupdated.py +++ b/telegram/_chatmemberupdated.py @@ -136,7 +136,7 @@ def difference( """Computes the difference between :attr:`old_chat_member` and :attr:`new_chat_member`. Example: - .. code:: python + .. code:: pycon >>> chat_member_updated.difference() {'custom_title': ('old title', 'new title')} diff --git a/telegram/_inline/inlinequery.py b/telegram/_inline/inlinequery.py index 026d5ebc481..cb8adc2e219 100644 --- a/telegram/_inline/inlinequery.py +++ b/telegram/_inline/inlinequery.py @@ -131,10 +131,10 @@ async def answer( """Shortcut for:: await bot.answer_inline_query( - update.inline_query.id, - *args, - current_offset=self.offset if auto_pagination else None, - **kwargs + update.inline_query.id, + *args, + current_offset=self.offset if auto_pagination else None, + **kwargs ) For the documentation of the arguments, please see diff --git a/telegram/_message.py b/telegram/_message.py index 8cfca085807..902cdc6e154 100644 --- a/telegram/_message.py +++ b/telegram/_message.py @@ -797,7 +797,7 @@ async def reply_markdown( ) -> 'Message': """Shortcut for:: - await bot.send_message( + await bot.send_message( update.effective_message.chat_id, parse_mode=ParseMode.MARKDOWN, *args, @@ -859,7 +859,7 @@ async def reply_markdown_v2( ) -> 'Message': """Shortcut for:: - await bot.send_message( + await bot.send_message( update.effective_message.chat_id, parse_mode=ParseMode.MARKDOWN_V2, *args, @@ -917,7 +917,7 @@ async def reply_html( ) -> 'Message': """Shortcut for:: - await bot.send_message( + await bot.send_message( update.effective_message.chat_id, parse_mode=ParseMode.HTML, *args, @@ -1952,11 +1952,13 @@ async def forward( ) -> 'Message': """Shortcut for:: - await bot.forward_message(chat_id=chat_id, - from_chat_id=update.effective_message.chat_id, - message_id=update.effective_message.message_id, - *args, - **kwargs) + await bot.forward_message( + chat_id=chat_id, + from_chat_id=update.effective_message.chat_id, + message_id=update.effective_message.message_id, + *args, + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.forward_message`. @@ -2004,11 +2006,13 @@ async def copy( ) -> 'MessageId': """Shortcut for:: - await bot.copy_message(chat_id=chat_id, - from_chat_id=update.effective_message.chat_id, - message_id=update.effective_message.message_id, - *args, - **kwargs) + await bot.copy_message( + chat_id=chat_id, + from_chat_id=update.effective_message.chat_id, + message_id=update.effective_message.message_id, + *args, + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.copy_message`. @@ -2056,11 +2060,13 @@ async def reply_copy( ) -> 'MessageId': """Shortcut for:: - await bot.copy_message(chat_id=message.chat.id, - from_chat_id=from_chat_id, - message_id=message_id, - *args, - **kwargs) + await bot.copy_message( + chat_id=message.chat.id, + from_chat_id=from_chat_id, + message_id=message_id, + *args, + **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.copy_message`. @@ -2111,10 +2117,9 @@ async def edit_text( ) -> Union['Message', bool]: """Shortcut for:: - await bot.edit_message_text(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.edit_message_text( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.edit_message_text`. @@ -2158,10 +2163,9 @@ async def edit_caption( ) -> Union['Message', bool]: """Shortcut for:: - await bot.edit_message_caption(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.edit_message_caption( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.edit_message_caption`. @@ -2203,10 +2207,9 @@ async def edit_media( ) -> Union['Message', bool]: """Shortcut for:: - await bot.edit_message_media(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.edit_message_media( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.edit_message_media`. @@ -2245,10 +2248,9 @@ async def edit_reply_markup( ) -> Union['Message', bool]: """Shortcut for:: - await bot.edit_message_reply_markup(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.edit_message_reply_markup( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.edit_message_reply_markup`. @@ -2291,10 +2293,9 @@ async def edit_live_location( ) -> Union['Message', bool]: """Shortcut for:: - await bot.edit_message_live_location(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.edit_message_live_location( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.edit_message_live_location`. @@ -2337,10 +2338,9 @@ async def stop_live_location( ) -> Union['Message', bool]: """Shortcut for:: - await bot.stop_message_live_location(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.stop_message_live_location( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.stop_message_live_location`. @@ -2380,10 +2380,9 @@ async def set_game_score( ) -> Union['Message', bool]: """Shortcut for:: - await bot.set_game_score(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.set_game_score( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.set_game_score`. @@ -2422,10 +2421,9 @@ async def get_game_high_scores( ) -> List['GameHighScore']: """Shortcut for:: - await bot.get_game_high_scores(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.get_game_high_scores( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.get_game_high_scores`. @@ -2460,10 +2458,9 @@ async def delete( ) -> bool: """Shortcut for:: - await bot.delete_message(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.delete_message( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.delete_message`. @@ -2492,10 +2489,9 @@ async def stop_poll( ) -> Poll: """Shortcut for:: - await bot.stop_poll(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.stop_poll( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.stop_poll`. @@ -2526,10 +2522,9 @@ async def pin( ) -> bool: """Shortcut for:: - await bot.pin_chat_message(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.pin_chat_message( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.pin_chat_message`. @@ -2558,10 +2553,9 @@ async def unpin( ) -> bool: """Shortcut for:: - await bot.unpin_chat_message(chat_id=message.chat_id, - message_id=message.message_id, - *args, - **kwargs) + await bot.unpin_chat_message( + chat_id=message.chat_id, message_id=message.message_id, *args, **kwargs + ) For the documentation of the arguments, please see :meth:`telegram.Bot.unpin_chat_message`. diff --git a/telegram/_user.py b/telegram/_user.py index 760826f9ab9..1bd2c5361d2 100644 --- a/telegram/_user.py +++ b/telegram/_user.py @@ -173,8 +173,7 @@ async def get_profile_photos( pool_timeout: ODVInput[float] = DEFAULT_NONE, api_kwargs: JSONDict = None, ) -> Optional['UserProfilePhotos']: - """ - Shortcut for:: + """Shortcut for:: await bot.get_user_profile_photos(update.effective_user.id, *args, **kwargs) @@ -238,8 +237,7 @@ def mention_html(self, name: str = None) -> str: return helpers_mention_html(self.id, self.full_name) def mention_button(self, name: str = None) -> InlineKeyboardButton: - """ - Shortcut for:: + """Shortcut for:: InlineKeyboardButton(text=name, url=f"tg://user?id={update.effective_user.id}") @@ -265,9 +263,7 @@ async def pin_message( ) -> bool: """Shortcut for:: - await bot.pin_chat_message(chat_id=update.effective_user.id, - *args, - **kwargs) + await bot.pin_chat_message(chat_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.pin_chat_message`. @@ -297,9 +293,7 @@ async def unpin_message( ) -> bool: """Shortcut for:: - await bot.unpin_chat_message(chat_id=update.effective_user.id, - *args, - **kwargs) + await bot.unpin_chat_message(chat_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.unpin_chat_message`. @@ -327,9 +321,7 @@ async def unpin_all_messages( ) -> bool: """Shortcut for:: - await bot.unpin_all_chat_messages(chat_id=update.effective_user.id, - *args, - **kwargs) + await bot.unpin_all_chat_messages(chat_id=update.effective_user.id, *args, **kwargs) For the documentation of the arguments, please see :meth:`telegram.Bot.unpin_all_chat_messages`. diff --git a/telegram/_utils/defaultvalue.py b/telegram/_utils/defaultvalue.py index c4f4ef71742..581678e00a9 100644 --- a/telegram/_utils/defaultvalue.py +++ b/telegram/_utils/defaultvalue.py @@ -101,8 +101,7 @@ def get_value(obj: OT) -> OT: @staticmethod def get_value(obj: Union[OT, 'DefaultValue[OT]']) -> OT: - """ - Shortcut for:: + """Shortcut for:: return obj.value if isinstance(obj, DefaultValue) else obj diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index cb02d65753c..d92f1a7b5e8 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -104,7 +104,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): Example: .. code:: python - application = ApplicationBuilder().token('TOKEN').build() + application = ApplicationBuilder().token("TOKEN").build() Please see the description of the individual methods for information on which arguments can be set and what the defaults are when not called. When no default is mentioned, the argument will diff --git a/telegram/ext/_commandhandler.py b/telegram/ext/_commandhandler.py index 8e3c29f17d4..36473fd5171 100644 --- a/telegram/ext/_commandhandler.py +++ b/telegram/ext/_commandhandler.py @@ -178,20 +178,21 @@ class PrefixHandler(CommandHandler): .. code:: python - PrefixHandler('!', 'test', callback) # will respond to '!test'. + PrefixHandler("!", "test", callback) # will respond to '!test'. Multiple prefixes, single command: .. code:: python - PrefixHandler(['!', '#'], 'test', callback) # will respond to '!test' and '#test'. + PrefixHandler(["!", "#"], "test", callback) # will respond to '!test' and '#test'. Multiple prefixes and commands: .. code:: python - PrefixHandler(['!', '#'], ['test', 'help'], callback) # will respond to '!test', \ - '#test', '!help' and '#help'. + PrefixHandler( + ["!", "#"], ["test", "help"], callback + ) # will respond to '!test', '#test', '!help' and '#help'. By default, the handler listens to messages as well as edited messages. To change this behavior diff --git a/telegram/ext/_utils/stack.py b/telegram/ext/_utils/stack.py index 6b2324a80c8..e4df470d715 100644 --- a/telegram/ext/_utils/stack.py +++ b/telegram/ext/_utils/stack.py @@ -34,7 +34,7 @@ def was_called_by(frame: Optional[FrameType], caller: Path) -> bool: """Checks if the passed frame was called by the specified file. Example: - .. code:: python + .. code:: pycon >>> was_called_by(inspect.currentframe(), Path(__file__)) True From b0605a7096c35ca8b1e4d9865e1a87b90af93963 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:13:59 +0200 Subject: [PATCH 123/153] avoid the term "asynchronously" --- telegram/ext/_application.py | 6 +++--- telegram/ext/_chatjoinrequesthandler.py | 2 +- telegram/ext/_conversationhandler.py | 2 +- telegram/ext/_handler.py | 2 +- telegram/ext/_pollanswerhandler.py | 2 +- telegram/ext/_pollhandler.py | 2 +- telegram/ext/_precheckoutqueryhandler.py | 2 +- telegram/ext/_shippingqueryhandler.py | 2 +- tests/test_application.py | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 74302293656..0adf2572701 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -85,7 +85,7 @@ async def conversation_callback(update, context): raise ApplicationHandlerStop(next_state) Note: - Has no effect, if the handler or error handler is run asynchronously. + Has no effect, if the handler or error handler is run in a non-blocking way. Args: state (:obj:`object`, optional): The next state of the conversation. @@ -766,8 +766,8 @@ async def __create_task_callback( except Exception as exception: if isinstance(exception, ApplicationHandlerStop): warn( - 'ApplicationHandlerStop is not supported with asynchronously ' - 'running handlers.', + 'ApplicationHandlerStop is not supported with handlers ' + 'running non-blocking.', stacklevel=1, ) diff --git a/telegram/ext/_chatjoinrequesthandler.py b/telegram/ext/_chatjoinrequesthandler.py index 13e9b43054d..70e5fa51e1d 100644 --- a/telegram/ext/_chatjoinrequesthandler.py +++ b/telegram/ext/_chatjoinrequesthandler.py @@ -50,7 +50,7 @@ async def callback(update: Update, context: CallbackContext) Attributes: callback (:term:`coroutine function`): The callback function for this handler. - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. """ diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 3870558986e..5e52c035f0e 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -249,7 +249,7 @@ class ConversationHandler(Handler[Update, CCT]): when :attr:`per_message`, :attr:`per_chat`, :attr:`per_user` are all :obj:`False`. Attributes: - block (:obj:`bool`): Determines whether the callback will run asynchronously. Always + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. Always :obj:`True` since conversation handlers handle any non-blocking callbacks internally. """ diff --git a/telegram/ext/_handler.py b/telegram/ext/_handler.py index a12602b6623..a6c2bed6250 100644 --- a/telegram/ext/_handler.py +++ b/telegram/ext/_handler.py @@ -56,7 +56,7 @@ async def callback(update: Update, context: CallbackContext) Attributes: callback (:term:`coroutine function`): The callback function for this handler. - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. """ diff --git a/telegram/ext/_pollanswerhandler.py b/telegram/ext/_pollanswerhandler.py index a5a9276fed0..8d97b6361e2 100644 --- a/telegram/ext/_pollanswerhandler.py +++ b/telegram/ext/_pollanswerhandler.py @@ -48,7 +48,7 @@ async def callback(update: Update, context: CallbackContext) Attributes: callback (:term:`coroutine function`): The callback function for this handler. - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. """ diff --git a/telegram/ext/_pollhandler.py b/telegram/ext/_pollhandler.py index d6b37a7824a..38a8519156b 100644 --- a/telegram/ext/_pollhandler.py +++ b/telegram/ext/_pollhandler.py @@ -47,7 +47,7 @@ async def callback(update: Update, context: CallbackContext) Attributes: callback (:term:`coroutine function`): The callback function for this handler. - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. """ diff --git a/telegram/ext/_precheckoutqueryhandler.py b/telegram/ext/_precheckoutqueryhandler.py index 30e5f919275..35bf0afe80f 100644 --- a/telegram/ext/_precheckoutqueryhandler.py +++ b/telegram/ext/_precheckoutqueryhandler.py @@ -46,7 +46,7 @@ async def callback(update: Update, context: CallbackContext) Attributes: callback (:term:`coroutine function`): The callback function for this handler. - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. """ diff --git a/telegram/ext/_shippingqueryhandler.py b/telegram/ext/_shippingqueryhandler.py index e26a2028ef2..4f59521b5b7 100644 --- a/telegram/ext/_shippingqueryhandler.py +++ b/telegram/ext/_shippingqueryhandler.py @@ -46,7 +46,7 @@ async def callback(update: Update, context: CallbackContext) Attributes: callback (:term:`coroutine function`): The callback function for this handler. - block (:obj:`bool`): Determines whether the callback will run asynchronously. + block (:obj:`bool`): Determines whether the callback will run in a blocking way.. """ diff --git a/tests/test_application.py b/tests/test_application.py index cd76034becc..afeda89cc7c 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -945,7 +945,7 @@ async def callback(update, context): assert recwarn[0].category is PTBUserWarning assert ( str(recwarn[0].message) - == 'ApplicationHandlerStop is not supported with asynchronously running handlers.' + == 'ApplicationHandlerStop is not supported with handlers running non-blocking.' ) assert ( Path(recwarn[0].filename) == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' From 6938ed8544c579112d7deff59742fb62dcc90108 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:26:03 +0200 Subject: [PATCH 124/153] Some docstring improvements for app --- telegram/ext/_application.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 0adf2572701..8bee5d142d2 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -273,8 +273,8 @@ def running(self) -> bool: @property def concurrent_updates(self) -> int: - """:obj:`int`: Indicates the number of concurrent updates set. A value of ``0`` indicates - updates are *not* being processed concurrently. + """:obj:`int`: The number of concurrent updates that will be processed in parallel. A + value of ``0`` indicates updates are *not* being processed concurrently. """ return self._concurrent_updates @@ -418,8 +418,9 @@ async def start(self) -> None: :attr:`persistence` is set. Note: - This does *not* start fetching updates from Telegram. You need to either start - :attr:`updater` manually or use one of :meth:`run_polling` or :meth:`run_webhook`. + This does *not* start fetching updates from Telegram. To fetch updates, you need to + either start :attr:`updater` manually or use one of :meth:`run_polling` or + :meth:`run_webhook`. .. seealso:: :meth:`stop` @@ -523,9 +524,12 @@ def run_polling( drop_pending_updates: bool = None, close_loop: bool = True, ) -> None: - """Starts polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling`. + """Convenience method that takes care of initializing and starting the app, + polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and + a graceful shutdown of the app on exit. .. seealso:: + :meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown` :meth:`telegram.ext.Updater.start_polling`, :meth:`run_webhook` Args: @@ -603,15 +607,18 @@ def run_webhook( max_connections: int = 40, close_loop: bool = True, ) -> None: - """ - Starts a small http server to listen for updates via webhook using - :meth:`telegram.ext.Updater.start_webhook`. If :paramref:`cert` + """Convenience method that takes care of initializing and starting the app, + polling updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and + a graceful shutdown of the app on exit. + + If :paramref:`cert` and :paramref:`key` are not provided, the webhook will be started directly on ``http://listen:port/url_path``, so SSL can be handled by another application. Else, the webhook will be started on ``https://listen:port/url_path``. Also calls :meth:`telegram.Bot.set_webhook` as required. .. seealso:: + :meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown` :meth:`telegram.ext.Updater.start_webhook`, :meth:`run_polling` Args: From 68287570c972c7dbaefe68b28a95f7a1ba8c8542 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:29:39 +0200 Subject: [PATCH 125/153] Fix a warning on custom CC attributes --- telegram/ext/_callbackqueryhandler.py | 2 +- telegram/ext/_chatjoinrequesthandler.py | 2 +- telegram/ext/_chatmemberhandler.py | 2 +- telegram/ext/_choseninlineresulthandler.py | 2 +- telegram/ext/_commandhandler.py | 4 ++-- telegram/ext/_handler.py | 2 +- telegram/ext/_inlinequeryhandler.py | 2 +- telegram/ext/_messagehandler.py | 2 +- telegram/ext/_pollanswerhandler.py | 2 +- telegram/ext/_pollhandler.py | 2 +- telegram/ext/_precheckoutqueryhandler.py | 2 +- telegram/ext/_shippingqueryhandler.py | 2 +- telegram/ext/_stringcommandhandler.py | 2 +- telegram/ext/_stringregexhandler.py | 2 +- telegram/ext/_typehandler.py | 2 +- 15 files changed, 16 insertions(+), 16 deletions(-) diff --git a/telegram/ext/_callbackqueryhandler.py b/telegram/ext/_callbackqueryhandler.py index e5a7fff5b92..c543a2c784b 100644 --- a/telegram/ext/_callbackqueryhandler.py +++ b/telegram/ext/_callbackqueryhandler.py @@ -61,7 +61,7 @@ class CallbackQueryHandler(Handler[Update, CCT]): .. versionadded:: 13.6 Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_chatjoinrequesthandler.py b/telegram/ext/_chatjoinrequesthandler.py index 70e5fa51e1d..b89e2ca557f 100644 --- a/telegram/ext/_chatjoinrequesthandler.py +++ b/telegram/ext/_chatjoinrequesthandler.py @@ -30,7 +30,7 @@ class ChatJoinRequestHandler(Handler[Update, CCT]): :attr:`telegram.Update.chat_join_request`. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. .. versionadded:: 13.8 diff --git a/telegram/ext/_chatmemberhandler.py b/telegram/ext/_chatmemberhandler.py index 7ea2388bbc4..67fbbf40d4c 100644 --- a/telegram/ext/_chatmemberhandler.py +++ b/telegram/ext/_chatmemberhandler.py @@ -34,7 +34,7 @@ class ChatMemberHandler(Handler[Update, CCT]): .. versionadded:: 13.4 Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_choseninlineresulthandler.py b/telegram/ext/_choseninlineresulthandler.py index 9f40abc4d98..2262378368c 100644 --- a/telegram/ext/_choseninlineresulthandler.py +++ b/telegram/ext/_choseninlineresulthandler.py @@ -37,7 +37,7 @@ class ChosenInlineResultHandler(Handler[Update, CCT]): :attr:`telegram.Update.chosen_inline_result`. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_commandhandler.py b/telegram/ext/_commandhandler.py index 36473fd5171..ac3f1c59db4 100644 --- a/telegram/ext/_commandhandler.py +++ b/telegram/ext/_commandhandler.py @@ -48,7 +48,7 @@ class CommandHandler(Handler[Update, CCT]): * :class:`CommandHandler` does *not* handle (edited) channel posts. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: @@ -202,7 +202,7 @@ class PrefixHandler(CommandHandler): * :class:`PrefixHandler` does *not* handle (edited) channel posts. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_handler.py b/telegram/ext/_handler.py index a6c2bed6250..e9b6cd13085 100644 --- a/telegram/ext/_handler.py +++ b/telegram/ext/_handler.py @@ -35,7 +35,7 @@ class Handler(Generic[UT, CCT], ABC): """The base class for all update handlers. Create custom handlers by inheriting from it. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. .. versionchanged:: 14.0 diff --git a/telegram/ext/_inlinequeryhandler.py b/telegram/ext/_inlinequeryhandler.py index 5f9fed0ff61..f6e44dcd61d 100644 --- a/telegram/ext/_inlinequeryhandler.py +++ b/telegram/ext/_inlinequeryhandler.py @@ -47,7 +47,7 @@ class InlineQueryHandler(Handler[Update, CCT]): documentation of the :mod:`re` module for more information. Warning: - * When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + * When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. * :attr:`telegram.InlineQuery.chat_type` will not be set for inline queries from secret chats and may not be set for inline queries coming from third-party clients. These diff --git a/telegram/ext/_messagehandler.py b/telegram/ext/_messagehandler.py index 97f3ef26855..89c02eab0dd 100644 --- a/telegram/ext/_messagehandler.py +++ b/telegram/ext/_messagehandler.py @@ -36,7 +36,7 @@ class MessageHandler(Handler[Update, CCT]): """Handler class to handle Telegram messages. They might contain text, media or status updates. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_pollanswerhandler.py b/telegram/ext/_pollanswerhandler.py index 8d97b6361e2..d96cfc6252e 100644 --- a/telegram/ext/_pollanswerhandler.py +++ b/telegram/ext/_pollanswerhandler.py @@ -30,7 +30,7 @@ class PollAnswerHandler(Handler[Update, CCT]): :attr:`poll answer `. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_pollhandler.py b/telegram/ext/_pollhandler.py index 38a8519156b..39f586a37c4 100644 --- a/telegram/ext/_pollhandler.py +++ b/telegram/ext/_pollhandler.py @@ -29,7 +29,7 @@ class PollHandler(Handler[Update, CCT]): """Handler class to handle Telegram updates that contain a :attr:`poll `. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_precheckoutqueryhandler.py b/telegram/ext/_precheckoutqueryhandler.py index 35bf0afe80f..ef4e4893948 100644 --- a/telegram/ext/_precheckoutqueryhandler.py +++ b/telegram/ext/_precheckoutqueryhandler.py @@ -28,7 +28,7 @@ class PreCheckoutQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram :attr:`telegram.Update.pre_checkout_query`. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_shippingqueryhandler.py b/telegram/ext/_shippingqueryhandler.py index 4f59521b5b7..e2a7506dd99 100644 --- a/telegram/ext/_shippingqueryhandler.py +++ b/telegram/ext/_shippingqueryhandler.py @@ -28,7 +28,7 @@ class ShippingQueryHandler(Handler[Update, CCT]): """Handler class to handle Telegram :attr:`telegram.Update.shipping_query`. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_stringcommandhandler.py b/telegram/ext/_stringcommandhandler.py index 06821d615ec..e2cf25029b4 100644 --- a/telegram/ext/_stringcommandhandler.py +++ b/telegram/ext/_stringcommandhandler.py @@ -40,7 +40,7 @@ class StringCommandHandler(Handler[str, CCT]): put in the queue. For example to send messages with the bot using command line or API. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_stringregexhandler.py b/telegram/ext/_stringregexhandler.py index 381d0152f4a..0f2932b4caa 100644 --- a/telegram/ext/_stringregexhandler.py +++ b/telegram/ext/_stringregexhandler.py @@ -43,7 +43,7 @@ class StringRegexHandler(Handler[str, CCT]): put in the queue. For example to send messages with the bot using command line or API. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: diff --git a/telegram/ext/_typehandler.py b/telegram/ext/_typehandler.py index 7b53f5c1f43..b50a5a4592b 100644 --- a/telegram/ext/_typehandler.py +++ b/telegram/ext/_typehandler.py @@ -33,7 +33,7 @@ class TypeHandler(Handler[UT, CCT]): """Handler class to handle updates of custom types. Warning: - When setting :paramref:`block` to :obj:`True`, you cannot rely on adding custom + When setting :paramref:`block` to :obj:`False`, you cannot rely on adding custom attributes to :class:`telegram.ext.CallbackContext`. See its docs for more info. Args: From ec784a76afa4e2339d8cbd4dc47b11666bf4416f Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:39:15 +0200 Subject: [PATCH 126/153] Improve usage of AbstractAsyncContextManager --- telegram/ext/_application.py | 20 +++++++++++++++++++- telegram/ext/_updater.py | 20 +++++++++++++++++++- telegram/request/_baserequest.py | 2 +- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 8bee5d142d2..33522e12f10 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -22,6 +22,7 @@ import itertools import logging from collections import defaultdict +from contextlib import AbstractAsyncContextManager from copy import deepcopy from pathlib import Path from types import TracebackType, MappingProxyType @@ -101,7 +102,7 @@ def __init__(self, state: object = None) -> None: self.state = state -class Application(Generic[BT, CCT, UD, CD, BD, JQ]): +class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager): """This class dispatches all kinds of updates to its registered handlers, and is the entry point to a PTB application. @@ -109,6 +110,23 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ]): This class may not be initialized directly. Use :class:`telegram.ext.ApplicationBuilder` or :meth:`builder` (for convenience). + Instances of this class can be used as asyncio context managers, where + + .. code:: python + + async with application: + # code + + is roughly equivalent to + + .. code:: python + + try: + await application.initialize() + # code + finally: + await application.shutdown() + .. versionchanged:: 14.0 * Initialization is now done through the :class:`telegram.ext.ApplicationBuilder`. diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index c7c3272c001..fe074d45d07 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -20,6 +20,7 @@ import asyncio import logging import ssl +from contextlib import AbstractAsyncContextManager from pathlib import Path from types import TracebackType from typing import ( @@ -45,11 +46,28 @@ _UpdaterType = TypeVar('_UpdaterType', bound="Updater") -class Updater: +class Updater(AbstractAsyncContextManager): """This class fetches updates for the bot either via long polling or by starting a webhook server. Received updates are enqueued into the :attr:`update_queue` and may be fetched from there to handle them appropriately. + Instances of this class can be used as asyncio context managers, where + + .. code:: python + + async with updater: + # code + + is roughly equivalent to + + .. code:: python + + try: + await updater.initialize() + # code + finally: + await updater.shutdown() + .. versionchanged:: 14.0 * Removed argument and attribute ``user_sig_handler`` diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index 639f639e486..b6ce388fc7c 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -71,7 +71,7 @@ class BaseRequest( await request_object.initialize() # code finally: - await request_object.stop() + await request_object.shutdown() """ __slots__ = () From 26b0dc48439d7449b83da7558801c8f79a5e542b Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:45:38 +0200 Subject: [PATCH 127/153] Some versioning directives --- docs/source/telegram.request.rst | 2 ++ telegram/request/_baserequest.py | 2 ++ telegram/request/_httpxrequest.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/docs/source/telegram.request.rst b/docs/source/telegram.request.rst index ce724e76f28..5279a8af171 100644 --- a/docs/source/telegram.request.rst +++ b/docs/source/telegram.request.rst @@ -3,6 +3,8 @@ telegram.request Module ======================= +.. versionadded:: 14.0 + .. toctree:: telegram.request.baserequest telegram.request.requestdata diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index b6ce388fc7c..76411218223 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -72,6 +72,8 @@ class BaseRequest( # code finally: await request_object.shutdown() + + .. versionadded:: 14.0 """ __slots__ = () diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index ddbfe8bab10..67a2790127d 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -40,6 +40,8 @@ class HTTPXRequest(BaseRequest): """Implementation of :class:`~telegram.request.BaseRequest` using the library `httpx `_. + .. versionadded:: 14.0 + Args: connection_pool_size (:obj:`int`, optional): Number of connections to keep in the connection pool. Defaults to ``1``. From d7ff6923c4dbec0a0b0249ba86f7ede70124c5ba Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 10:57:21 +0200 Subject: [PATCH 128/153] fix tests --- tests/test_application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_application.py b/tests/test_application.py index afeda89cc7c..8dbd7f19684 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1018,7 +1018,7 @@ async def error_handler(update, context): assert recwarn[0].category is PTBUserWarning assert ( str(recwarn[0].message) - == 'ApplicationHandlerStop is not supported with asynchronously running handlers.' + == 'ApplicationHandlerStop is not supported with handlers running non-blocking.' ) assert ( Path(recwarn[0].filename) == PROJECT_ROOT_PATH / 'telegram' / 'ext' / '_application.py' From d7b9643d672737b5e13d824c8e7c1a35c56867d9 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 13:59:47 +0200 Subject: [PATCH 129/153] small edits for tests --- tests/test_jobqueue.py | 8 ++++++++ tests/test_message.py | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index 5dca6dfa6fd..6d31ff6d712 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -421,6 +421,14 @@ async def test_get_jobs(self, job_queue): assert job_queue.get_jobs_by_name('name1') == (job1, job2) assert job_queue.get_jobs_by_name('name2') == (job3,) + @pytest.mark.asyncio + async def test_job_run(self, app): + job = app.job_queue.run_repeating(self.job_run_once, 0.02) + await asyncio.sleep(0.05) + assert self.result == 0 + await job.run(app) + assert self.result == 1 + @pytest.mark.asyncio async def test_enable_disable_job(self, job_queue): job = job_queue.run_repeating(self.job_run_once, 0.2) diff --git a/tests/test_message.py b/tests/test_message.py index c526f2daae7..876e22f43dd 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -639,26 +639,26 @@ async def test_parse_entities_url_emoji(self): def test_chat_id(self, message): assert message.chat_id == message.chat.id - @pytest.mark.parametrize('type', argvalues=[Chat.SUPERGROUP, Chat.CHANNEL]) - def test_link_with_username(self, message, type): + @pytest.mark.parametrize('type_', argvalues=[Chat.SUPERGROUP, Chat.CHANNEL]) + def test_link_with_username(self, message, type_): message.chat.username = 'username' - message.chat.type = type + message.chat.type = type_ assert message.link == f'https://t.me/{message.chat.username}/{message.message_id}' @pytest.mark.parametrize( - 'type, id', argvalues=[(Chat.CHANNEL, -1003), (Chat.SUPERGROUP, -1003)] + 'type_, id_', argvalues=[(Chat.CHANNEL, -1003), (Chat.SUPERGROUP, -1003)] ) - def test_link_with_id(self, message, type, id): + def test_link_with_id(self, message, type_, id_): message.chat.username = None - message.chat.id = id - message.chat.type = type + message.chat.id = id_ + message.chat.type = type_ # The leading - for group ids/ -100 for supergroup ids isn't supposed to be in the link assert message.link == f'https://t.me/c/{3}/{message.message_id}' - @pytest.mark.parametrize('id, username', argvalues=[(None, 'username'), (-3, None)]) - def test_link_private_chats(self, message, id, username): + @pytest.mark.parametrize('id_, username', argvalues=[(None, 'username'), (-3, None)]) + def test_link_private_chats(self, message, id_, username): message.chat.type = Chat.PRIVATE - message.chat.id = id + message.chat.id = id_ message.chat.username = username assert message.link is None message.chat.type = Chat.GROUP From b6eef7e482fec8969f6af5aa9f21dce52d3280fc Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 14:07:36 +0200 Subject: [PATCH 130/153] Review --- telegram/request/_requestparameter.py | 5 +++++ tests/test_message.py | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/telegram/request/_requestparameter.py b/telegram/request/_requestparameter.py index e6764c2302f..844327ff183 100644 --- a/telegram/request/_requestparameter.py +++ b/telegram/request/_requestparameter.py @@ -87,6 +87,11 @@ def _value_and_input_file_from_input( # pylint: disable=too-many-return-stateme Note that we use this for *all* files to be uploaded. This is not documented in the official API, but has been confirmed to be supported in the official Bot API repository. See https://github.com/tdlib/telegram-bot-api/issues/167 + + This method does no special casing for enums because + * all enums in tg.constants are subclasses of int/str, so they are already json-dumpable + * if a user passes a custom enum, it's unlikely that we can actually properly handle it + even with some special casing. """ if isinstance(value, datetime): return to_timestamp(value), [] diff --git a/tests/test_message.py b/tests/test_message.py index 876e22f43dd..c12c5647888 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -322,9 +322,14 @@ class TestMessage: def test_all_possibilities_de_json_and_to_dict(self, bot, message_params): new = Message.de_json(message_params.to_dict(), bot) - assert new.to_dict() == message_params.to_dict() + # Checking that none of the attributes are dicts is a best effort approach to ensure that + # de_json converts everything to proper classes without having to write special tests for + # every single case + for slot in new.__slots__: + assert not isinstance(new[slot], dict) + def test_slot_behaviour(self, message, mro_slots): for attr in message.__slots__: assert getattr(message, attr, 'err') != 'err', f"got extra slot '{attr}'" From c61afc8c19f59a9633de11c1da4e534001c7f564 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 18:01:56 +0200 Subject: [PATCH 131/153] Update timerbot a bit --- examples/timerbot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/timerbot.py b/examples/timerbot.py index 8ff874dea96..3f1131dc430 100644 --- a/examples/timerbot.py +++ b/examples/timerbot.py @@ -27,7 +27,6 @@ logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO ) -logger = logging.getLogger(__name__) # Define a few command handlers. These usually take the two arguments update and @@ -44,7 +43,7 @@ async def start(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: async def alarm(context: CallbackContext.DEFAULT_TYPE) -> None: """Send the alarm message.""" job = context.job - await context.bot.send_message(job.context, text='Beep!') + await context.bot.send_message(job.chat_id, text=f'Beep! {job.context} seconds are over!') def remove_job_if_exists(name: str, context: CallbackContext.DEFAULT_TYPE) -> bool: @@ -68,7 +67,7 @@ async def set_timer(update: Update, context: CallbackContext.DEFAULT_TYPE) -> No return job_removed = remove_job_if_exists(str(chat_id), context) - context.job_queue.run_once(alarm, due, context=chat_id, name=str(chat_id)) + context.job_queue.run_once(alarm, due, chat_id=chat_id, name=str(chat_id), context=due) text = 'Timer successfully set!' if job_removed: From 70ee0ffdfc91c51982b857b81affaf53bc1afe2b Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 20:58:11 +0200 Subject: [PATCH 132/153] shutdown with signal handlers --- telegram/ext/_application.py | 58 +++++++++++++++++++++++++++++++++++- tests/test_application.py | 32 ++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 33522e12f10..66bbc058bb4 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -21,6 +21,7 @@ import inspect import itertools import logging +import signal from collections import defaultdict from contextlib import AbstractAsyncContextManager from copy import deepcopy @@ -42,6 +43,8 @@ Set, Mapping, DefaultDict, + Sequence, + NoReturn, ) from telegram import Update @@ -541,11 +544,16 @@ def run_polling( allowed_updates: List[str] = None, drop_pending_updates: bool = None, close_loop: bool = True, + stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT), ) -> None: """Convenience method that takes care of initializing and starting the app, polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and a graceful shutdown of the app on exit. + The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. + On unix, the app will also shut down on receiving the signals specified by + :paramref:`stop_signals`. + .. seealso:: :meth:`initialize`, :meth:`start`, :meth:`stop`, :meth:`shutdown` :meth:`telegram.ext.Updater.start_polling`, :meth:`run_webhook` @@ -582,6 +590,16 @@ def run_polling( .. seealso:: :meth:`asyncio.loop.close` + stop_signals (Sequence[:obj:`int`] | :obj:`None`, optional): Signals that will shut + down the app. Pass :obj:`None` to not use stop signals. + Defaults to :data:`signal.SIGINT`, :data:`signal.SIGTERM` and + :data:`signal.SIGABRT`. + + Caution: + Not every :class:`asyncio.AbstractEventLoop` implements + :meth:`asyncio.loop.add_signal_handler`. Most notably, the standard event loop + on Windows, :class:`asyncio.ProactorEventLoop`, does not implement this method. + If this method is not available, stop signals can not be set. Raises: :exc:`RuntimeError`: If the Application does not have an :class:`telegram.ext.Updater`. @@ -608,6 +626,7 @@ def error_callback(exc: TelegramError) -> None: error_callback=error_callback, # if there is an error in fetching updates ), close_loop=close_loop, + stop_signals=stop_signals, ) def run_webhook( @@ -624,11 +643,16 @@ def run_webhook( ip_address: str = None, max_connections: int = 40, close_loop: bool = True, + stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT), ) -> None: """Convenience method that takes care of initializing and starting the app, polling updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and a graceful shutdown of the app on exit. + The app will shut down when :exc:`KeyboardInterrupt` or :exc:`SystemExit` is raised. + On unix, the app will also shut down on receiving the signals specified by + :paramref:`stop_signals`. + If :paramref:`cert` and :paramref:`key` are not provided, the webhook will be started directly on ``http://listen:port/url_path``, so SSL can be handled by another @@ -668,6 +692,16 @@ def run_webhook( .. seealso:: :meth:`asyncio.loop.close` + stop_signals (Sequence[:obj:`int`] | :obj:`None`, optional): Signals that will shut + down the app. Pass :obj:`None` to not use stop signals. + Defaults to :data:`signal.SIGINT`, :data:`signal.SIGTERM` and + :data:`signal.SIGABRT`. + + Caution: + Not every :class:`asyncio.AbstractEventLoop` implements + :meth:`asyncio.loop.add_signal_handler`. Most notably, the standard event loop + on Windows, :class:`asyncio.ProactorEventLoop`, does not implement this method. + If this method is not available, stop signals can not be set. """ if not self.updater: raise RuntimeError( @@ -689,13 +723,35 @@ def run_webhook( max_connections=max_connections, ), close_loop=close_loop, + stop_signals=stop_signals, ) - def __run(self, updater_coroutine: Coroutine, close_loop: bool = True) -> None: + @staticmethod + def _raise_system_exit() -> NoReturn: + raise SystemExit + + def __run( + self, + updater_coroutine: Coroutine, + stop_signals: Optional[Sequence[int]], + close_loop: bool = True, + ) -> None: # Calling get_event_loop() should still be okay even in py3.10+ as long as there is a # running event loop or we are in the main thread, which are the intended use cases. # See the docs of get_event_loop() and get_running_loop() for more info loop = asyncio.get_event_loop() + + try: + for sig in stop_signals or []: + loop.add_signal_handler(sig, self._raise_system_exit) + except NotImplementedError as exc: + warn( + f'Could not add signal handlers for the stop signals {stop_signals} due to ' + f'exception `{exc!r}`. If your event loop does not implement `add_signal_handler`,' + f' please pass `stop_signals=None`.', + stacklevel=3, + ) + try: loop.run_until_complete(self.initialize()) loop.run_until_complete(updater_coroutine) # one of updater.start_webhook/polling diff --git a/tests/test_application.py b/tests/test_application.py index 8dbd7f19684..5b446f3dbaa 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1611,6 +1611,7 @@ def test_run_without_updater(self, bot): app.run_polling() @pytest.mark.parametrize('method', ['start', 'initialize']) + @pytest.mark.filterwarnings('ignore::telegram.warnings.PTBUserWarning') def test_run_error_in_application(self, bot, monkeypatch, method): shutdowns = [] @@ -1632,6 +1633,7 @@ async def shutdown(*args, **kwargs): assert shutdowns == [True, True] @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) + @pytest.mark.filterwarnings('ignore::telegram.warnings.PTBUserWarning') def test_run_error_in_updater(self, bot, monkeypatch, method): shutdowns = [] @@ -1654,3 +1656,33 @@ async def shutdown(*args, **kwargs): assert not app.running assert not app.updater.running assert shutdowns == [True, True] + + @pytest.mark.skipif( + platform.system() != 'Windows', + reason="Only really relevant on windows", + ) + @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) + @pytest.mark.asyncio + async def test_run_stop_signal_warning_windows(self, bot, method, monkeypatch): + async def raise_method(*args, **kwargs): + raise RuntimeError('Test Exception') + + # monkeypatch.setattr(Updater, method, raise_method) + app = ApplicationBuilder().token(bot.token).build() + with pytest.raises( + PTBUserWarning, match='Could not add signal handlers for the stop signals' + ) as exc_info: + if 'polling' in method: + app.run_polling(close_loop=False) + else: + app.run_webhook(close_loop=False) + + assert exc_info.traceback[0].path == Path(__file__), "stacklevel is incorrect!" + + with pytest.raises(RuntimeError, match="This event loop is already running"): + # this is somewhat silly: app.run_*() won't work as the pytest asyncio loop is already + # running, but we only care about checking that no warning is issued ... + if 'polling' in method: + app.run_polling(close_loop=False, stop_signals=None) + else: + app.run_webhook(close_loop=False, stop_signals=None) From 54c83dc44de2ac54dc0caf434080578ba67f648c Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 21:18:44 +0200 Subject: [PATCH 133/153] try improving tests --- tests/test_application.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 5b446f3dbaa..c6d3f1c4e50 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1662,13 +1662,9 @@ async def shutdown(*args, **kwargs): reason="Only really relevant on windows", ) @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) + @pytest.mark.filterwarnings("ignore:was never awaited") @pytest.mark.asyncio - async def test_run_stop_signal_warning_windows(self, bot, method, monkeypatch): - async def raise_method(*args, **kwargs): - raise RuntimeError('Test Exception') - - # monkeypatch.setattr(Updater, method, raise_method) - app = ApplicationBuilder().token(bot.token).build() + async def test_run_stop_signal_warning_windows(self, app, method): with pytest.raises( PTBUserWarning, match='Could not add signal handlers for the stop signals' ) as exc_info: From b8f41a58006d1053220edd7fb14d3c8296220cbf Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 21:27:19 +0200 Subject: [PATCH 134/153] try again --- tests/test_application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_application.py b/tests/test_application.py index c6d3f1c4e50..632f010d515 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1662,7 +1662,7 @@ async def shutdown(*args, **kwargs): reason="Only really relevant on windows", ) @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) - @pytest.mark.filterwarnings("ignore:was never awaited") + @pytest.mark.filterwarnings(r"ignore:coroutine '[\w\.\_]+' was never awaited") @pytest.mark.asyncio async def test_run_stop_signal_warning_windows(self, app, method): with pytest.raises( From 0a83834d196d2466f9b42019033af3406611e6d1 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 17 Apr 2022 22:07:51 +0200 Subject: [PATCH 135/153] One last try --- tests/test_application.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 632f010d515..78155183639 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1662,23 +1662,28 @@ async def shutdown(*args, **kwargs): reason="Only really relevant on windows", ) @pytest.mark.parametrize('method', ['start_polling', 'start_webhook']) - @pytest.mark.filterwarnings(r"ignore:coroutine '[\w\.\_]+' was never awaited") - @pytest.mark.asyncio - async def test_run_stop_signal_warning_windows(self, app, method): - with pytest.raises( - PTBUserWarning, match='Could not add signal handlers for the stop signals' - ) as exc_info: + def test_run_stop_signal_warning_windows(self, bot, method, recwarn, monkeypatch): + async def raise_method(*args, **kwargs): + raise RuntimeError('Prevent Actually Running') + + monkeypatch.setattr(Application, 'initialize', raise_method) + app = ApplicationBuilder().token(bot.token).build() + + with pytest.raises(RuntimeError, match='Prevent Actually Running'): if 'polling' in method: app.run_polling(close_loop=False) else: app.run_webhook(close_loop=False) - assert exc_info.traceback[0].path == Path(__file__), "stacklevel is incorrect!" + assert len(recwarn) == 1 + assert str(recwarn[0].message).startswith('Could not add signal handlers for the stop') + assert recwarn[0].filename == __file__, "stacklevel is incorrect!" - with pytest.raises(RuntimeError, match="This event loop is already running"): - # this is somewhat silly: app.run_*() won't work as the pytest asyncio loop is already - # running, but we only care about checking that no warning is issued ... + recwarn.clear() + with pytest.raises(RuntimeError, match='Prevent Actually Running'): if 'polling' in method: app.run_polling(close_loop=False, stop_signals=None) else: app.run_webhook(close_loop=False, stop_signals=None) + + assert len(recwarn) == 0 From 193c43568a4890a5c04772a755245d7792d46564 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 18 Apr 2022 15:59:25 +0200 Subject: [PATCH 136/153] fix some docs of HTTPXRequest --- telegram/request/_httpxrequest.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/telegram/request/_httpxrequest.py b/telegram/request/_httpxrequest.py index 67a2790127d..9a378c07ea1 100644 --- a/telegram/request/_httpxrequest.py +++ b/telegram/request/_httpxrequest.py @@ -61,19 +61,22 @@ class HTTPXRequest(BaseRequest): .. _the docs of httpx: https://www.python-httpx.org/environment_variables/#proxies read_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum - amount of time (in seconds) to wait for a response from Telegram's server instead - of the time specified during creating of this object. Defaults to ``5``. + amount of time (in seconds) to wait for a response from Telegram's server. + This value is used unless a different value is passed to :meth:`do_request`. + Defaults to ``5``. write_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum amount of time (in seconds) to wait for a write operation to complete (in terms of - a network socket; i.e. POSTing a request or uploading a file) instead of the time - specified during creating of this object. Defaults to ``5``. + a network socket; i.e. POSTing a request or uploading a file). + This value is used unless a different value is passed to :meth:`do_request`. + Defaults to ``5``. connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum amount of time (in seconds) to wait for a connection attempt to a server - to succeed instead of the time specified during creating of this object. Defaults - to ``5``. + to succeed. This value is used unless a different value is passed to + :meth:`do_request`. Defaults to ``5``. pool_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the maximum - amount of time (in seconds) to wait for a connection to become available instead - of the time specified during creating of this object. Defaults to ``1``. + amount of time (in seconds) to wait for a connection to become available. + This value is used unless a different value is passed to :meth:`do_request`. + Defaults to ``1``. Warning: With a finite pool timeout, you must expect :exc:`telegram.error.TimedOut` From 079fca1e3e484b99185e860c4f330f24b8fd6abf Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Tue, 19 Apr 2022 00:16:13 +0530 Subject: [PATCH 137/153] move definition of InitAppBuilder type alias to bottom of file --- telegram/ext/_applicationbuilder.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index d92f1a7b5e8..5cda8840d1f 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -53,19 +53,6 @@ InBD = TypeVar('InBD') BuilderType = TypeVar('BuilderType', bound='ApplicationBuilder') -if TYPE_CHECKING: - DEF_CCT = CallbackContext.DEFAULT_TYPE # type: ignore[misc] - InitApplicationBuilder = ( - ApplicationBuilder[ # noqa: F821 # pylint: disable=used-before-assignment - ExtBot, - DEF_CCT, - Dict, - Dict, - Dict, - JobQueue, - ] - ) - _BOT_CHECKS = [ ('request', 'request instance'), @@ -902,3 +889,15 @@ def updater(self: BuilderType, updater: Optional[Updater]) -> BuilderType: self._updater = updater return self + + +InitApplicationBuilder = ( # This is defined all the way down here so that its type is inferred + ApplicationBuilder[ # by Pylance correctly. + ExtBot, + CallbackContext.DEFAULT_TYPE, + Dict, + Dict, + Dict, + JobQueue, + ] +) From a1bf431c17dfda8475b4446d5583a7a4b39bbee3 Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Tue, 19 Apr 2022 00:21:15 +0530 Subject: [PATCH 138/153] try to make codacy happy w.r.t hardcoded token in example --- examples/paymentbot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/paymentbot.py b/examples/paymentbot.py index 28ae5899b0e..1536635a28d 100644 --- a/examples/paymentbot.py +++ b/examples/paymentbot.py @@ -24,7 +24,7 @@ ) logger = logging.getLogger(__name__) -PAYMENT_PROVIDER_TOKEN = 'TOKEN' +PAYMENT_PROVIDER_TOKEN = "PAYMENT_PROVIDER_TOKEN" async def start_callback(update: Update, context: CallbackContext.DEFAULT_TYPE) -> None: From 74bf2df485090a4fa8252d4c7746389e3c813772 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 18 Apr 2022 21:45:08 +0200 Subject: [PATCH 139/153] add a debug print to failing test --- tests/test_application.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_application.py b/tests/test_application.py index 78155183639..fb08103ff4c 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1675,6 +1675,10 @@ async def raise_method(*args, **kwargs): else: app.run_webhook(close_loop=False) + for record in recwarn: + print(record) + print(record.message) + print(str(record.message)) assert len(recwarn) == 1 assert str(recwarn[0].message).startswith('Could not add signal handlers for the stop') assert recwarn[0].filename == __file__, "stacklevel is incorrect!" From bd8fe9e0f6513dcfbe3ff21299263c600f215fa9 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 18 Apr 2022 22:15:54 +0200 Subject: [PATCH 140/153] try fixing the failing test --- setup.cfg | 4 ++++ tests/conftest.py | 5 ----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index 5df067d852c..f13a3ad56ff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,10 @@ addopts = --no-success-flaky-report -rsxX filterwarnings = error ignore::DeprecationWarning +; ignore:Tasks created via `Application\.create_task` while the application is not running + ignore::ResourceWarning +; TODO: Write so good code that we don't need to ignore ResourceWarnings anymore + ; Unfortunately due to https://github.com/pytest-dev/pytest/issues/8343 we can't have this here ; and instead do a trick directly in tests/conftest.py ; ignore::telegram.utils.deprecate.TelegramDeprecationWarning diff --git a/tests/conftest.py b/tests/conftest.py index e1903b45876..2dbf6b5ff52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -255,11 +255,6 @@ def class_thumb_file(): f.close() -def pytest_configure(config): - config.addinivalue_line('filterwarnings', 'ignore::ResourceWarning') - # TODO: Write so good code that we don't need to ignore ResourceWarnings anymore - - def make_bot(bot_info, **kwargs): """ Tests are executed on tg.ext.ExtBot, as that class only extends the functionality of tg.bot From a7a11bc2fcf031a188092026e60b2effb0eeaa87 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 18 Apr 2022 22:23:34 +0200 Subject: [PATCH 141/153] create_task while not running: warn instead of log --- setup.cfg | 2 +- telegram/ext/_application.py | 5 +++-- tests/test_application.py | 12 ++++++------ tests/test_jobqueue.py | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/setup.cfg b/setup.cfg index f13a3ad56ff..1541905a06a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ addopts = --no-success-flaky-report -rsxX filterwarnings = error ignore::DeprecationWarning -; ignore:Tasks created via `Application\.create_task` while the application is not running + ignore:Tasks created via `Application\.create_task` while the application is not running ignore::ResourceWarning ; TODO: Write so good code that we don't need to ignore ResourceWarnings anymore diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 66bbc058bb4..88f1b8b683d 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -816,9 +816,10 @@ def __create_task( self.__create_task_tasks.add(task) task.add_done_callback(self.__create_task_done_callback) else: - _logger.warning( + warn( "Tasks created via `Application.create_task` while the application is not " - "running won't be automatically awaited!" + "running won't be automatically awaited!", + stacklevel=3, ) return task diff --git a/tests/test_application.py b/tests/test_application.py index fb08103ff4c..25b69f19374 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1166,7 +1166,7 @@ async def callback(): @pytest.mark.asyncio @pytest.mark.parametrize('running', (True, False)) - async def test_create_task_awaiting_warning(self, app, running, caplog): + async def test_create_task_awaiting_warning(self, app, running, recwarn): async def callback(): await asyncio.sleep(0.1) return 43 @@ -1175,18 +1175,18 @@ async def callback(): if running: await app.start() - with caplog.at_level(logging.WARNING): - task = app.create_task(callback()) + task = app.create_task(callback()) if running: - assert len(caplog.records) == 0 + assert len(recwarn) == 0 assert not task.done() await app.stop() assert task.done() assert task.result() == 43 else: - assert len(caplog.records) == 1 - assert "won't be automatically awaited" in caplog.records[-1].getMessage() + assert len(recwarn) == 1 + assert "won't be automatically awaited" in str(recwarn[0].message) + assert recwarn[0].filename == __file__, "wrong stacklevel!" assert not task.done() await task diff --git a/tests/test_jobqueue.py b/tests/test_jobqueue.py index 6d31ff6d712..32634de7f4d 100644 --- a/tests/test_jobqueue.py +++ b/tests/test_jobqueue.py @@ -490,7 +490,7 @@ async def test_process_error_that_raises_errors(self, job_queue, app, caplog): with caplog.at_level(logging.ERROR): job = job_queue.run_once(self.job_with_exception, 0.1) await asyncio.sleep(0.15) - assert len(caplog.records) == 2 + assert len(caplog.records) == 1 rec = caplog.records[-1] assert 'An error was raised and an uncaught' in rec.getMessage() caplog.clear() @@ -509,7 +509,7 @@ async def test_process_error_that_raises_errors(self, job_queue, app, caplog): with caplog.at_level(logging.ERROR): job = job_queue.run_once(self.job_with_exception, 0.1) await asyncio.sleep(0.15) - assert len(caplog.records) == 2 + assert len(caplog.records) == 1 rec = caplog.records[-1] assert 'No error handlers are registered' in rec.getMessage() caplog.clear() From 9a541f0f9beb6fa38d76fa3018f25f9c0051094e Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Mon, 18 Apr 2022 22:28:27 +0200 Subject: [PATCH 142/153] second try for failing test --- tests/test_application.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 25b69f19374..c4682d68df1 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1675,13 +1675,14 @@ async def raise_method(*args, **kwargs): else: app.run_webhook(close_loop=False) + assert len(recwarn) >= 1 + found = False for record in recwarn: print(record) - print(record.message) - print(str(record.message)) - assert len(recwarn) == 1 - assert str(recwarn[0].message).startswith('Could not add signal handlers for the stop') - assert recwarn[0].filename == __file__, "stacklevel is incorrect!" + if str(record.message).startswith('Could not add signal handlers for the stop'): + assert record.filename == __file__, "stacklevel is incorrect!" + found = True + assert found recwarn.clear() with pytest.raises(RuntimeError, match='Prevent Actually Running'): From c7677c2dcb805231d90480cf2ec3bdf994cb85b7 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 19 Apr 2022 21:40:29 +0200 Subject: [PATCH 143/153] tiny doc fix --- telegram/ext/_application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 88f1b8b683d..458062fcbed 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -1129,7 +1129,7 @@ def migrate_chat_data( :meth:`update_persistence`. Warning: - * Any data stored in :attr:`chat_data` at key `new_chat_id` will be overridden + * Any data stored in :attr:`chat_data` at key ``new_chat_id`` will be overridden * The key `old_chat_id` of :attr:`chat_data` will be deleted * This does not update the :attr:`~telegram.ext.Job.chat_id` attribute of any scheduled :class:`telegram.ext.Job`. From b4fa728b9e1077df0e34c62b73970879287b742f Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Tue, 19 Apr 2022 21:50:00 +0200 Subject: [PATCH 144/153] Another small doc fix --- telegram/ext/_application.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 458062fcbed..9008797a757 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -1154,6 +1154,9 @@ def migrate_chat_data( Mutually exclusive with passing :paramref:`message` new_chat_id (:obj:`int`, optional): The new chat ID. Mutually exclusive with passing :paramref:`message` + + Raises: + ValueError: Raised if the input is invalid. """ if message and (old_chat_id or new_chat_id): raise ValueError("Message and chat_id pair are mutually exclusive") From 12f13225c01130a4495b856c66da546f38d4b679 Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Wed, 20 Apr 2022 22:24:53 +0530 Subject: [PATCH 145/153] jobqueue doc fixes --- telegram/ext/_jobqueue.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/telegram/ext/_jobqueue.py b/telegram/ext/_jobqueue.py index 7bd8c18eceb..40f919b3cc1 100644 --- a/telegram/ext/_jobqueue.py +++ b/telegram/ext/_jobqueue.py @@ -508,7 +508,7 @@ async def callback(context: CallbackContext) async def start(self) -> None: # this method async just in case future versions need that - """Starts the job_queue thread.""" + """Starts the job_queue.""" if not self.scheduler.running: self.scheduler.start() @@ -579,10 +579,10 @@ async def callback(context: CallbackContext) job (:class:`apscheduler.job.Job`, optional): The APS Job this job is a wrapper for. chat_id (:obj:`int`, optional): Chat id of the chat that this job is associated with. - ..versionadded:: 14.0 + .. versionadded:: 14.0 user_id (:obj:`int`, optional): User id of the user that this job is associated with. - ..versionadded:: 14.0 + .. versionadded:: 14.0 Attributes: callback (:term:`coroutine function`): The callback function that should be executed by the @@ -592,10 +592,10 @@ async def callback(context: CallbackContext) job (:class:`apscheduler.job.Job`): Optional. The APS Job this job is a wrapper for. chat_id (:obj:`int`): Optional. Chat id of the chat that this job is associated with. - ..versionadded:: 14.0 + .. versionadded:: 14.0 user_id (:obj:`int`): Optional. User id of the user that this job is associated with. - ..versionadded:: 14.0 + .. versionadded:: 14.0 """ __slots__ = ( From bafe934239415dd2357b971f35489a3ab5ee9214 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Thu, 21 Apr 2022 22:47:32 +0200 Subject: [PATCH 146/153] Rework uploading files - tests for RequestData are not yet adjusted --- telegram/_bot.py | 10 ++-- telegram/_files/inputfile.py | 20 ++++--- telegram/_files/inputmedia.py | 12 ++--- telegram/_utils/enum.py | 36 +++++++++++++ telegram/_utils/files.py | 8 ++- telegram/constants.py | 42 ++++++--------- telegram/request/_baserequest.py | 2 - telegram/request/_requestdata.py | 12 ++++- telegram/request/_requestparameter.py | 75 ++++++++++++++++++--------- tests/test_bot.py | 3 ++ tests/test_constants.py | 7 ++- tests/test_files.py | 8 +++ tests/test_inputfile.py | 10 ++++ tests/test_requestdata.py | 4 ++ tests/test_requestparameter.py | 37 +++++++++---- 15 files changed, 201 insertions(+), 85 deletions(-) create mode 100644 telegram/_utils/enum.py diff --git a/telegram/_bot.py b/telegram/_bot.py index 8b65f3045ab..7a236ab6c4b 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -1036,7 +1036,7 @@ async def send_audio( if caption_entities: data['caption_entities'] = caption_entities if thumb: - data['thumb'] = parse_file_input(thumb) + data['thumb'] = parse_file_input(thumb, attach=True) return await self._send_message( # type: ignore[return-value] 'sendAudio', @@ -1171,7 +1171,7 @@ async def send_document( if disable_content_type_detection is not None: data['disable_content_type_detection'] = disable_content_type_detection if thumb: - data['thumb'] = parse_file_input(thumb) + data['thumb'] = parse_file_input(thumb, attach=True) return await self._send_message( # type: ignore[return-value] 'sendDocument', @@ -1410,7 +1410,7 @@ async def send_video( if height: data['height'] = height if thumb: - data['thumb'] = parse_file_input(thumb) + data['thumb'] = parse_file_input(thumb, attach=True) return await self._send_message( # type: ignore[return-value] 'sendVideo', @@ -1533,7 +1533,7 @@ async def send_video_note( if length is not None: data['length'] = length if thumb: - data['thumb'] = parse_file_input(thumb) + data['thumb'] = parse_file_input(thumb, attach=True) return await self._send_message( # type: ignore[return-value] 'sendVideoNote', @@ -1673,7 +1673,7 @@ async def send_animation( if height: data['height'] = height if thumb: - data['thumb'] = parse_file_input(thumb) + data['thumb'] = parse_file_input(thumb, attach=True) if caption: data['caption'] = caption if caption_entities: diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index b2aa8ee3c1b..a4ae183f11e 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -49,10 +49,15 @@ class InputFile: .. versionchanged:: 14.0 Accept string input. filename (:obj:`str`, optional): Filename for this InputFile. + attach (:obj:`bool`, optional): Pass :obj:`True` if the parameter this file belongs to in + the request to Telegram should point to the multipart data via an ``attach://`` URI. + Defaults to `False`. Attributes: input_file_content (:obj:`bytes`): The binary content of the file to send. - attach_name (:obj:`str`): Attach name. + attach_name (:obj:`str`, Optional ): If present, the parameter this file belongs to in + the request to Telegram should point to the multipart data via a an URI of the form + ``attach://`` URI. filename (:obj:`str`): Filename for the file to be sent. mimetype (:obj:`str`): The mimetype inferred from the file to be sent. @@ -60,14 +65,16 @@ class InputFile: __slots__ = ('filename', 'attach_name', 'input_file_content', 'mimetype') - def __init__(self, obj: Union[IO[bytes], bytes, str], filename: str = None): + def __init__( + self, obj: Union[IO[bytes], bytes, str], filename: str = None, attach: bool = False + ): if isinstance(obj, bytes): self.input_file_content = obj elif isinstance(obj, str): self.input_file_content = obj.encode('utf-8') else: self.input_file_content = obj.read() - self.attach_name = 'attached' + uuid4().hex + self.attach_name: Optional[str] = 'attached' + uuid4().hex if attach else None if ( not filename @@ -119,6 +126,7 @@ def field_tuple(self) -> FieldTuple: return self.filename, self.input_file_content, self.mimetype @property - def attach_uri(self) -> str: - """URI to insert into the JSON data for uploading the file.""" - return f'attach://{self.attach_name}' + def attach_uri(self) -> Optional[str]: + """URI to insert into the JSON data for uploading the file. Returns :obj:`None`, if + :attr:`attach_name` is :obj:`None`.""" + return f'attach://{self.attach_name}' if self.attach_name else None diff --git a/telegram/_files/inputmedia.py b/telegram/_files/inputmedia.py index efa4665db4d..9e2ccbe9c09 100644 --- a/telegram/_files/inputmedia.py +++ b/telegram/_files/inputmedia.py @@ -103,7 +103,7 @@ def to_dict(self) -> JSONDict: @staticmethod def _parse_thumb_input(thumb: Optional[FileInput]) -> Optional[Union[str, InputFile]]: - return parse_file_input(thumb) if thumb is not None else thumb + return parse_file_input(thumb, attach=True) if thumb is not None else thumb class InputMediaAnimation(InputMedia): @@ -184,7 +184,7 @@ def __init__( duration = media.duration if duration is None else duration media = media.file_id else: - media = parse_file_input(media, filename=filename) + media = parse_file_input(media, filename=filename, attach=True) super().__init__(InputMediaType.ANIMATION, media, caption, caption_entities, parse_mode) self.thumb = self._parse_thumb_input(thumb) @@ -240,7 +240,7 @@ def __init__( caption_entities: Union[List[MessageEntity], Tuple[MessageEntity, ...]] = None, filename: str = None, ): - media = parse_file_input(media, PhotoSize, filename=filename) + media = parse_file_input(media, PhotoSize, filename=filename, attach=True) super().__init__(InputMediaType.PHOTO, media, caption, caption_entities, parse_mode) @@ -331,7 +331,7 @@ def __init__( duration = duration if duration is not None else media.duration media = media.file_id else: - media = parse_file_input(media, filename=filename) + media = parse_file_input(media, filename=filename, attach=True) super().__init__(InputMediaType.VIDEO, media, caption, caption_entities, parse_mode) self.width = width @@ -422,7 +422,7 @@ def __init__( title = media.title if title is None else title media = media.file_id else: - media = parse_file_input(media, filename=filename) + media = parse_file_input(media, filename=filename, attach=True) super().__init__(InputMediaType.AUDIO, media, caption, caption_entities, parse_mode) self.thumb = self._parse_thumb_input(thumb) @@ -496,7 +496,7 @@ def __init__( caption_entities: Union[List[MessageEntity], Tuple[MessageEntity, ...]] = None, filename: str = None, ): - media = parse_file_input(media, Document, filename=filename) + media = parse_file_input(media, Document, filename=filename, attach=True) super().__init__(InputMediaType.DOCUMENT, media, caption, caption_entities, parse_mode) self.thumb = self._parse_thumb_input(thumb) self.disable_content_type_detection = disable_content_type_detection diff --git a/telegram/_utils/enum.py b/telegram/_utils/enum.py new file mode 100644 index 00000000000..63ff62645c4 --- /dev/null +++ b/telegram/_utils/enum.py @@ -0,0 +1,36 @@ +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2022 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains a helper class for Enums that should be subclasses of `str`. + +Warning: + Contents of this module are intended to be used internally by the library and *not* by the + user. Changes to this module are not considered breaking changes and may not be documented in + the changelog. +""" +from enum import Enum + + +class StringEnum(str, Enum): + """Helper class for string enums where the value is not important to be displayed on + stringification. + """ + + __slots__ = () + + def __repr__(self) -> str: + return f'<{self.__class__.__name__}.{self.name}>' diff --git a/telegram/_utils/files.py b/telegram/_utils/files.py index b074d99f2fd..5197bb21ad3 100644 --- a/telegram/_utils/files.py +++ b/telegram/_utils/files.py @@ -58,6 +58,7 @@ def parse_file_input( file_input: Union[FileInput, 'TelegramObject'], tg_type: Type['TelegramObject'] = None, filename: str = None, + attach: bool = False, ) -> Union[str, 'InputFile', Any]: """ Parses input for sending files: @@ -77,6 +78,9 @@ def parse_file_input( :class:`telegram.Animation`. filename (:obj:`str`, optional): The filename. Only relevant in case an :class:`telegram.InputFile` is returned. + attach (:obj:`bool`, optional): Pass :obj:`True` if the parameter this file belongs to in + the request to Telegram should point to the multipart data via an ``attach://`` URI. + Defaults to `False`. Only relevant if an :class:`telegram.InputFile` is returned. Returns: :obj:`str` | :class:`telegram.InputFile` | :obj:`object`: The parsed input or the untouched @@ -94,9 +98,9 @@ def parse_file_input( out = file_input # type: ignore[assignment] return out if isinstance(file_input, bytes): - return InputFile(file_input, filename=filename) + return InputFile(file_input, filename=filename, attach=attach) if hasattr(file_input, 'read'): - return InputFile(cast(IO, file_input), filename=filename) + return InputFile(cast(IO, file_input), filename=filename, attach=attach) if tg_type and isinstance(file_input, tg_type): return file_input.file_id # type: ignore[attr-defined] return file_input diff --git a/telegram/constants.py b/telegram/constants.py index 5a513d14b82..36d71889f75 100644 --- a/telegram/constants.py +++ b/telegram/constants.py @@ -63,20 +63,10 @@ 'UpdateType', ] -from enum import Enum, IntEnum +from enum import IntEnum from typing import List - -class _StringEnum(str, Enum): - """Helper class for string enums where the value is not important to be displayed on - stringification. - """ - - __slots__ = () - - def __repr__(self) -> str: - return f'<{self.__class__.__name__}.{self.name}>' - +from telegram._utils.enum import StringEnum BOT_API_VERSION = '5.7' @@ -85,7 +75,7 @@ def __repr__(self) -> str: SUPPORTED_WEBHOOK_PORTS: List[int] = [443, 80, 88, 8443] -class BotCommandScopeType(_StringEnum): +class BotCommandScopeType(StringEnum): """This enum contains the available types of :class:`telegram.BotCommandScope`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -125,7 +115,7 @@ class CallbackQueryLimit(IntEnum): :meth:`telegram.Bot.answer_callback_query`.""" -class ChatAction(_StringEnum): +class ChatAction(StringEnum): """This enum contains the available chat actions for :meth:`telegram.Bot.send_chat_action`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -210,7 +200,7 @@ class ChatInviteLinkLimit(IntEnum): :meth:`telegram.Bot.create_chat_invite_link` and :meth:`telegram.Bot.edit_chat_invite_link`.""" -class ChatMemberStatus(_StringEnum): +class ChatMemberStatus(StringEnum): """This enum contains the available states for :class:`telegram.ChatMember`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -233,7 +223,7 @@ class ChatMemberStatus(_StringEnum): """:obj:`str`: A :class:`telegram.ChatMember` who was restricted in this chat.""" -class ChatType(_StringEnum): +class ChatType(StringEnum): """This enum contains the available types of :class:`telegram.Chat`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -255,7 +245,7 @@ class ChatType(_StringEnum): """:obj:`str`: A :class:`telegram.Chat` that is a channel.""" -class DiceEmoji(_StringEnum): +class DiceEmoji(StringEnum): """This enum contains the available emoji for :class:`telegram.Dice`/ :meth:`telegram.Bot.send_dice`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -344,7 +334,7 @@ class InlineKeyboardMarkupLimit(IntEnum): """ -class InputMediaType(_StringEnum): +class InputMediaType(StringEnum): """This enum contains the available types of :class:`telegram.InputMedia`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -383,7 +373,7 @@ class InlineQueryLimit(IntEnum): :meth:`telegram.Bot.answer_inline_query`.""" -class InlineQueryResultType(_StringEnum): +class InlineQueryResultType(StringEnum): """This enum contains the available types of :class:`telegram.InlineQueryResult`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -457,7 +447,7 @@ class LocationLimit(IntEnum): """ -class MaskPosition(_StringEnum): +class MaskPosition(StringEnum): """This enum contains the available positions for :class:`telegram.MaskPosition`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -476,7 +466,7 @@ class MaskPosition(_StringEnum): """:obj:`str`: Mask position for a sticker on the chin.""" -class MessageAttachmentType(_StringEnum): +class MessageAttachmentType(StringEnum): """This enum contains the available types of :class:`telegram.Message` that can bee seens as attachment. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -525,7 +515,7 @@ class MessageAttachmentType(_StringEnum): """:obj:`str`: Messages with :attr:`telegram.Message.venue`.""" -class MessageEntityType(_StringEnum): +class MessageEntityType(StringEnum): """This enum contains the available types of :class:`telegram.MessageEntity`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -592,7 +582,7 @@ class MessageLimit(IntEnum): """ -class MessageType(_StringEnum): +class MessageType(StringEnum): """This enum contains the available types of :class:`telegram.Message` that can be seen as attachment. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -679,7 +669,7 @@ class MessageType(_StringEnum): """:obj:`str`: Messages with :attr:`telegram.Message.voice_chat_participants_invited`.""" -class ParseMode(_StringEnum): +class ParseMode(StringEnum): """This enum contains the available parse modes. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -719,7 +709,7 @@ class PollLimit(IntEnum): """:obj:`str`: Maximum number of available options for the poll.""" -class PollType(_StringEnum): +class PollType(StringEnum): """This enum contains the available types for :class:`telegram.Poll`/ :meth:`telegram.Bot.send_poll`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. @@ -735,7 +725,7 @@ class PollType(_StringEnum): """:obj:`str`: quiz polls.""" -class UpdateType(_StringEnum): +class UpdateType(StringEnum): """This enum contains the available types of :class:`telegram.Update`. The enum members of this enumeration are instances of :class:`str` and can be treated as such. diff --git a/telegram/request/_baserequest.py b/telegram/request/_baserequest.py index 76411218223..ece9f2009d5 100644 --- a/telegram/request/_baserequest.py +++ b/telegram/request/_baserequest.py @@ -18,7 +18,6 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains an abstract class to make POST and GET requests.""" import abc -import traceback from contextlib import AbstractAsyncContextManager from http import HTTPStatus from types import TracebackType @@ -280,7 +279,6 @@ async def _request_wrapper( except TelegramError as exc: raise exc except Exception as exc: - traceback.print_tb(exc.__traceback__) raise NetworkError(f"Unknown error in HTTP implementation: {repr(exc)}") from exc if HTTPStatus.OK <= code <= 299: diff --git a/telegram/request/_requestdata.py b/telegram/request/_requestdata.py index f541b75392e..6bcc0c97a79 100644 --- a/telegram/request/_requestdata.py +++ b/telegram/request/_requestdata.py @@ -58,14 +58,22 @@ def parameters(self) -> Dict[str, Union[str, int, List, Dict]]: (possibly nested) composition of lists, tuples and dictionaries, where each entry, key and value is of one of the mentioned types. """ - return {param.name: param.value for param in self._parameters} # type: ignore[misc] + return { + param.name: param.value # type: ignore[misc] + for param in self._parameters + if param.value is not None + } @property def json_parameters(self) -> Dict[str, str]: """Gives the parameters as mapping of parameter name to the respective JSON encoded value. """ - return {param.name: param.json_value for param in self._parameters} + return { + param.name: param.json_value + for param in self._parameters + if param.json_value is not None + } def url_encoded_parameters(self, encode_kwargs: Dict[str, Any] = None) -> str: """Encodes the parameters with :func:`urllib.parse.urlencode`. diff --git a/telegram/request/_requestparameter.py b/telegram/request/_requestparameter.py index 844327ff183..61fd306ec9d 100644 --- a/telegram/request/_requestparameter.py +++ b/telegram/request/_requestparameter.py @@ -19,11 +19,11 @@ """This module contains a class that describes a single parameter of a request to the Bot API.""" from dataclasses import dataclass from datetime import datetime -from enum import Enum from typing import Optional, List, Tuple from telegram import InputFile, InputMedia, TelegramObject from telegram._utils.datetime import to_timestamp +from telegram._utils.enum import StringEnum from telegram._utils.types import UploadFileDict try: @@ -46,13 +46,13 @@ class RequestParameter: Args: name (:obj:`str`): The name of the parameter. - value (:obj:`object`): The value of the parameter. Must be JSON-dumpable. + value (:obj:`object` | :obj:None`): The value of the parameter. Must be JSON-dumpable. input_files (List[:class:`telegram.InputFile`], optional): A list of files that should be uploaded along with this parameter. Attributes: name (:obj:`str`): The name of the parameter. - value (:obj:`object`): The value of the parameter. + value (:obj:`object` | :obj:None`): The value of the parameter. input_files (List[:class:`telegram.InputFile` | :obj:`None`): A list of files that should be uploaded along with this parameter. """ @@ -64,10 +64,15 @@ class RequestParameter: input_files: Optional[List[InputFile]] @property - def json_value(self) -> str: - """The JSON dumped :attr:`value`""" + def json_value(self) -> Optional[str]: + """The JSON dumped :attr:`value` or :obj:`None` if :attr:`value` is :obj:`None`. + The latter can currently only happen if :attr:`input_files` has exactly one element that + must not be uploaded via an attach:// URI. + """ if isinstance(self.value, str): return self.value + if self.value is None: + return None return json.dumps(self.value) @property @@ -75,41 +80,60 @@ def multipart_data(self) -> Optional[UploadFileDict]: """A dict with the file data to upload, if any.""" if not self.input_files: return None - return {input_file.attach_name: input_file.field_tuple for input_file in self.input_files} + return { + (input_file.attach_name or self.name): input_file.field_tuple + for input_file in self.input_files + } @staticmethod - def _value_and_input_file_from_input( # pylint: disable=too-many-return-statements + def _value_and_input_files_from_input( # pylint: disable=too-many-return-statements value: object, ) -> Tuple[object, List[InputFile]]: - """Converts `value` into something that we can json-dump. If `value` contains a file to be - uploaded, it will be returned as second return value and the corresponding attach:// value - will be returned as first return value. - Note that we use this for *all* files to be uploaded. This is not documented in the - official API, but has been confirmed to be supported in the official Bot API repository. - See https://github.com/tdlib/telegram-bot-api/issues/167 - - This method does no special casing for enums because - * all enums in tg.constants are subclasses of int/str, so they are already json-dumpable + """Converts `value` into something that we can json-dump. Returns two values: + 1. the JSON-dumpable value. Maybe be `None` in case the value is an InputFile which must + not be uploaded via an attach:// URI + 2. A list of InputFiles that should be uploaded for this value + + Note that we handle files differently depending on whether attaching them via an URI of the + form attach:// is documented to be allowed or not. + There was some confusion whether this worked for all files, so that we stick to the + documented ways for now. + See https://github.com/tdlib/telegram-bot-api/issues/167 and + https://github.com/tdlib/telegram-bot-api/issues/259 + + This method only does some special casing for our own helper class StringEnum, but not + for general enums. This is because: + * tg.constants currently only uses IntEnum as second enum type and json dumping that + is no problem * if a user passes a custom enum, it's unlikely that we can actually properly handle it even with some special casing. """ if isinstance(value, datetime): return to_timestamp(value), [] - if isinstance(value, Enum): + if isinstance(value, StringEnum): return value.value, [] if isinstance(value, InputFile): - return value.attach_uri, [ - value, - ] + if value.attach_uri: + return value.attach_uri, [ + value, + ] + return None, [value] + if isinstance(value, InputMedia) and isinstance(value.media, InputFile): # We call to_dict and change the returned dict instead of overriding # value.media in case the same value is reused for another request data = value.to_dict() - data['media'] = value.media.attach_uri + if value.media.attach_uri: + data['media'] = value.media.attach_uri + else: + data.pop('media', None) thumb = data.get('thumb', None) if isinstance(thumb, InputFile): - data['thumb'] = thumb.attach_uri + if thumb.attach_uri: + data['thumb'] = thumb.attach_uri + else: + data.pop('thumb', None) return data, [value.media, thumb] return data, [value.media] @@ -127,14 +151,15 @@ def from_input(cls, key: str, value: object) -> 'RequestParameter': param_values = [] input_files = [] for obj in value: - param_value, input_file = cls._value_and_input_file_from_input(obj) - param_values.append(param_value) + param_value, input_file = cls._value_and_input_files_from_input(obj) + if param_value is not None: + param_values.append(param_value) input_files.extend(input_file) return RequestParameter( name=key, value=param_values, input_files=input_files if input_files else None ) - param_value, input_files = cls._value_and_input_file_from_input(value) + param_value, input_files = cls._value_and_input_files_from_input(value) return RequestParameter( name=key, value=param_value, input_files=input_files if input_files else None ) diff --git a/tests/test_bot.py b/tests/test_bot.py index eba5573bb51..f2ef8d6a5b4 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -1646,6 +1646,7 @@ async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot, use_ip max_connections=max_connections, allowed_updates=allowed_updates, ip_address=ip if use_ip else None, + certificate=data_file('sslcert.pem').read_bytes() if use_ip else None, ) await asyncio.sleep(1) @@ -1654,12 +1655,14 @@ async def test_set_webhook_get_webhook_info_and_delete_webhook(self, bot, use_ip assert live_info.max_connections == max_connections assert live_info.allowed_updates == allowed_updates assert live_info.ip_address == ip + assert live_info.has_custom_certificate == use_ip await bot.delete_webhook() await asyncio.sleep(1) info = await bot.get_webhook_info() assert info.url == '' assert info.ip_address is None + assert info.has_custom_certificate is False @pytest.mark.parametrize('drop_pending_updates', [True, False]) @pytest.mark.asyncio diff --git a/tests/test_constants.py b/tests/test_constants.py index a034341e722..fe940566e3d 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -23,12 +23,12 @@ from flaky import flaky from telegram import constants -from telegram.constants import _StringEnum +from telegram._utils.enum import StringEnum from telegram.error import BadRequest from tests.conftest import data_file -class StrEnumTest(_StringEnum): +class StrEnumTest(StringEnum): FOO = 'foo' BAR = 'bar' @@ -39,6 +39,9 @@ class IntEnumTest(IntEnum): class TestConstants: + """Also test _utils.enum.StringEnum on the fly because tg.constants is currently the only + place where that class is used.""" + def test__all__(self): expected = { key diff --git a/tests/test_files.py b/tests/test_files.py index df9d227be03..6775c28dffd 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -97,3 +97,11 @@ def test_parse_file_input_tg_object(self): @pytest.mark.parametrize('obj', [{1: 2}, [1, 2], (1, 2)]) def test_parse_file_input_other(self, obj): assert telegram._utils.files.parse_file_input(obj) is obj + + @pytest.mark.parametrize('attach', [True, False]) + def test_parse_file_input_attach(self, attach): + source_file = data_file('text_file.txt') + parsed = telegram._utils.files.parse_file_input(source_file.read_bytes(), attach=attach) + + assert isinstance(parsed, InputFile) + assert bool(parsed.attach_name) is attach diff --git a/tests/test_inputfile.py b/tests/test_inputfile.py index 22634068fc6..0725920a219 100644 --- a/tests/test_inputfile.py +++ b/tests/test_inputfile.py @@ -56,6 +56,16 @@ def test_subprocess_pipe(self, png_file): # to kill it. pass + @pytest.mark.parametrize('attach', [True, False]) + def test_attach(self, attach): + input_file = InputFile('contents', attach=attach) + if attach: + assert isinstance(input_file.attach_name, str) + assert input_file.attach_uri == f'attach://{input_file.attach_name}' + else: + assert input_file.attach_name is None + assert input_file.attach_uri is None + def test_mimetypes(self, caplog): # Only test a few to make sure logic works okay assert InputFile(data_file('telegram.jpg').open('rb')).mimetype == 'image/jpeg' diff --git a/tests/test_requestdata.py b/tests/test_requestdata.py index 2f254b6b26d..0d84cf402a0 100644 --- a/tests/test_requestdata.py +++ b/tests/test_requestdata.py @@ -133,6 +133,7 @@ def mixed_rqs(mixed_params) -> RequestData: class TestRequestData: + # TODO: Adjust tests! def test_slot_behaviour(self, simple_rqs, mro_slots): for attr in simple_rqs.__slots__: assert getattr(simple_rqs, attr, 'err') != 'err', f"got extra slot '{attr}'" @@ -154,6 +155,7 @@ def test_parameters( # assert file_rqs.parameters == file_params # assert mixed_rqs.parameters == mixed_params + @pytest.mark.xfail(True, reason='Not adjusted yet') def test_json_parameters( self, simple_rqs, file_rqs, mixed_rqs, simple_jsons, file_jsons, mixed_jsons ): @@ -161,6 +163,7 @@ def test_json_parameters( assert file_rqs.json_parameters == file_jsons assert mixed_rqs.json_parameters == mixed_jsons + @pytest.mark.xfail(True, reason='Not adjusted yet') def test_json_payload( self, simple_rqs, file_rqs, mixed_rqs, simple_jsons, file_jsons, mixed_jsons ): @@ -168,6 +171,7 @@ def test_json_payload( assert file_rqs.json_payload == json.dumps(file_jsons).encode() assert mixed_rqs.json_payload == json.dumps(mixed_jsons).encode() + @pytest.mark.xfail(True, reason='Not adjusted yet') def test_multipart_data( self, simple_rqs, diff --git a/tests/test_requestparameter.py b/tests/test_requestparameter.py index 52c404a1e5c..f79bf8ce426 100644 --- a/tests/test_requestparameter.py +++ b/tests/test_requestparameter.py @@ -50,7 +50,7 @@ def test_slot_behaviour(self, mro_slots): (1, '1'), ('one', 'one'), (True, 'true'), - (None, 'null'), + (None, None), ([1, '1'], '[1, "1"]'), ({True: None}, '{"true": null}'), ((1,), '[1]'), @@ -60,15 +60,17 @@ def test_json_value(self, value, expected): request_parameter = RequestParameter('name', value, None) assert request_parameter.json_value == expected - def test_multipart_data(self): + def test_multiple_multipart_data(self): assert RequestParameter('name', 'value', []).multipart_data is None - input_file_1 = InputFile(data_file('telegram.jpg').read_bytes()) - input_file_2 = InputFile(data_file('telegram.jpg').read_bytes(), filename='custom') - request_parameter = RequestParameter('value', 'name', [input_file_1, input_file_2]) + input_file_1 = InputFile('data1', attach=True) + input_file_2 = InputFile('data2', filename='custom') + request_parameter = RequestParameter( + value='value', name='name', input_files=[input_file_1, input_file_2] + ) files = request_parameter.multipart_data assert files[input_file_1.attach_name] == input_file_1.field_tuple - assert files[input_file_2.attach_name] == input_file_2.field_tuple + assert files['name'] == input_file_2.field_tuple @pytest.mark.parametrize( ('value', 'expected_value'), @@ -97,15 +99,19 @@ def test_from_input_no_media(self, value, expected_value): assert request_parameter.input_files is None def test_from_input_inputfile(self): - inputfile_1 = InputFile(data_file('telegram.jpg').read_bytes(), 'inputfile_1') - inputfile_2 = InputFile(data_file('telegram.mp4').read_bytes(), 'inputfile_2') + inputfile_1 = InputFile('data1', filename='inputfile_1', attach=True) + inputfile_2 = InputFile('data2', filename='inputfile_2') request_parameter = RequestParameter.from_input('key', inputfile_1) assert request_parameter.value == inputfile_1.attach_uri assert request_parameter.input_files == [inputfile_1] + request_parameter = RequestParameter.from_input('key', inputfile_2) + assert request_parameter.value is None + assert request_parameter.input_files == [inputfile_2] + request_parameter = RequestParameter.from_input('key', [inputfile_1, inputfile_2]) - assert request_parameter.value == [inputfile_1.attach_uri, inputfile_2.attach_uri] + assert request_parameter.value == [inputfile_1.attach_uri] assert request_parameter.input_files == [inputfile_1, inputfile_2] def test_from_input_input_media(self): @@ -137,3 +143,16 @@ def test_from_input_input_media(self): input_media_thumb.thumb, input_media_no_thumb.media, ] + + def test_from_input_inputmedia_without_attach(self): + """This case will never happen, but we test it for completeness""" + input_media = InputMediaVideo( + data_file('telegram.png').read_bytes(), + thumb=data_file('telegram.png').read_bytes(), + parse_mode=None, + ) + input_media.media.attach_name = None + input_media.thumb.attach_name = None + request_parameter = RequestParameter.from_input('key', input_media) + assert request_parameter.value == {"type": "video"} + assert request_parameter.input_files == [input_media.media, input_media.thumb] From 0086ad57681a2b36fe0368a2e914d8a529a8f438 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 22 Apr 2022 11:55:00 +0200 Subject: [PATCH 147/153] doc adjustments --- docs/source/conf.py | 2 +- telegram/_files/inputfile.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 61a6cee75f8..a92952873c6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -517,7 +517,7 @@ def autodoc_process_bases(app, name, obj, option, bases: list): base = str(base) # Special case because base classes are in std lib: - if "_StringEnum" in base: + if "StringEnum" in base == "": bases[idx] = ":class:`enum.Enum`" bases.insert(0, ':class:`str`') continue diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index a4ae183f11e..05e9d4eef3d 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -128,5 +128,6 @@ def field_tuple(self) -> FieldTuple: @property def attach_uri(self) -> Optional[str]: """URI to insert into the JSON data for uploading the file. Returns :obj:`None`, if - :attr:`attach_name` is :obj:`None`.""" + :attr:`attach_name` is :obj:`None`. + """ return f'attach://{self.attach_name}' if self.attach_name else None From 1aa2fc14d459491a2066016538b3a34080d19806 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Fri, 22 Apr 2022 19:05:31 +0200 Subject: [PATCH 148/153] Test RequestData --- tests/test_request.py | 2 +- tests/test_requestdata.py | 58 ++++++++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tests/test_request.py b/tests/test_request.py index ad273d7f86a..8d38a8288f2 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -48,8 +48,8 @@ mixed_rqs, mixed_params, file_params, + inputfiles, simple_params, - inputfile, input_media_video, input_media_photo, ) diff --git a/tests/test_requestdata.py b/tests/test_requestdata.py index 0d84cf402a0..07fdf14263f 100644 --- a/tests/test_requestdata.py +++ b/tests/test_requestdata.py @@ -33,8 +33,8 @@ @pytest.fixture(scope='module') -def inputfile() -> InputFile: - return InputFile(data_file('telegram.jpg').read_bytes()) +def inputfiles() -> Dict[bool, InputFile]: + return {True: InputFile(obj='data', attach=True), False: InputFile(obj='data', attach=False)} @pytest.fixture(scope='module') @@ -59,8 +59,8 @@ def simple_params() -> Dict[str, Any]: return { 'string': 'string', 'integer': 1, - 'tg_object': MessageEntity('type', 1, 1).to_dict(), - 'list': [1, 'string', MessageEntity('type', 1, 1).to_dict()], + 'tg_object': MessageEntity('type', 1, 1), + 'list': [1, 'string', MessageEntity('type', 1, 1)], } @@ -82,23 +82,24 @@ def simple_rqs(simple_params) -> RequestData: @pytest.fixture(scope='module') -def file_params(inputfile, input_media_video, input_media_photo) -> Dict[str, Any]: +def file_params(inputfiles, input_media_video, input_media_photo) -> Dict[str, Any]: return { - 'inputfile': inputfile, + 'inputfile_attach': inputfiles[True], + 'inputfile_no_attach': inputfiles[False], 'inputmedia': input_media_video, 'inputmedia_list': [input_media_video, input_media_photo], } @pytest.fixture(scope='module') -def file_jsons(inputfile, input_media_video, input_media_photo) -> Dict[str, Any]: +def file_jsons(inputfiles, input_media_video, input_media_photo) -> Dict[str, Any]: input_media_video_dict = input_media_video.to_dict() input_media_video_dict['media'] = input_media_video.media.attach_uri input_media_video_dict['thumb'] = input_media_video.thumb.attach_uri input_media_photo_dict = input_media_photo.to_dict() input_media_photo_dict['media'] = input_media_photo.media.attach_uri return { - 'inputfile': inputfile.attach_uri, + 'inputfile_attach': inputfiles[True].attach_uri, 'inputmedia': json.dumps(input_media_video_dict), 'inputmedia_list': json.dumps([input_media_video_dict, input_media_photo_dict]), } @@ -133,7 +134,6 @@ def mixed_rqs(mixed_params) -> RequestData: class TestRequestData: - # TODO: Adjust tests! def test_slot_behaviour(self, simple_rqs, mro_slots): for attr in simple_rqs.__slots__: assert getattr(simple_rqs, attr, 'err') != 'err', f"got extra slot '{attr}'" @@ -145,17 +145,32 @@ def test_contains_files(self, simple_rqs, file_rqs, mixed_rqs): assert mixed_rqs.contains_files def test_parameters( - self, - simple_rqs, - simple_params, # file_rqs, mixed_rqs, file_params, mixed_params + self, simple_rqs, file_rqs, mixed_rqs, inputfiles, input_media_video, input_media_photo ): - assert simple_rqs.parameters == simple_params - # We don't test these for now since that's a struggle - # And the conversion part is already being tested in test_requestparameter.py - # assert file_rqs.parameters == file_params - # assert mixed_rqs.parameters == mixed_params + simple_params_expected = { + 'string': 'string', + 'integer': 1, + 'tg_object': MessageEntity('type', 1, 1).to_dict(), + 'list': [1, 'string', MessageEntity('type', 1, 1).to_dict()], + } + video_value = { + 'media': input_media_video.media.attach_uri, + 'thumb': input_media_video.thumb.attach_uri, + 'type': input_media_video.type, + } + photo_value = {'media': input_media_photo.media.attach_uri, 'type': input_media_photo.type} + file_params_expected = { + 'inputfile_attach': inputfiles[True].attach_uri, + 'inputmedia': video_value, + 'inputmedia_list': [video_value, photo_value], + } + mixed_params_expected = simple_params_expected.copy() + mixed_params_expected.update(file_params_expected) + + assert simple_rqs.parameters == simple_params_expected + assert file_rqs.parameters == file_params_expected + assert mixed_rqs.parameters == mixed_params_expected - @pytest.mark.xfail(True, reason='Not adjusted yet') def test_json_parameters( self, simple_rqs, file_rqs, mixed_rqs, simple_jsons, file_jsons, mixed_jsons ): @@ -163,7 +178,6 @@ def test_json_parameters( assert file_rqs.json_parameters == file_jsons assert mixed_rqs.json_parameters == mixed_jsons - @pytest.mark.xfail(True, reason='Not adjusted yet') def test_json_payload( self, simple_rqs, file_rqs, mixed_rqs, simple_jsons, file_jsons, mixed_jsons ): @@ -171,18 +185,18 @@ def test_json_payload( assert file_rqs.json_payload == json.dumps(file_jsons).encode() assert mixed_rqs.json_payload == json.dumps(mixed_jsons).encode() - @pytest.mark.xfail(True, reason='Not adjusted yet') def test_multipart_data( self, simple_rqs, file_rqs, mixed_rqs, - inputfile, + inputfiles, input_media_video, input_media_photo, ): expected = { - inputfile.attach_name: inputfile.field_tuple, + inputfiles[True].attach_name: inputfiles[True].field_tuple, + 'inputfile_no_attach': inputfiles[False].field_tuple, input_media_photo.media.attach_name: input_media_photo.media.field_tuple, input_media_video.media.attach_name: input_media_video.media.field_tuple, input_media_video.thumb.attach_name: input_media_video.thumb.field_tuple, From f1f7e2462eb275e37de96e5e03e480fc84a19353 Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Sat, 23 Apr 2022 00:59:04 +0530 Subject: [PATCH 149/153] document a limit about send_document/voice --- telegram/_bot.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/telegram/_bot.py b/telegram/_bot.py index 7a236ab6c4b..08ebadf2757 100644 --- a/telegram/_bot.py +++ b/telegram/_bot.py @@ -1083,8 +1083,10 @@ async def send_document( changed in the future. Note: - The document argument can be either a file_id, an URL or a file from disk - ``open(filename, 'rb')`` + * The document argument can be either a file_id, an URL or a file from disk + ``open(filename, 'rb')``. + + * Sending by URL will currently only work ``GIF``, ``PDF`` & ``ZIP`` files. Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username @@ -1717,14 +1719,17 @@ async def send_voice( ) -> Message: """ Use this method to send audio files, if you want Telegram clients to display the file - as a playable voice message. For this to work, your audio must be in an .ogg file + as a playable voice message. For this to work, your audio must be in an ``.ogg`` file encoded with OPUS (other formats may be sent as Audio or Document). Bots can currently send voice messages of up to :tg-const:`telegram.constants.FileSizeLimit.FILESIZE_UPLOAD` in size, this limit may be changed in the future. Note: - The voice argument can be either a file_id, an URL or a file from disk - ``open(filename, 'rb')`` + * The voice argument can be either a file_id, an URL or a file from disk + ``open(filename, 'rb')``. + + * To use this method, the file must have the type ``audio/ogg`` and be no more than 1MB + in size. 1-20MB voice notes will be sent as files. Args: chat_id (:obj:`int` | :obj:`str`): Unique identifier for the target chat or username From 88b8e192264e452b7e9ca34749fddf9d235cd6ac Mon Sep 17 00:00:00 2001 From: Poolitzer <25934244+Poolitzer@users.noreply.github.com> Date: Sat, 23 Apr 2022 10:14:51 +0200 Subject: [PATCH 150/153] Fix: Small doc fixes + improvements --- docs/source/conf.py | 2 +- docs/source/index.rst | 6 +-- telegram/ext/_application.py | 11 ++-- telegram/ext/_applicationbuilder.py | 76 ++++++++++++++-------------- telegram/ext/_conversationhandler.py | 2 +- telegram/ext/_extbot.py | 8 +-- telegram/ext/_updater.py | 6 ++- 7 files changed, 56 insertions(+), 55 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index a92952873c6..f0a3fb5e8bb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -158,7 +158,7 @@ "announcement": 'PTB has undergone significant changes in v14. Please read the documentation ' 'carefully and also check out the transition guide in the ' '' - 'wiki', + 'wiki.', } # Add any paths that contain custom themes here, relative to this directory. diff --git a/docs/source/index.rst b/docs/source/index.rst index 4b8f910b262..9b42222ec18 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,19 +10,19 @@ Guides and tutorials ==================== If you're just starting out with the library, we recommend following our `"Your first Bot" `_ tutorial that you can find on our `wiki `_. -On our wiki you will also find guides like how to use handlers, webhooks, emoji, proxies and much more. +While being there, you will also find guides to learn how to use handlers, webhooks, proxies, making your bot persistent, and much more. Examples ======== -A great way to learn is by looking at examples. Ours can be found in our `examples folder on Github `_. +A great way to learn is by looking at examples. Ours can be found in our `examples folder on Github `_. Reference ========= Below you can find a reference of all the classes and methods in python-telegram-bot. -Apart from the `telegram.ext` package the objects should reflect the types defined in the `official Telegram Bot API documentation `_. +Apart from the `telegram.ext` package and the `Auxiliary` modules, the objects reflect the types defined in the `official Telegram Bot API documentation `_. .. toctree:: telegram.ext diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 9008797a757..fd4c1d2b408 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -77,7 +77,7 @@ class ApplicationHandlerStop(Exception): """ Raise this in a handler or an error handler to prevent execution of any other handler (even in - different group). + different groups). In order to use this exception in a :class:`telegram.ext.ConversationHandler`, pass the optional :paramref:`state` parameter instead of returning the next state: @@ -168,7 +168,7 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) .. seealso:: :meth:`add_handler`, :meth:`add_handlers`. - error_handlers (Dict[:term:`coroutine function`, :obj:`bool`]): A dict, where the keys are + error_handlers (Dict[:term:`coroutine function`, :obj:`bool`]): A dictionary where the keys are error handlers and the values indicate whether they are to be run blocking. .. seealso:: @@ -785,11 +785,11 @@ def create_task(self, coroutine: Coroutine, update: object = None) -> asyncio.Ta * If :paramref:`coroutine` raises an exception, it will be set on the task created by this method even though it's handled by :meth:`process_error`. * If the application is currently running, tasks created by this method will be - awaited by :meth:`stop`. + awaited with :meth:`stop`. Args: coroutine (:term:`coroutine function`): The coroutine to run as task. - update (:obj:`object`, optional): If passed, will be passed to :meth:`process_error` + update (:obj:`object`, optional): If set, will be passed to :meth:`process_error` as additional information for the error handlers. Moreover, the corresponding :attr:`chat_data` and :attr:`user_data` entries will be updated in the next run of :meth:`update_persistence` after the :paramref:`coroutine` is finished. @@ -974,7 +974,7 @@ def add_handler(self, handler: Handler[Any, CCT], group: int = DEFAULT_GROUP) -> The priority/order of handlers is determined as follows: * Priority of the group (lower group number == higher priority) - * The first handler in a group which should handle an update (see + * The first handler in a group which can handle an update (see :attr:`telegram.ext.Handler.check_update`) will be used. Other handlers from the group will not be used. The order in which handlers were added to the group defines the priority. @@ -1031,7 +1031,6 @@ def add_handlers( sequence(s) matters. See :meth:`add_handler` for details. .. versionadded:: 14.0 - .. seealso:: :meth:`add_handler` Args: handlers (List[:class:`telegram.ext.Handler`] | \ diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index 5cda8840d1f..d2d67549bcd 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -103,7 +103,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): * Unless a custom :class:`telegram.Bot` instance is set via :meth:`bot`, :meth:`build` will use :class:`telegram.ext.ExtBot` for the bot. - .. _`builder pattern`: https://en.wikipedia.org/wiki/Builder_pattern. + .. _`builder pattern`: https://en.wikipedia.org/wiki/Builder_pattern """ __slots__ = ( @@ -278,7 +278,7 @@ def build( def application_class( self: BuilderType, application_class: Type[Application], kwargs: Dict[str, object] = None ) -> BuilderType: - """Sets a custom subclass to be used instead of :class:`telegram.ext.Application`. The + """Sets a custom subclass instead of :class:`telegram.ext.Application`. The subclass's ``__init__`` should look like this .. code:: python @@ -301,7 +301,7 @@ def __init__(self, custom_arg_1, custom_arg_2, ..., **kwargs): return self def token(self: BuilderType, token: str) -> BuilderType: - """Sets the token to be used for :attr:`telegram.ext.Application.bot`. + """Sets the token for :attr:`telegram.ext.Application.bot`. Args: token (:obj:`str`): The token. @@ -317,7 +317,7 @@ def token(self: BuilderType, token: str) -> BuilderType: return self def base_url(self: BuilderType, base_url: str) -> BuilderType: - """Sets the base URL to be used for :attr:`telegram.ext.Application.bot`. If not called, + """Sets the base URL for :attr:`telegram.ext.Application.bot`. If not called, will default to ``'https://api.telegram.org/bot'``. .. seealso:: :paramref:`telegram.Bot.base_url`, `Local Bot API Server BuilderType: return self def base_file_url(self: BuilderType, base_file_url: str) -> BuilderType: - """Sets the base file URL to be used for :attr:`telegram.ext.Application.bot`. If not + """Sets the base file URL for :attr:`telegram.ext.Application.bot`. If not called, will default to ``'https://api.telegram.org/file/bot'``. .. seealso:: :paramref:`telegram.Bot.base_file_url`, `Local Bot API Server None: ) def request(self: BuilderType, request: BaseRequest) -> BuilderType: - """Sets a :class:`telegram.request.BaseRequest` object to be used for the + """Sets a :class:`telegram.request.BaseRequest` instance for the :paramref:`telegram.Bot.request` parameter of :attr:`telegram.ext.Application.bot`. .. seealso:: :meth:`get_updates_request` Args: - request (:class:`telegram.request.BaseRequest`): The request object. + request (:class:`telegram.request.BaseRequest`): The request instance. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -413,7 +413,7 @@ def request(self: BuilderType, request: BaseRequest) -> BuilderType: return self def connection_pool_size(self: BuilderType, connection_pool_size: int) -> BuilderType: - """Sets the size of the connection pool to be used for the + """Sets the size of the connection pool for the :paramref:`~telegram.request.HTTPXRequest.connection_pool_size` parameter of :attr:`telegram.Bot.request`. Defaults to ``128``. @@ -428,7 +428,7 @@ def connection_pool_size(self: BuilderType, connection_pool_size: int) -> Builde return self def proxy_url(self: BuilderType, proxy_url: str) -> BuilderType: - """Sets the proxy to be used for the :paramref:`~telegram.request.HTTPXRequest.proxy_url` + """Sets the proxy for the :paramref:`~telegram.request.HTTPXRequest.proxy_url` parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`. Args: @@ -443,7 +443,7 @@ def proxy_url(self: BuilderType, proxy_url: str) -> BuilderType: return self def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType: - """Sets the connection attempt timeout to be used for the + """Sets the connection attempt timeout for the :paramref:`~telegram.request.HTTPXRequest.connect_timeout` parameter of :attr:`telegram.Bot.request`. Defaults to ``5.0``. @@ -459,7 +459,7 @@ def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> Buil return self def read_timeout(self: BuilderType, read_timeout: Optional[float]) -> BuilderType: - """Sets the waiting timeout to be used for the + """Sets the waiting timeout for the :paramref:`~telegram.request.HTTPXRequest.read_timeout` parameter of :attr:`telegram.Bot.request`. Defaults to ``5.0``. @@ -475,7 +475,7 @@ def read_timeout(self: BuilderType, read_timeout: Optional[float]) -> BuilderTyp return self def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderType: - """Sets the write operation timeout to be used for the + """Sets the write operation timeout for the :paramref:`~telegram.request.HTTPXRequest.write_timeout` parameter of :attr:`telegram.Bot.request`. Defaults to ``5.0``. @@ -491,7 +491,7 @@ def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderT return self def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderType: - """Sets the connection pool's connection freeing timeout to be used for the + """Sets the connection pool's connection freeing timeout for the :paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`. @@ -507,14 +507,14 @@ def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderTyp return self def get_updates_request(self: BuilderType, get_updates_request: BaseRequest) -> BuilderType: - """Sets a :class:`telegram.request.BaseRequest` object to be used for the + """Sets a :class:`telegram.request.BaseRequest` instance for the :paramref:`~telegram.Bot.get_updates_request` parameter of :attr:`telegram.ext.Application.bot`. .. seealso:: :meth:`request` Args: - get_updates_request (:class:`telegram.request.BaseRequest`): The request object. + get_updates_request (:class:`telegram.request.BaseRequest`): The request instance. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -526,9 +526,9 @@ def get_updates_request(self: BuilderType, get_updates_request: BaseRequest) -> def get_updates_connection_pool_size( self: BuilderType, get_updates_connection_pool_size: int ) -> BuilderType: - """Sets the size of the connection pool to be used for the + """Sets the size of the connection pool for the :paramref:`telegram.request.HTTPXRequest.connection_pool_size` parameter which is used - for :meth:`telegram.Bot.get_updates`. Defaults to ``1``. + for the :meth:`telegram.Bot.get_updates` request. Defaults to ``1``. Args: get_updates_connection_pool_size (:obj:`int`): The size of the connection pool. @@ -541,7 +541,7 @@ def get_updates_connection_pool_size( return self def get_updates_proxy_url(self: BuilderType, get_updates_proxy_url: str) -> BuilderType: - """Sets the proxy to be used for the :paramref:`telegram.request.HTTPXRequest.proxy_url` + """Sets the proxy for the :paramref:`telegram.request.HTTPXRequest.proxy_url` parameter which is used for :meth:`telegram.Bot.get_updates`. Defaults to :obj:`None`. Args: @@ -558,9 +558,9 @@ def get_updates_proxy_url(self: BuilderType, get_updates_proxy_url: str) -> Buil def get_updates_connect_timeout( self: BuilderType, get_updates_connect_timeout: Optional[float] ) -> BuilderType: - """Sets the connection attempt timeout to be used for the + """Sets the connection attempt timeout for the :paramref:`telegram.request.HTTPXRequest.connect_timeout` parameter which is used for - :meth:`telegram.Bot.get_updates`. Defaults to ``5.0``. + the :meth:`telegram.Bot.get_updates` request. Defaults to ``5.0``. Args: get_updates_connect_timeout (:obj:`float`): See @@ -576,9 +576,9 @@ def get_updates_connect_timeout( def get_updates_read_timeout( self: BuilderType, get_updates_read_timeout: Optional[float] ) -> BuilderType: - """Sets the waiting timeout to be used for the - :paramref:`telegram.request.HTTPXRequest.read_timeout` parameter which is used for - :meth:`telegram.Bot.get_updates`. Defaults to ``5.0``. + """Sets the waiting timeout for the + :paramref:`telegram.request.HTTPXRequest.read_timeout` parameter which is used for the + :meth:`telegram.Bot.get_updates` request. Defaults to ``5.0``. Args: get_updates_read_timeout (:obj:`float`): See @@ -594,9 +594,9 @@ def get_updates_read_timeout( def get_updates_write_timeout( self: BuilderType, get_updates_write_timeout: Optional[float] ) -> BuilderType: - """Sets the write operation timeout to be used for the + """Sets the write operation timeout for the :paramref:`telegram.request.HTTPXRequest.write_timeout` parameter which is used for - :meth:`telegram.Bot.get_updates`. Defaults to ``5.0``. + the :meth:`telegram.Bot.get_updates` request. Defaults to ``5.0``. Args: get_updates_write_timeout (:obj:`float`): See @@ -612,9 +612,9 @@ def get_updates_write_timeout( def get_updates_pool_timeout( self: BuilderType, get_updates_pool_timeout: Optional[float] ) -> BuilderType: - """Sets the connection pool's connection freeing timeout to be used for the - :paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter which is used for - :meth:`telegram.Bot.get_updates`. Defaults to :obj:`None`. + """Sets the connection pool's connection freeing timeout for the + :paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter which is used for the + :meth:`telegram.Bot.get_updates` request. Defaults to :obj:`None`. Args: get_updates_pool_timeout (:obj:`float`): See @@ -633,7 +633,7 @@ def private_key( password: Union[bytes, FilePathInput] = None, ) -> BuilderType: """Sets the private key and corresponding password for decryption of telegram passport data - to be used for :attr:`telegram.ext.Application.bot`. + for :attr:`telegram.ext.Application.bot`. .. seealso:: `passportbot.py `_, `Telegram Passports @@ -666,14 +666,14 @@ def private_key( return self def defaults(self: BuilderType, defaults: 'Defaults') -> BuilderType: - """Sets the :class:`telegram.ext.Defaults` object to be used for + """Sets the :class:`telegram.ext.Defaults` instance for :attr:`telegram.ext.Application.bot`. .. seealso:: `Adding Defaults `_ Args: - defaults (:class:`telegram.ext.Defaults`): The defaults. + defaults (:class:`telegram.ext.Defaults`): The defaults instance. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. @@ -717,7 +717,7 @@ def bot( self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', bot: InBT, ) -> 'ApplicationBuilder[InBT, CCT, UD, CD, BD, JQ]': - """Sets a :class:`telegram.Bot` instance to be used for + """Sets a :class:`telegram.Bot` instance for :attr:`telegram.ext.Application.bot`. Instances of subclasses like :class:`telegram.ext.ExtBot` are also valid. @@ -736,7 +736,7 @@ def bot( return self # type: ignore[return-value] def update_queue(self: BuilderType, update_queue: Queue) -> BuilderType: - """Sets a :class:`asyncio.Queue` instance to be used for + """Sets a :class:`asyncio.Queue` instance for :attr:`telegram.ext.Application.update_queue`, i.e. the queue that the application will fetch updates from. Will also be used for the :attr:`telegram.ext.Application.updater`. If not called, a queue will be instantiated. @@ -759,7 +759,7 @@ def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) Warning: Processing updates concurrently is not recommended when stateful handlers like - :class:`telegram.ext.ConversationHandler` are used. Only use this, when you are sure + :class:`telegram.ext.ConversationHandler` are used. Only use this if you are sure that your bot does not (explicitly or implicitly) rely on updates being processed sequentially. @@ -780,7 +780,7 @@ def job_queue( self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', job_queue: InJQ, ) -> 'ApplicationBuilder[BT, CCT, UD, CD, BD, InJQ]': - """Sets a :class:`telegram.ext.JobQueue` instance to be used for + """Sets a :class:`telegram.ext.JobQueue` instance for :attr:`telegram.ext.Application.job_queue`. If not called, a job queue will be instantiated. @@ -809,7 +809,7 @@ def job_queue( return self # type: ignore[return-value] def persistence(self: BuilderType, persistence: 'BasePersistence') -> BuilderType: - """Sets a :class:`telegram.ext.BasePersistence` instance to be used for + """Sets a :class:`telegram.ext.BasePersistence` instance for :attr:`telegram.ext.Application.persistence`. Note: @@ -843,7 +843,7 @@ def context_types( self: 'ApplicationBuilder[BT, CCT, UD, CD, BD, JQ]', context_types: 'ContextTypes[InCCT, InUD, InCD, InBD]', ) -> 'ApplicationBuilder[BT, InCCT, InUD, InCD, InBD, JQ]': - """Sets a :class:`telegram.ext.ContextTypes` instance to be used for + """Sets a :class:`telegram.ext.ContextTypes` instance for :attr:`telegram.ext.Application.context_types`. .. seealso:: `contexttypesbot.py BuilderType: - """Sets a :class:`telegram.ext.Updater` instance to be used for + """Sets a :class:`telegram.ext.Updater` instance for :attr:`telegram.ext.Application.updater`. The :attr:`telegram.ext.Updater.bot` and :attr:`telegram.ext.Updater.update_queue` will be used for :attr:`telegram.ext.Application.bot` and :attr:`telegram.ext.Application.update_queue`, diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 5e52c035f0e..3f75af0eb82 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -185,7 +185,7 @@ class ConversationHandler(Handler[Update, CCT]): map a state to :attr:`END` to end the *parent* conversation from within the child conversation. For an example on nested :class:`ConversationHandler` s, see our `examples`_. - .. _`examples`: https://github.com/python-telegram-bot/python-telegram-bot/blob/master/examples + .. _`examples`: https://github.com/python-telegram-bot/python-telegram-bot/tree/master/examples#examples Args: entry_points (List[:class:`telegram.ext.Handler`]): A list of :obj:`Handler` objects that diff --git a/telegram/ext/_extbot.py b/telegram/ext/_extbot.py index 159fa9877f1..9f6d3678ec4 100644 --- a/telegram/ext/_extbot.py +++ b/telegram/ext/_extbot.py @@ -179,12 +179,12 @@ def insert_callback_data(self, update: Update) -> None: corresponding buttons within this update. Note: - Checks :attr:`telegram.Message.via_bot` and :attr:`telegram.Message.from_user` to check - if the reply markup (if any) was actually sent by this caches bot. If it was not, the - message will be returned unchanged. + Checks :attr:`telegram.Message.via_bot` and :attr:`telegram.Message.from_user` + to figure out if a) a reply markup exists and b) it was actually sent by this + cached bot. If not, the message will be returned unchanged. Note that this will fail for channel posts, as :attr:`telegram.Message.from_user` is - :obj:`None` for those! In the corresponding reply markups the callback data will be + :obj:`None` for those! In the corresponding reply markups, the callback data will be replaced by :class:`telegram.ext.InvalidCallbackData`. Warning: diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index fe074d45d07..73ec1cdf609 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -118,7 +118,7 @@ def running(self) -> bool: return self._running async def initialize(self) -> None: - """Initialize the Updater & the associated :attr:`bot` by calling + """Initializes the Updater & the associated :attr:`bot` by calling :meth:`telegram.Bot.initialize`. .. seealso:: @@ -221,7 +221,9 @@ async def start_polling( error_callback (Callable[[:exc:`telegram.error.TelegramError`], :obj:`None`], \ optional): Callback to handle :exc:`telegram.error.TelegramError` s that occur while calling :meth:`telegram.Bot.get_updates` during polling. Defaults to - :obj:`None`, in which case errors will be logged. + :obj:`None`, in which case errors will be logged. Callback signature:: + + def callback(error: telegram.error.TelegramError) Note: The :paramref:`error_callback` must *not* be a :term:`coroutine function`! If From 58c7781ed0404fc259994d3fe024d4241b0eeaf1 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 24 Apr 2022 09:58:17 +0200 Subject: [PATCH 151/153] forgot to commit review --- telegram/_files/inputfile.py | 2 +- telegram/request/_requestparameter.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/telegram/_files/inputfile.py b/telegram/_files/inputfile.py index 05e9d4eef3d..0f50efd7d31 100644 --- a/telegram/_files/inputfile.py +++ b/telegram/_files/inputfile.py @@ -55,7 +55,7 @@ class InputFile: Attributes: input_file_content (:obj:`bytes`): The binary content of the file to send. - attach_name (:obj:`str`, Optional ): If present, the parameter this file belongs to in + attach_name (:obj:`str`): Optional. If present, the parameter this file belongs to in the request to Telegram should point to the multipart data via a an URI of the form ``attach://`` URI. filename (:obj:`str`): Filename for the file to be sent. diff --git a/telegram/request/_requestparameter.py b/telegram/request/_requestparameter.py index 61fd306ec9d..311ff506f0d 100644 --- a/telegram/request/_requestparameter.py +++ b/telegram/request/_requestparameter.py @@ -46,13 +46,13 @@ class RequestParameter: Args: name (:obj:`str`): The name of the parameter. - value (:obj:`object` | :obj:None`): The value of the parameter. Must be JSON-dumpable. + value (:obj:`object` | :obj:`None`): The value of the parameter. Must be JSON-dumpable. input_files (List[:class:`telegram.InputFile`], optional): A list of files that should be uploaded along with this parameter. Attributes: name (:obj:`str`): The name of the parameter. - value (:obj:`object` | :obj:None`): The value of the parameter. + value (:obj:`object` | :obj:`None`): The value of the parameter. input_files (List[:class:`telegram.InputFile` | :obj:`None`): A list of files that should be uploaded along with this parameter. """ @@ -114,9 +114,7 @@ def _value_and_input_files_from_input( # pylint: disable=too-many-return-statem return value.value, [] if isinstance(value, InputFile): if value.attach_uri: - return value.attach_uri, [ - value, - ] + return value.attach_uri, [value] return None, [value] if isinstance(value, InputMedia) and isinstance(value.media, InputFile): From 66d9a13f96017553e658e02bddb191a7b6c18dd4 Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 24 Apr 2022 09:59:22 +0200 Subject: [PATCH 152/153] fix pre-commit --- telegram/ext/_application.py | 4 ++-- telegram/ext/_conversationhandler.py | 3 ++- telegram/ext/_extbot.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index fd4c1d2b408..d11323bf6fb 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -168,8 +168,8 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AbstractAsyncContextManager) .. seealso:: :meth:`add_handler`, :meth:`add_handlers`. - error_handlers (Dict[:term:`coroutine function`, :obj:`bool`]): A dictionary where the keys are - error handlers and the values indicate whether they are to be run blocking. + error_handlers (Dict[:term:`coroutine function`, :obj:`bool`]): A dictionary where the keys + are error handlers and the values indicate whether they are to be run blocking. .. seealso:: :meth:`add_error_handler` diff --git a/telegram/ext/_conversationhandler.py b/telegram/ext/_conversationhandler.py index 3f75af0eb82..43447d2ba2d 100644 --- a/telegram/ext/_conversationhandler.py +++ b/telegram/ext/_conversationhandler.py @@ -185,7 +185,8 @@ class ConversationHandler(Handler[Update, CCT]): map a state to :attr:`END` to end the *parent* conversation from within the child conversation. For an example on nested :class:`ConversationHandler` s, see our `examples`_. - .. _`examples`: https://github.com/python-telegram-bot/python-telegram-bot/tree/master/examples#examples + .. _`examples`: https://github.com/python-telegram-bot/python-telegram-bot/tree/master\ + /examples#examples Args: entry_points (List[:class:`telegram.ext.Handler`]): A list of :obj:`Handler` objects that diff --git a/telegram/ext/_extbot.py b/telegram/ext/_extbot.py index 9f6d3678ec4..ba526763489 100644 --- a/telegram/ext/_extbot.py +++ b/telegram/ext/_extbot.py @@ -181,7 +181,7 @@ def insert_callback_data(self, update: Update) -> None: Note: Checks :attr:`telegram.Message.via_bot` and :attr:`telegram.Message.from_user` to figure out if a) a reply markup exists and b) it was actually sent by this - cached bot. If not, the message will be returned unchanged. + bot. If not, the message will be returned unchanged. Note that this will fail for channel posts, as :attr:`telegram.Message.from_user` is :obj:`None` for those! In the corresponding reply markups, the callback data will be From 9389bc985d8ff6a96cace874da0a549f49bf16bb Mon Sep 17 00:00:00 2001 From: Hinrich Mahler <22366557+Bibo-Joshi@users.noreply.github.com> Date: Sun, 24 Apr 2022 10:30:07 +0200 Subject: [PATCH 153/153] finally fix certificate passing --- telegram/ext/_updater.py | 6 ++++-- tests/test_updater.py | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/telegram/ext/_updater.py b/telegram/ext/_updater.py index 73ec1cdf609..dc469134b8e 100644 --- a/telegram/ext/_updater.py +++ b/telegram/ext/_updater.py @@ -513,7 +513,9 @@ async def _start_webhook( # We pass along the cert to the webhook if present. await self._bootstrap( - cert=cert, + # Passing a Path or string only works if the bot is running against a local bot API + # server, so let's read the contents + cert=Path(cert).read_bytes() if cert else None, max_retries=bootstrap_retries, drop_pending_updates=drop_pending_updates, webhook_url=webhook_url, @@ -590,7 +592,7 @@ async def _bootstrap( webhook_url: Optional[str], allowed_updates: Optional[List[str]], drop_pending_updates: bool = None, - cert: Union[str, Path] = None, + cert: Optional[bytes] = None, bootstrap_interval: float = 1, ip_address: str = None, max_connections: int = 40, diff --git a/tests/test_updater.py b/tests/test_updater.py index b2e1fb14c4c..c37287ba856 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -667,7 +667,7 @@ async def serve_forever(*args, **kwargs): ) expected_set_webhook = dict( - certificate='certificate', + certificate=data_file('sslcert.pem').read_bytes(), max_connections=47, allowed_updates=['message'], ip_address='123.456.789', @@ -680,11 +680,10 @@ async def serve_forever(*args, **kwargs): drop_pending_updates=True, ip_address='123.456.789', max_connections=47, - cert='certificate', + cert=str(data_file('sslcert.pem').resolve()), ) await updater.stop() - expected_set_webhook['certificate'] = data_file('sslcert.pem') await updater.start_webhook( listen='listen-ssl', port=42,