8000 fix: Update throttling logic to use exponential back-off (#223) · strands-agents/sdk-python@52c68aa · GitHub
[go: up one dir, main page]

Skip to content

Commit 52c68aa

Browse files
authored
fix: Update throttling logic to use exponential back-off (#223)
current_delay was being thrown away and not applied to subsequent retries Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 4dd0819 commit 52c68aa

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def event_loop_cycle(
136136
metrics: Metrics
137137

138138
# Retry loop for handling throttling exceptions
139+
current_delay = INITIAL_DELAY
139140
for attempt in range(MAX_ATTEMPTS):
140141
model_id = model.config.get("model_id") if hasattr(model, "config") else None
141142
model_invoke_span = tracer.start_model_invoke_span(
@@ -168,7 +169,7 @@ def event_loop_cycle(
168169

169170
# Handle throttling errors with exponential backoff
170171
should_retry, current_delay = handle_throttling_error(
171-
e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs
172+
e, attempt, MAX_ATTEMPTS, current_delay, MAX_DELAY, callback_handler, kwargs
172173
)
173174
if should_retry:
174175
continue

tests/strands/event_loop/test_event_loop.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
1212

1313

14+
@pytest.fixture
15+
def mock_time():
16+
"""Fixture to mock the time module in the error_handler."""
17+
with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock:
18+
yield mock
19+
20+
1421
@pytest.fixture
1522
def model():
1623
return unittest.mock.Mock()
@@ -157,8 +164,8 @@ def test_event_loop_cycle_text_response(
157164
assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
158165

159166

160-
@unittest.mock.patch.object(strands.event_loop.error_handler, "time")
161167
def test_event_loop_cycle_text_response_throttling(
168+
mock_time,
162169
model,
163170
model_id,
164171
system_prompt,
@@ -191,6 +198,53 @@ def test_event_loop_cycle_text_response_throttling(
191198
exp_request_state = {}
192199

193200
assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
201+
# Verify that sleep was called once with the initial delay
202+
mock_time.sleep.assert_called_once()
203+
204+
205+
def test_event_loop_cycle_exponential_backoff(
206+
mock_time,
207+
model,
208+
model_id,
209+
system_prompt,
210+
messages,
211+
tool_config,
212+
callback_handler,
213+
tool_handler,
214+
tool_execution_handler,
215+
):
216+
"""Test that the exponential backoff works correctly with multiple retries."""
217+
# Set up the model to raise throttling exceptions multiple times before succeeding
218+
model.converse.side_effect = [
219+
ModelThrottledException("ThrottlingException | ConverseStream"),
220+
ModelThrottledException("ThrottlingException | ConverseStream"),
221+
ModelThrottledException("ThrottlingException | ConverseStream"),
222+
[
223+
{"contentBlockDelta": {"delta": {"text": "test text"}}},
224+
{"contentBlockStop": {}},
225+
],
226+
]
227+
228+
tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle(
229+
model=model,
230+
model_id=model_id,
231+
system_prompt=system_prompt,
232+
messages=messages,
233+
tool_config=tool_config,
234+
callback_handler=callback_handler,
235+
tool_handler=tool_handler,
236+
tool_execution_handler=tool_execution_handler,
237+
)
238+
239+
# Verify the final response
240+
assert tru_stop_reason == "end_turn"
241+
assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]}
242+
assert tru_request_state == {}
243+
244+
# Verify that sleep was called with increasing delays
245+
# Initial delay is 4, then 8, then 16
246+
assert mock_time.sleep.call_count == 3
247+
assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)]
194248

195249

196250
def test_event_loop_cycle_text_response_error(

0 commit comments

Comments
 (0)
0