8000 chore: moved truncation logic to conversation manager and added shoul… · NMsby/sdk-python@7c7f91e · GitHub
[go: up one dir, main page]

Skip to content

Commit 7c7f91e

Browse files
authored
chore: moved truncation logic to conversation manager and added should_truncate_results (strands-agents#192)
1 parent c28737c commit 7c7f91e

File tree

9 files changed

+166
-394
lines changed

9 files changed

+166
-394
lines changed

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager):
4444
invalid window states.
4545
"""
4646

47-
def __init__(self, window_size: int = 40):
47+
def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
4848
"""Initialize the sliding window conversation manager.
4949
5050
Args:
5151
window_size: Maximum number of messages to keep in the agent's history.
5252
Defaults to 40 messages.
53+
should_truncate_results: Truncate tool results when a message is too large for the model's context window
5354
"""
5455
self.window_size = window_size
56+
self.should_truncate_results = should_truncate_results
5557

5658
def apply_management(self, agent: "Agent") -> None:
5759
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.
@@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
127129
converted.
128130
"""
129131
messages = agent.messages
132+
133+
# Try to truncate the tool result first
134+
last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages)
135+
if last_message_idx_with_tool_results is not None and self.should_truncate_results:
136+
logger.debug(
137+
"message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results
138+
)
139+
results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results)
140+
if results_truncated:
141+
logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results)
142+
return
143+
144+
# Try to trim index id when tool result cannot be truncated anymore
130145
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
131146
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
132147

@@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
151166

152167
# Overwrite message history
153168
messages[:] = messages[trim_index:]
169+
170+
def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
171+
"""Truncate tool results in a message to reduce context size.
172+
173+
When a message contains tool results that are too large for the model's context window, this function
174+
replaces the content of those tool results with a simple error message.
175+
176+
Args:
177+
messages: The conversation message history.
178+
msg_idx: Index of the message containing tool results to truncate.
179+
180+
Returns:
181+
True if any changes were made to the message, False otherwise.
182+
"""
183+
if msg_idx >= len(messages) or msg_idx < 0:
184+
return False
185+
186+
message = messages[msg_idx]
187+
changes_made = False
188+
tool_result_too_large_message = "The tool result was too large!"
189+
for i, content in enumerate(message.get("content", [])):
190+
if isinstance(content, dict) and "toolResult" in content:
191+
tool_result_content_text = next(
192+
(item["text"] for item in content["toolResult"]["content"] if "text" in item),
193+
"",
194+
)
195+
# make the overwriting logic togglable
196+
if (
197+
message["content"][i]["toolResult"]["status"] == "error"
198+
and tool_result_content_text == tool_result_too_large_message
199+
):
200+
logger.info("ToolResult has already been updated, skipping overwrite")
201+
return False
202+
# Update status to error with informative message
203+
message["content"][i]["toolResult"]["status"] = "error"
204+
message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}]
205+
changes_made = True
206+
207+
return chan F438 ges_made
208+
209+
def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]:
210+
"""Find the index of the last message containing tool results.
211+
212+
This is useful for identifying messages that might need to be truncated to reduce context size.
213+
214+
Args:
215+
messages: The conversation message history.
216+
217+
Returns:
218+
Index of the last message with tool results, or None if no such message exists.
219+
"""
220+
# Iterate backwards through all messages (from newest to oldest)
221+
for idx in range(len(messages) - 1, -1, -1):
222+
# Check if this message has any content with toolResult
223+
current_message = messages[idx]
224+
has_tool_result = False
225+
226+
for content in current_message.get("content", []):
227+
if isinstance(content, dict) and "toolResult" in content:
228+
has_tool_result = True
229+
break
230+
231+
if has_tool_result:
232+
return idx
233+
234+
return None

src/strands/event_loop/error_handler.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,9 @@
66

77
import logging
88
import time
9-
from typing import Any, Dict, Optional, Tuple
9+
from typing import Any, Dict, Tuple
1010

11-
from ..telemetry.metrics import EventLoopMetrics
12-
from ..types.content import Message, Messages
13-
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
14-
from ..types.models import Model
15-
from ..types.streaming import StopReason
16-
from .message_processor import find_last_message_with_tool_results, truncate_tool_results
11+
from ..types.exceptions import ModelThrottledException
1712

