@@ -126,6 +126,7 @@ struct server_task_result {
126
126
struct slot_params {
127
127
bool stream = true ;
128
128
bool cache_prompt = false ; // remember the prompt to avoid reprocessing all prompt
129
+ std::string save_filepath; // Where to save the slot data when done.
129
130
130
131
int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
131
132
int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -249,6 +250,34 @@ struct server_slot {
249
250
return state != SLOT_STATE_IDLE;
250
251
}
251
252
253
+ struct restore_results {
254
+ size_t token_count;
255
+ size_t nread;
256
+ };
257
+
258
+ restore_results restore (struct llama_context * ctx, const std::string & filepath) {
259
+ cache_tokens.resize (n_ctx);
260
+ size_t token_count = 0 ;
261
+ size_t nread = llama_state_seq_load_file (ctx, filepath.c_str (), id + 1 , cache_tokens.data (), cache_tokens.size (), &token_count);
262
+ if (nread == 0 ) {
263
+ cache_tokens.resize (0 );
264
+ throw std::runtime_error (" Unable to restore slot, no available space in KV cache or invalid slot save file" );
265
+ }
266
+ cache_tokens.resize (token_count);
267
+ return {token_count, nread};
268
+ }
269
+
270
+ struct save_results {
271
+ size_t token_count;
272
+ size_t nwrite;
273
+ };
274
+
275
+ save_results save (struct llama_context * ctx, const std::string & filepath) const {
276
+ const size_t token_count = cache_tokens.size ();
277
+ const size_t nwrite = llama_state_seq_save_file (ctx, filepath.c_str (), id + 1 , cache_tokens.data (), token_count);
278
+ return {token_count, nwrite};
279
+ }
280
+
252
281
void add_token (const completion_token_output & token) {
253
282
if (!is_processing ()) {
254
283
SLT_WRN (*this , " %s" , " slot is not processing\n " );
@@ -893,6 +922,7 @@ struct server_context {
893
922
slot.sparams .seed = json_value (data, " seed" , default_sparams.seed );
894
923
slot.sparams .n_probs = json_value (data, " n_probs" , default_sparams.n_probs );
895
924
slot.sparams .min_keep = json_value (data, " min_keep" , default_sparams.min_keep );
925
+ slot.params .save_filepath = params.slot_save_path + json_value (data, " save_filename" , std::string ());
896
926
897
927
// process "json_schema" and "grammar"
898
928
if (data.contains (" json_schema" ) && !data.at (" json_schema" ).is_null () && data.contains (" grammar" ) && !data.at (" grammar" ).is_null ()) {
@@ -1581,6 +1611,12 @@ struct server_context {
1581
1611
break ;
1582
1612
}
1583
1613
1614
+ if (task.data .contains (" restore_filename" )) {
1615
+ std::string filename = task.data .at (" restore_filename" );
1616
+ std::string filepath = params.slot_save_path + filename;
1617
+ slot->restore (ctx, filepath);
1618
+ }
1619
+
1584
1620
if (task.data .contains (" system_prompt" )) {
1585
1621
std::string sys_prompt = json_value (task.data , " system_prompt" , std::string ());
1586
1622
system_prompt_set (sys_prompt);
@@ -1698,13 +1734,12 @@ struct server_context {
1698
1734
break ;
1699
1735
}
1700
1736
1701
- const size_t token_count = slot->cache_tokens .size ();
1702
1737
const int64_t t_start = ggml_time_us ();
1703
1738
1704
1739
std::string filename = task.data .at (" filename" );
1705
1740
std::string filepath = task.data .at (" filepath" );
1706
1741
1707
- const size_t nwrite = llama_state_seq_save_file (ctx, filepath. c_str (), slot->id + 1 , slot-> cache_tokens . data (), token_count );
1742
+ auto save_results = slot->save (ctx, filepath );
1708
1743
1709
1744
const int64_t t_end = ggml_time_us ();
1710
1745
const double t_save_ms = (t_end - t_start) / 1000.0 ;
@@ -1716,8 +1751,9 @@ struct server_context {
1716
1751
result.data = json {
1717
1752
{ " id_slot" , id_slot },
1718
1753
{ " filename" , filename },
1719
- { " n_saved" , token_count }, // tokens saved
1720
- { " n_written" , nwrite }, // bytes written
1754
+ { " filepath" , filepath },
1755
+ { " n_saved" , save_results.token_count }, // tokens saved
1756
+ { " n_written" , save_results.nwrite }, // bytes written
1721
1757
{ " timings" , {
1722
1758
{ " save_ms" , t_save_ms }
1723
1759
} }
@@ -1744,15 +1780,7 @@ struct server_context {
1744
1780
std::string filename = task.data .at (" filename" );
1745
1781
std::string filepath = task.data .at (" filepath" );
1746
1782
1747
- slot->cache_tokens .resize (slot->n_ctx );
1748
- size_t token_count = 0 ;
1749
- size_t nread = llama_state_seq_load_file (ctx, filepath.c_str (), slot->id + 1 , slot->cache_tokens .data (), slot->cache_tokens .size (), &token_count);
1750
- if (nread == 0 ) {
1751
- slot->cache_tokens .resize (0 );
1752
- send_error (task, " Unable to restore slot, no available space in KV cache or invalid slot save file" , ERROR_TYPE_INVALID_REQUEST);
1753
- break ;
1754
- }
1755
- slot->cache_tokens .resize (token_count);
1783
+ auto restore_results = slot->restore (ctx, filepath);
1756
1784
1757
1785
const int64_t t_end = ggml_time_us ();
1758
1786
const double t_restore_ms = (t_end - t_start) / 1000.0 ;
@@ -1763,9 +1791,10 @@ struct server_context {
1763
1791
result.error = false ;
1764
1792
result.data = json {
1765
1793
{ " id_slot" , id_slot },
1794
+ { " filepath" , filepath },
1766
1795
{ " filename" , filename },
1767
- { " n_restored" , token_count }, // tokens restored
1768
- { " n_read" , nread }, // bytes read
1796
+ { " n_restored" , restore_results. token_count }, // tokens restored
1797
+ { " n_read" , restore_results. nread }, // bytes read
1769
1798
{ " timings" , {
1770
1799
{ " restore_ms" , t_restore_ms }
1771
1800
} }
@@ -2284,6 +2313,10 @@ struct server_context {
2284
2313
slot.print_timings ();
2285
2314
send_final_response (slot);
2286
2315
metrics.on_prediction (slot);
2316
+
2317
+ if (!slot.params .save_filepath .empty ()) {
2318
+ slot.save (ctx, slot.params .save_filepath );
2319
+ }
2287
2320
}
2288
2321
2289
2322
slot.i_batch = -1 ;
@@ -2865,7 +2898,7 @@ int main(int argc, char ** argv) {
2865
2898
2866
2899
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2867
2900
json data = json::parse (req.body );
2868
- return handle_completions_generic (SERVER_TASK_CMPL_TYPE_INFILL, data, res);
2901
+ return handle_completions_generic (SERVER_TASK_CMPL_TYPE_INFILL, data, res);
2869
2902
};
2870
2903
2871
2904
// TODO: maybe merge this function with "handle_completions_generic"
0 commit comments