8000 Support `allow_paid_broadcast` in `AIORateLimiter` (#4627) · vavasik800/python-telegram-bot@61b87ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 61b87ba

Browse files
authored
Support allow_paid_broadcast in AIORateLimiter (python-telegram-bot#4627)
1 parent dd592cd commit 61b87ba

File tree

2 files changed

+111
-23
lines changed

2 files changed

+111
-23
lines changed

telegram/ext/_aioratelimiter.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
except ImportError:
3333
AIO_LIMITER_AVAILABLE = False
3434

35+
from telegram import constants
3536
from telegram._utils.logging import get_logger
3637
from telegram._utils.types import JSONDict
3738
from telegram.error import RetryAfter
@@ -86,7 +87,8 @@ class AIORateLimiter(BaseRateLimiter[int]):
8687
* A :exc:`~telegram.error.RetryAfter` exception will halt *all* requests for
8788
:attr:`~telegram.error.RetryAfter.retry_after` + 0.1 seconds. This may be stricter than
8889
necessary in some cases, e.g. the bot may hit a rate limit in one group but might still
89-
be allowed to send messages in another group.
90+
be allowed to send messages in another group or with
91+
:paramref:`~telegram.Bot.send_message.allow_paid_broadcast` set to :obj:`True`.
9092
9193
Tip:
9294
With `Bot API 7.1 <https://core.telegram.org/bots/api-changelog#october-31-2024>`_
@@ -96,10 +98,10 @@ class AIORateLimiter(BaseRateLimiter[int]):
9698
:tg-const:`telegram.constants.FloodLimit.PAID_MESSAGES_PER_SECOND` messages per second by
9799
paying a fee in Telegram Stars.
98100
99-
.. caution::
100-
This class currently doesn't take the
101-
:paramref:`~telegram.Bot.send_message.allow_paid_broadcast` parameter into account.
102-
This means that the rate limiting is applied just like for any other message.
101+
.. versionchanged:: NEXT.VERSION
102+
This class automatically takes the
103+
:paramref:`~telegram.Bot.send_message.allow_paid_broadcast` parameter into account and
104+
throttles the requests accordingly.
103105
104106
8000 Note:
105107
This class is to be understood as minimal effort reference implementation.
@@ -114,23 +116,25 @@ class AIORateLimiter(BaseRateLimiter[int]):
114116
Args:
115117
overall_max_rate (:obj:`float`): The maximum number of requests allowed for the entire bot
116118
per :paramref:`overall_time_period`. When set to 0, no rate limiting will be applied.
117-
Defaults to ``30``.
119+
Defaults to :tg-const:`telegram.constants.FloodLimit.MESSAGES_PER_SECOND`.
118120
overall_time_period (:obj:`float`): The time period (in seconds) during which the
119121
:paramref:`overall_max_rate` is enforced. When set to 0, no rate limiting will be
120-
applied. Defaults to 1.
122+
applied. Defaults to ``1``.
121123
group_max_rate (:obj:`float`): The maximum number of requests allowed for requests related
122124
to groups and channels per :paramref:`group_time_period`. When set to 0, no rate
123-
limiting will be applied. Defaults to 20.
125+
limiting will be applied. Defaults to
126+
:tg-const:`telegram.constants.FloodLimit.MESSAGES_PER_MINUTE_PER_GROUP`.
124127
group_time_period (:obj:`float`): The time period (in seconds) during which the
125128
:paramref:`group_max_rate` is enforced. When set to 0, no rate limiting will be
126-
applied. Defaults to 60.
129+
applied. Defaults to ``60``.
127130
max_retries (:obj:`int`): The maximum number of retries to be made in case of a
128131
:exc:`~telegram.error.RetryAfter` exception.
129132
If set to 0, no retries will be made. Defaults to ``0``.
130133
131134
"""
132135

133136
__slots__ = (
137+
"_apb_limiter",
134138
"_base_limiter",
135139
"_group_limiters",
136140
"_group_max_rate",
@@ -141,9 +145,9 @@ class AIORateLimiter(BaseRateLimiter[int]):
141145

142146
def __init__(
143147
self,
144-
overall_max_rate: float = 30,
148+
overall_max_rate: float = constants.FloodLimit.MESSAGES_PER_SECOND,
145149
overall_time_period: float = 1,
146-
group_max_rate: float = 20,
150+
group_max_rate: float = constants.FloodLimit.MESSAGES_PER_MINUTE_PER_GROUP,
147151
group_time_period: float = 60,
148152
max_retries: int = 0,
149153
) -> None:
@@ -167,6 +171,9 @@ def __init__(
167171
self._group_time_period = 0
168172

169173
self._group_limiters: dict[Union[str, int], AsyncLimiter] = {}
174+
self._apb_limiter: AsyncLimiter = AsyncLimiter(
175+
max_rate=constants.FloodLimit.PAID_MESSAGES_PER_SECOND, time_period=1
176+
)
170177
self._max_retries: int = max_retries
171178
self._retry_after_event = asyncio.Event()
172179
self._retry_after_event.set()
@@ -201,21 +208,30 @@ async def _run_request(
201208
self,
202209
chat: bool,
203210
group: Union[str, int, bool],
211+
allow_paid_broadcast: bool,
204212
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, list[JSONDict]]]],
205213
args: Any,
206214
kwargs: dict[str, Any],
207215
) -> Union[bool, JSONDict, list[JSONDict]]:
208-
base_context = self._base_limiter if (chat and self._base_limiter) else null_context()
209-
group_context = (
210-
self._get_group_limiter(group) if group and self._group_max_rate else null_context()
211-
)
212-
213-
async with group_context, base_context:
216+
async def inner() -> Union[bool, JSONDict, list[JSONDict]]:
214217
# In case a retry_after was hit, we wait with processing the request
215218
await self._retry_after_event.wait()
216-
217219
return await callback(*args, **kwargs)
218220

221+
if allow_paid_broadcast:
222+
async with self._apb_limiter:
223+
return await inner()
224+
else:
225+
base_context = self._base_limiter if (chat and self._base_limiter) else null_context()
226+
group_context = (
227+
self._get_group_limiter(group)
228+
if group and self._group_max_rate
229+
else null_context()
230+
)
231+
232+
async with group_context, base_context:
233+
return await inner()
234+
219235
# mypy doesn't understand that the last run of the for loop raises an exception
220236
async def process_request(
221237
self,
@@ -242,6 +258,7 @@ async def process_request(
242258
group: Union[int, str, bool] = False
243259
chat: bool = False
244260
chat_id = data.get("chat_id")
261+
allow_paid_broadcast = data.get("allow_paid_broadcast", False)
245262
if chat_id is not None:
246263
chat = True
247264

@@ -257,7 +274,12 @@ async def process_request(
257274
for i in range(max_retries + 1):
258275
try:
259276
return await self._run_request(
260-
chat=chat, group=group, callback=callback, args=args, kwargs=kwargs
277+
chat=chat,
278+
group=group,
279+
allow_paid_broadcast=allow_paid_broadcast,
280+
callback=callback,
281+
args=args,
282+
kwargs=kwargs,
261283
)
262284
except RetryAfter as exc:
263285
if i == max_retries:

tests/ext/test_ratelimiter.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import json
2727
import platform
2828
import time
29+
from collections import Counter
2930
from http import HTTPStatus
3031

3132
import pytest
@@ -148,7 +149,9 @@ async def do_request(self, *args, **kwargs):
148149
@pytest.mark.flaky(10, 1) # Timings aren't quite perfect
149150
class TestAIORateLimiter:
150151
count = 0
152+
apb_count = 0
151153
call_times = []
154+
apb_call_times = []
152155

153156
class CountRequest(BaseRequest):
154157
def __init__(self, retry_after=None):
@@ -161,8 +164,16 @@ async def shutdown(self) -> None:
161164
pass
162165

163166
async def do_request(self, *args, **kwargs):
164-
TestAIORateLimiter.count += 1
165-
TestAIORateLimiter.call_times.append(time.time())
167+
request_data = kwargs.get("request_data")
168+
allow_paid_broadcast = request_data.parameters.get("allow_paid_broadcast", False)
169+
170+
if allow_paid_broadcast:
171+
TestAIORateLimiter.apb_count += 1
172+
TestAIORateLimiter.apb_call_times.append(time.time())
173+
else:
174+
TestAIORateLimiter.count += 1
175+
TestAIORateLimiter.call_times.append(time.time())
176+
166177
if self.retry_after:
167178
raise RetryAfter(retry_after=1)
168179

@@ -190,10 +201,10 @@ async def do_request(self, *args, **kwargs):
190201

191202
@pytest.fixture(autouse=True)
192203
def _reset(self):
193-
self.count = 0
194204
TestAIORateLimiter.count = 0
195-
self.call_times = []
196205
TestAIORateLimiter.call_times = []
206+
TestAIORateLimiter.apb_count = 0
207+
TestAIORateLimiter.apb_call_times = []
197208

198209
@pytest.mark.parametrize("max_retries", [0, 1, 4])
199210
async def test_max_retries(self, bot, max_retries):
@@ -358,3 +369,58 @@ async def test_group_caching(self, bot, intermediate):
358369
finally:
359370
TestAIORateLimiter.count = 0
360371
TestAIORateLimiter.call_times = []
372+
373+
async def test_allow_paid_broadcast(self, bot):
374+
try:
375+
rl_bot = ExtBot(
376+
token=bot.token,
377+
request=self.CountRequest(retry_after=None),
378+
rate_limiter=AIORateLimiter(),
379+
)
380+
381+
async with rl_bot:
382+
apb_tasks = {}
383+
non_apb_tasks = {}
384+
for i in range(3000):
385+
apb_tasks[i] = asyncio.create_task(
386+
rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=True)
387+
)
388+
389+
number = 2
390+
for i in range(number):
391+
non_apb_tasks[i] = asyncio.create_task(
392+
rl_bot.send_message(chat_id=-1, text="test")
393+
)
394+
non_apb_tasks[i + number] = asyncio.create_task(
395+
rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=False)
396+
)
397+
398+
await asyncio.sleep(0.1)
399+
# We expect 5 non-apb requests:
400+
# 1: `get_me` from `async with rl_bot`
401+
# 2-5: `send_message`
402+
assert TestAIORateLimiter.count == 5
403+
assert sum(1 for task in non_apb_tasks.values() if task.done()) == 4
404+
405+
# ~2 second after start
406+
# We do the checks once all apb_tasks are done as apparently getting the timings
407+
# right to check after 1 second is hard
408+
await asyncio.sleep(2.1 - 0.1)
409+
assert all(task.done() for task in apb_tasks.values())
410+
411+
apb_call_times = [
412+
ct - TestAIORateLimiter.apb_call_times[0]
413+
for ct in TestAIORateLimiter.apb_call_times
414+
]
415+
apb_call_times_dict = Counter(map(int, apb_call_times))
416+
417+
# We expect ~2000 apb requests after the first second
418+
# 2000 (>>1000), since we have a floating window logic such that an initial
419+
# burst is allowed that is hard to measure in the tests
420+
assert apb_call_times_dict[0] <= 2000
421+
assert apb_call_times_dict[0] + apb_call_times_dict[1] < 3000
422+
assert sum(apb_call_times_dict.values()) == 3000
423+
424+
finally:
425+
# cleanup
426+
await asyncio.gather(*apb_tasks.values(), *non_apb_tasks.values())

0 commit comments

Comments
 (0)
0