8000 `tool-call`: save/restore prompt cache · ochafik/llama.cpp@2562f5a · GitHub
[go: up one dir, main page]

Skip to content

Commit 2562f5a

Browse files
author
Olivier Chafik
committed
tool-call: save/restore prompt cache
1 parent e2a9ab6 commit 2562f5a

File tree

2 files changed

+92
-29
lines changed

2 files changed

+92
-29
lines changed

examples/agent/run.py

Lines changed: 43 additions & 13 deletions
< 6D47 /tr>
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# ///
1111
import json
1212
import asyncio
13+
import hashlib
1314
import logging
1415
import os
1516
import aiohttp
@@ -157,24 +158,51 @@ async def main(
157158
if openai:
158159
api_key = os.environ.get('OPENAI_API_KEY')
159160

160-
tool_map, tools = await discover_tools(tools or [], logger=logger)
161-
162-
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
161+
completions_url = f'{endpoint}chat/completions'
163162

164-
messages = [
165-
dict(
166-
role='user',
167-
content=goal,
168-
)
169-
]
163+
tool_map, tools = await discover_tools(tools or [], logger=logger)
170164

165+
171166
headers = {
172167
'Content-Type': 'application/json',
173168
'Authorization': f'Bearer {api_key}'
174169
}
175170
async with aiohttp.ClientSession(headers=headers) as session:
171+
172+
prompt_session_file = None
173+
if not openai:
174+
prompt_session_file = 'session.' + hashlib.sha256(json.dumps(dict(
175+
model=model,
176+
tools=tools,
177+
)).encode()).hexdigest() + '.bin'
178+
179+
if os.path.exists(prompt_session_file):
180+
logger.info('Found prompt cache %s', prompt_session_file)
181+
else:
182+
payload = dict(
183+
messages=[dict(role='user', content='')],
184+
tools=tools,
185+
max_tokens=1,
186+
save_filename=prompt_session_file,
187+
)
188+
logger.info('Computing prompt cache %s', prompt_session_file)
189+
logger.debug('Calling %s: %s', completions_url, json.dumps(payload, indent=2))
190+
async with aiohttp.ClientSession(headers=headers) as session:
191+
async with session.post(completions_url, json=payload) as response:
192+
logger.debug('Response: %s', response)
193+
response.raise_for_status()
194+
response = await response.json()
195+
196+
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
197+
198+
messages = [
199+
dict(
200+
role='user',
201+
content=goal,
202+
)
203+
]
204+
176205
for i in range(max_iterations or sys.maxsize):
177-
url = f'{endpoint}chat/completions'
178206
payload = dict(
179207
messages=messages,
180208
model=model,
@@ -185,12 +213,14 @@ async def main(
185213
seed=seed,
186214
cache_prompt=cache_prompt,
187215
)) # type: ignore
216+
if prompt_session_file and os.path.exists(prompt_session_file):
217+
payload['restore_filename'] = prompt_session_file
188218

189-
logger.debug('Calling %s with %s', url, json.dumps(payload, indent=2))
190-
async with session.post(url, json=payload) as response:
191-
logger.debug('Response: %s', response)
219+
logger.debug('Calling %s with %s', completions_url, json.dumps(payload, indent=2))
220+
async with session.post(completions_url, json=payload) as response:
192221
response.raise_for_status()
193222
response = await response.json()
223+
logger.debug('Response: %s', response)
194224

195225
assert len(response['choices']) == 1
196226
choice = response['choices'][0]

examples/server/server.cpp

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ struct server_task_result {
126126
struct slot_params {
127127
bool stream = true;
128128
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
129+
std::string save_filepath; // Where to save the slot data when done.
129130

130131
int32_t n_keep = 0; // number of tokens to keep from initial prompt
131132
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -249,6 +250,34 @@ struct server_slot {
249250
return state != SLOT_STATE_IDLE;
250251
}
251252

253+
struct restore_results {
254+
size_t token_count;
255+
size_t nread;
256+
};
257+
258+
restore_results restore(struct llama_context * ctx, const std::string & filepath) {
259+
cache_tokens.resize(n_ctx);
260+
size_t token_count = 0;
261+
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), id + 1, cache_tokens.data(), cache_tokens.size(), &token_count);
262+
if (nread == 0) {
263+
cache_tokens.resize(0);
264+
throw std::runtime_error("Unable to restore slot, no available space in KV cache or invalid slot save file");
265+
}
266+
cache_tokens.resize(token_count);
267+
return {token_count, nread};
268+
}
269+
270+
struct save_results {
271+
size_t token_count;
272+
size_t nwrite;
273+
};
274+
275+
save_results save(struct llama_context * ctx, const std::string & filepath) const {
276+
const size_t token_count = cache_tokens.size();
277+
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), id + 1, cache_tokens.data(), token_count);
278+
return {token_count, nwrite};
279+
}
280+
252281
void add_token(const completion_token_output & token) {
253282
if (!is_processing()) {
254283
SLT_WRN(*this, "%s", "slot is not processing\n");
@@ -893,6 +922,7 @@ struct server_context {
893922
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
894923
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
895924
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
925+
slot.params.save_filepath = params.slot_save_path + json_value(data, "save_filename", std::string());
896926

897927
// process "json_schema" and "grammar"
898928
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@@ -1581,6 +1611,12 @@ struct server_context {
15811611
break;
15821612
}
15831613

1614+
if (task.data.contains("restore_filename")) {
1615+
std::string filename = task.data.at("restore_filename");
1616+
std::string filepath = params.slot_save_path + filename;
1617+
slot->restore(ctx, filepath);
1618+
}
1619+
15841620
if (task.data.contains("system_prompt")) {
15851621
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
15861622
system_prompt_set(sys_prompt);
@@ -1698,13 +1734,12 @@ struct server_context {
16981734
break;
16991735
}
17001736

1701-
const size_t token_count = slot->cache_tokens.size();
17021737
const int64_t t_start = ggml_time_us();
17031738

17041739
std::string filename = task.data.at("filename");
17051740
std::string filepath = task.data.at("filepath");
17061741

1707-
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
1742+
auto save_results = slot->save(ctx, filepath);
17081743

17091744
const int64_t t_end = ggml_time_us();
17101745
const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -1716,8 +1751,9 @@ struct server_context {
17161751
result.data = json {
17171752
{ "id_slot", id_slot },
17181753
{ "filename", filename },
1719-
{ "n_saved", token_count }, // tokens saved
1720-
{ "n_written", nwrite }, // bytes written
1754+
{ "filepath", filepath },
1755+
{ "n_saved", save_results.token_count }, // tokens saved
1756+
{ "n_written", save_results.nwrite }, // bytes written
17211757
{ "timings", {
17221758
{ "save_ms", t_save_ms }
17231759
} }
@@ -1744,15 +1780,7 @@ struct server_context {
17441780
std::string filename = task.data.at("filename");
17451781
std::string filepath = task.data.at("filepath");
17461782

1747-
slot->cache_tokens.resize(slot->n_ctx);
1748-
size_t token_count = 0;
1749-
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1750-
if (nread == 0) {
1751-
slot->cache_tokens.resize(0);
1752-
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
1753-
break;
1754-
}
1755-
slot->cache_tokens.resize(token_count);
1783+
auto restore_results = slot->restore(ctx, filepath);
17561784

17571785
const int64_t t_end = ggml_time_us();
17581786
const double t_restore_ms = (t_end - t_start) / 1000.0;
@@ -1763,9 +1791,10 @@ struct server_context {
17631791
result.error = false;
17641792
result.data = json {
17651793
{ "id_slot", id_slot },
1794+
{ "filepath", filepath },
17661795
{ "filename", filename },
1767-
{ "n_restored", token_count }, // tokens restored
1768-
{ "n_read", nread }, // bytes read
1796+
{ "n_restored", restore_results.token_count }, // tokens restored
1797+
{ "n_read", restore_results.nread }, // bytes read
17691798
{ "timings", {
17701799
{ "restore_ms", t_restore_ms }
17711800
} }
@@ -2284,6 +2313,10 @@ struct server_context {
22842313
slot.print_timings();
22852314
send_final_response(slot);
22862315
metrics.on_prediction(slot);
2316+
2317+
if (!slot.params.save_filepath.empty()) {
2318+
slot.save(ctx, slot.params.save_filepath);
2319+
}
22872320
}
22882321

22892322
slot.i_batch = -1;
@@ -2865,7 +2898,7 @@ int main(int argc, char ** argv) {
28652898

28662899
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
28672900
json data = json::parse(req.body);
2868-
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
2901+
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
28692902
};
28702903

28712904
// TODO: maybe merge this function with "handle_completions_generic"

0 commit comments

Comments
 (0)
0