-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathagent.py
More file actions
130 lines (103 loc) · 4.43 KB
/
agent.py
File metadata and controls
130 lines (103 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Core loop for the nano code agent."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Protocol
from apecode.tools import ToolRegistry
class ChatModel(Protocol):
"""Protocol for chat model adapters."""
def complete(self, *, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> dict[str, Any]:
"""Return one assistant message."""
def _coerce_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "".join(parts)
return str(content)
@dataclass(slots=True)
class AgentConfig:
"""Runtime knobs for agent execution."""
max_steps: int = 20
@dataclass(slots=True)
class AgentCallbacks:
"""Optional event callbacks. Agent stays framework-agnostic — no Rich import."""
on_status: Callable[[str], None] | None = None
"""Called with status text ("Thinking...") or empty string to clear."""
on_thinking: Callable[[str], None] | None = None
"""Called when the model returns reasoning_content."""
on_tool_call: Callable[[str, str], None] | None = None
"""Called before tool execution with (tool_name, arguments_json)."""
on_tool_result: Callable[[str, str], None] | None = None
"""Called after tool execution with (tool_name, result_text)."""
class NanoCodeAgent:
"""A tiny tool-calling loop with Chat Completions."""
def __init__(
self,
*,
model: ChatModel,
tools: ToolRegistry,
system_prompt: str,
config: AgentConfig | None = None,
callbacks: AgentCallbacks | None = None,
# Legacy single callback kept for backwards compat with tests
on_tool_call: Callable[[str, str], None] | None = None,
) -> None:
self.model = model
self.tools = tools
self.config = config or AgentConfig()
self.cb = callbacks or AgentCallbacks(on_tool_call=on_tool_call)
self.messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
def _fire(self, name: str, *args: Any) -> None:
fn = getattr(self.cb, name, None)
if fn is not None:
fn(*args)
def run(self, user_input: str) -> str:
"""Run one user turn to completion."""
self.messages.append({"role": "user", "content": user_input})
for _ in range(self.config.max_steps):
self._fire("on_status", "Thinking...")
assistant = self.model.complete(
messages=self.messages,
tools=self.tools.as_openai_tools(),
)
self._fire("on_status", "")
# Show thinking if present
reasoning = assistant.get("reasoning_content")
if reasoning:
self._fire("on_thinking", str(reasoning))
tool_calls = assistant.get("tool_calls") or []
assistant_record: dict[str, Any] = {
"role": "assistant",
"content": assistant.get("content"),
}
# Preserve provider-specific fields (e.g. reasoning_content for thinking models)
for key in ("reasoning_content",):
if assistant.get(key):
assistant_record[key] = assistant[key]
if tool_calls:
assistant_record["tool_calls"] = tool_calls
self.messages.append(assistant_record)
if not tool_calls:
return _coerce_text(assistant.get("content"))
for call in tool_calls:
call_id = str(call.get("id", ""))
function = call.get("function") or {}
name = str(function.get("name", ""))
arguments = str(function.get("arguments", "{}"))
self._fire("on_tool_call", name, arguments)
result = self.tools.execute(name, arguments)
self._fire("on_tool_result", name, result)
self.messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"content": result,
}
)
raise RuntimeError(f"max steps exceeded ({self.config.max_steps})")