1813
logger = logging.getLogger(__name__)
1914

@@ -59,63 +54,3 @@ def handle_throttling_error(
5954

6055
callback_handler(force_stop=True, force_stop_reason=str(e))
6156
return False, current_delay
62-
63-
64-
def handle_input_too_long_error(
65-
e: ContextWindowOverflowException,
66-
messages: Messages,
67-
model: Model,
68-
system_prompt: Optional[str],
69-
tool_config: Any,
70-
callback_handler: Any,
71-
tool_handler: Any,
72-
kwargs: Dict[str, Any],
73-
) -> Tuple[StopReason, Message, EventLoopMetrics, Any]:
74-
"""Handle 'Input is too long' errors by truncating tool results.
75-
76-
When a context window overflow exception occurs (input too long for the model), this function attempts to recover
77-
by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the
78-
function will make a call to the event loop.
79-
80-
Args:
81-
e: The ContextWindowOverflowException that occurred.
82-
messages: The conversation message history.
83-
model: Model provider for running inference.
84-
system_prompt: System prompt for the model.
85-
tool_config: Tool configuration for the conversation.
86-
callback_handler: Callback for processing events as they happen.
87-
tool_handler: Handler for tool execution.
88-
kwargs: Additional arguments for the event loop.
89-
90-
Returns:
91-
The results from the event loop call if successful.
92-
93-
Raises:
94-
ContextWindowOverflowException: If messages cannot be truncated.
95-
"""
96-
from .event_loop import recurse_event_loop # Import here to avoid circular imports
97-
98-
# Find the last message with tool results
99-
last_message_with_tool_results = find_last_message_with_tool_results(messages)
100-
101-
# If we found a message with toolResult
102-
if last_message_with_tool_results is not None:
103-
logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results)
104-
105-
# Truncate the tool results in this message
106-
truncate_tool_results(messages, last_message_with_tool_results)
107-
108-
return recurse_event_loop(
109-
model=model,
110-
system_prompt=system_prompt,
111-
messages=messages,
112-
tool_config=tool_config,
113-
callback_handler=callback_handler,
114-
tool_handler=tool_handler,
115-
**kwargs,
116-
)
117-
118-
# If we can't handle this error, pass it up
119-
callback_handler(force_stop=True, force_stop_reason=str(e))
120-
logger.error("an exception occurred in event_loop_cycle | %s", e)
121-
raise ContextWindowOverflowException() from e

src/strands/event_loop/event_loop.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..types.models import Model
2323
from ..types.streaming import Metrics, StopReason
2424
from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse
25-
from .error_handler import handle_input_too_long_error, handle_throttling_error
25+
from .error_handler import handle_throttling_error
2626
from .message_processor import clean_orphaned_empty_tool_uses
2727
from .streaming import stream_messages
2828

