|
11 | 11 | from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
|
12 | 12 |
|
13 | 13 |
|
| 14 | +@pytest.fixture |
| 15 | +def mock_time
10000
span>(): |
| 16 | + """Fixture to mock the time module in the error_handler.""" |
| 17 | + with unittest.mock.patch.object(strands.event_loop.error_handler, "time") as mock: |
| 18 | + yield mock |
| 19 | + |
| 20 | + |
14 | 21 | @pytest.fixture
|
15 | 22 | def model():
|
16 | 23 | return unittest.mock.Mock()
|
@@ -157,8 +164,8 @@ def test_event_loop_cycle_text_response(
|
157 | 164 | assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
|
158 | 165 |
|
159 | 166 |
|
160 |
| -@unittest.mock.patch.object(strands.event_loop.error_handler, "time") |
161 | 167 | def test_event_loop_cycle_text_response_throttling(
|
| 168 | + mock_time, |
162 | 169 | model,
|
163 | 170 | model_id,
|
164 | 171 | system_prompt,
|
@@ -191,6 +198,53 @@ def test_event_loop_cycle_text_response_throttling(
|
191 | 198 | exp_request_state = {}
|
192 | 199 |
|
193 | 200 | assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
|
| 201 | + # Verify that sleep was called once with the initial delay |
| 202 | + mock_time.sleep.assert_called_once() |
| 203 | + |
| 204 | + |
| 205 | +def test_event_loop_cycle_exponential_backoff( |
| 206 | + mock_time, |
| 207 | + model, |
| 208 | + model_id, |
| 209 | + system_prompt, |
| 210 | + messages, |
| 211 | + tool_config, |
| 212 | + callback_handler, |
| 213 | + tool_handler, |
| 214 | + tool_execution_handler, |
| 215 | +): |
| 216 | + """Test that the exponential backoff works correctly with multiple retries.""" |
| 217 | + # Set up the model to raise throttling exceptions multiple times before succeeding |
| 218 | + model.converse.side_effect = [ |
| 219 | + ModelThrottledException("ThrottlingException | ConverseStream"), |
| 220 | + ModelThrottledException("ThrottlingException | ConverseStream"), |
| 221 | + ModelThrottledException("ThrottlingException | ConverseStream"), |
| 222 | + [ |
| 223 | + {"contentBlockDelta": {"delta": {"text": "test text"}}}, |
| 224 | + {"contentBlockStop": {}}, |
| 225 | + ], |
| 226 | + ] |
| 227 | + |
| 228 | + tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( |
| 229 | + model=model, |
| 230 | + model_id=model_id, |
| 231 | + system_prompt=system_prompt, |
| 232 | + messages=messages, |
| 233 | + tool_config=tool_config, |
| 234 | + callback_handler=callback_handler, |
| 235 | + tool_handler=tool_handler, |
| 236 | + tool_execution_handler=tool_execution_handler, |
| 237 | + ) |
| 238 | + |
| 239 | + # Verify the final response |
| 240 | + assert tru_stop_reason == "end_turn" |
| 241 | + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} |
| 242 | + assert tru_request_state == {} |
| 243 | + |
| 244 | + # Verify that sleep was called with increasing delays |
| 245 | + # Initial delay is 4, then 8, then 16 |
| 246 | + assert mock_time.sleep.call_count == 3 |
| 247 | + assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] |
194 | 248 |
|
195 | 249 |
|
196 | 250 | def test_event_loop_cycle_text_response_error(
|
|
0 commit comments