@@ -160,16 +160,7 @@ def event_loop_cycle(
160160
except ContextWindowOverflowException as e:
161161
if model_invoke_span:
162162
tracer.end_span_with_error(model_invoke_span, str(e), e)
163-
return handle_input_too_long_error(
164-
e,
165-
messages,
166-
model,
167-
system_prompt,
168-
tool_config,
169-
callback_handler,
170-
tool_handler,
171-
kwargs,
172-
)
163+
raise e
173164

174165
except ModelThrottledException as e:
175166
if model_invoke_span:
@@ -248,6 +239,10 @@ def event_loop_cycle(
248239
# Don't invoke the callback_handler or log the exception - we already did it when we
249240
# raised the exception and we don't need that duplication.
250241
raise
242+
except ContextWindowOverflowException as e:
243+
if cycle_span:
244+
tracer.end_span_with_error(cycle_span, str(e), e)
245+
raise e
251246
except Exception as e:
252247
if cycle_span:
253248
tracer.end_span_with_error(cycle_span, str(e), e)

src/strands/event_loop/message_processor.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import logging
8-
from typing import Dict, Optional, Set, Tuple
8+
from typing import Dict, Set, Tuple
99

1010
from ..types.content import Messages
1111

@@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool:
103103
logger.warning("failed to fix orphaned tool use | %s", e)
104104

105105
return True
106-
107-
108-
def find_last_message_with_tool_results(messages: Messages) -> Optional[int]:
109-
"""Find the index of the last message containing tool results.
110-
111-
This is useful for identifying messages that might need to be truncated to reduce context size.
112-
113-
Args:
114-
messages: The conversation message history.
115-
116-
Returns:
117-
Index of the last message with tool results, or None if no such message exists.
118-
"""
119-
# Iterate backwards through all messages (from newest to oldest)
120-
for idx in range(len(messages) - 1, -1, -1):
121-
# Check if this message has any content with toolResult
122-
current_message = messages[idx]
123-
has_tool_result = False
124-
125-
for content in current_message.get("content", []):
126-
if isinstance(content, dict) and "toolResult" in content:
127-
has_tool_result = True
128-
break
129-
130-
if has_tool_result:
131-
return idx
132-
133-
return None
134-
135-
136-
def truncate_tool_results(messages: Messages, msg_idx: int) -> bool:
137-
"""Truncate tool results in a message to reduce context size.
138-
139-
When a message contains tool results that are too large for the model's context window, this function replaces the
140-
content of those tool results with a simple error message.
141-
142-
Args:
143-
messages: The conversation message history.
144-
msg_idx: Index of the message containing tool results to truncate.
145-
146-
Returns:
147-
True if any changes were made to the message, False otherwise.
148-
"""
149-
if msg_idx >= len(messages) or msg_idx < 0:
150-
return False
151-
152-
message = messages[msg_idx]
153-
changes_made = False
154-
155-
for i, content in enumerate(message.get("content", [])):
156-
if isinstance(content, dict) and "toolResult" in content:
157-
# Update status to error with informative message
158-
message["content"][i]["toolResult"]["status"] = "error"
159-
message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}]
160-
changes_made = True
161-
162-
return changes_made

tests/strands/agent/test_agent.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):
438438

439439

440440
def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool):
441-
conversation_manager = SlidingWindowConversationManager(window_size=500)
441+
conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False)
442442
conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager)
443443
agent.conversation_manager = conversation_manager_spy
444444

@@ -484,10 +484,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc
484484
agent("Test!")
485485

486486

487+
def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent):
488+
messages: Messages = [
489+
{"role": "user", "content": [{"text": "Hello!"}]},
490+
{
491+
"role": "assistant",
492+
"content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}],
493+
},
494+
{
495+
"role": "user",
496+
"content": [
497+
{"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}}
498+
],
499+
},
500+
]
501+
agent.messages = messages
502+
503+
mock_model.mock_converse.side_effect = ContextWindowOverflowException(
504+
RuntimeError("Input is too long for requested model")
505+
)
506+
507+
with pytest.raises(ContextWindowOverflowException):
508+
agent("Test!")
509+
510+
487511
def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool):
488512
conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager)
489513
agent.conversation_manager = conversation_manager_spy
490514

515+
messages: Messages = [
516+
{"role": "user", "content": [{"text": "Hello!"}]},
517+
{
518+
"role": "assistant",
519+
"content": [{"text": "Hi!"}],
520+
},
521+
]
522+
agent.messages = messages
523+
491524
mock_model.mock_converse.side_effect = [
492525
[
493526
{
@@ -504,6 +537,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool):
504537
{"contentBlockStop": {}},
505538
{"messageStop": {"stopReason": "tool_use"}},
506539
],
540+
# Will truncate the tool result
541+
ContextWindowOverflowException(RuntimeError("Input is too long for requested model")),
542+
# Will reduce the context
507543
ContextWindowOverflowException(RuntimeError("Input is too long for requested model")),
508544
[],
509545
]
@@ -538,7 +574,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool):
538574
unittest.mock.ANY,
539575
)
540576

541-
conversation_manager_spy.reduce_context.assert_not_called()
577+
assert conversation_manager_spy.reduce_context.call_count == 2
542578
assert conversation_manager_spy.apply_management.call_count == 1
543579

544580

0 commit comments

Comments
 (0)
0