diff --git a/common/arg.cpp b/common/arg.cpp index 75e8e0bd51aee..6633c3a1db8a4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -43,6 +43,25 @@ std::initializer_list mmproj_examples = { // TODO: add LLAMA_EXAMPLE_SERVER when it's ready }; +static std::string read_file(const std::string & fname) { + std::ifstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + file.close(); + return content; +} + +static void write_file(const std::string & fname, const std::string & content) { + std::ofstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + file << content; + file.close(); +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = std::move(examples); return *this; @@ -200,9 +219,11 @@ struct curl_slist_ptr { static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { int remaining_attempts = max_attempts; + char * method = nullptr; + curl_easy_getinfo(curl, CURLINFO_EFFECTIVE_METHOD, &method); while (remaining_attempts > 0) { - LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); CURLcode res = curl_easy_perform(curl); if (res == CURLE_OK) { @@ -213,6 +234,7 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); remaining_attempts--; + if (remaining_attempts == 0) break; std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); } @@ -231,8 +253,6 @@ static bool common_download_file_single(const std::string & url, const std::stri return false; } - bool force_download = false; - // Set the URL, allow to follow http redirection curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); @@ -256,7 +276,7 @@ static bool common_download_file_single(const std::string & url, const std::stri // If the file exists, check its JSON metadata companion file. std::string metadata_path = path + ".json"; - nlohmann::json metadata; + nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead std::string etag; std::string last_modified; @@ -266,7 +286,7 @@ static bool common_download_file_single(const std::string & url, const std::stri if (metadata_in.good()) { try { metadata_in >> metadata; - LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); if (metadata.contains("url") && metadata.at("url").is_string()) { auto previous_url = metadata.at("url").get(); if (previous_url != url) { @@ -296,7 +316,10 @@ static bool common_download_file_single(const std::string & url, const std::stri }; common_load_model_from_url_headers headers; + bool head_request_ok = false; + bool should_download = !file_exists; // by default, we should download if the file does not exist + // get ETag to see if the remote file has changed { typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { @@ -325,23 +348,28 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + // we only allow retrying once for HEAD requests + // this is for the use case of using running offline (no internet), retrying can be annoying + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0); if (!was_perform_successful) { - return false; + head_request_ok = false; } long http_code = 0; curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code != 200) { - // HEAD not supported, we don't know if the file has changed - // force trigger downloading - force_download = true; - LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; } } - bool should_download = !file_exists || force_download; - if (!should_download) { + // if head_request_ok is false, we don't have the etag or last-modified headers + // we leave should_download as-is, which is true if the file does not exist + if (head_request_ok) { + // check if ETag or Last-Modified headers are different + // if it is, we need to download the file again if (!etag.empty() && etag != headers.etag) { LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); should_download = true; @@ -350,6 +378,7 @@ static bool common_download_file_single(const std::string & url, const std::stri should_download = true; } } + if (should_download) { std::string path_temporary = path + ".downloadInProgress"; if (file_exists) { @@ -424,13 +453,15 @@ static bool common_download_file_single(const std::string & url, const std::stri {"etag", headers.etag}, {"lastModified", headers.last_modified} }); - std::ofstream(metadata_path) << metadata.dump(4); - LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + write_file(metadata_path, metadata.dump(4)); + LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); if (rename(path_temporary.c_str(), path.c_str()) != 0) { LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } + } else { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); } return true; @@ -605,16 +636,37 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response // User-Agent header is already set in common_remote_get_content, no need to set it here + // we use "=" to avoid clashing with other component, while still being allowed on windows + std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json"; + string_replace_all(cached_response_fname, "/", "_"); + std::string cached_response_path = fs_get_cache_file(cached_response_fname); + // make the request common_remote_params params; params.headers = headers; - auto res = common_remote_get_content(url, params); - long res_code = res.first; - std::string res_str(res.second.data(), res.second.size()); + long res_code = 0; + std::string res_str; + bool use_cache = false; + try { + auto res = common_remote_get_content(url, params); + res_code = res.first; + res_str = std::string(res.second.data(), res.second.size()); + } catch (const std::exception & e) { + LOG_WRN("error: failed to get manifest: %s\n", e.what()); + LOG_WRN("try reading from cache\n"); + // try to read from cache + try { + res_str = read_file(cached_response_path); + res_code = 200; + use_cache = true; + } catch (const std::exception & e) { + throw std::runtime_error("error: failed to get manifest (check your internet connection)"); + } + } std::string ggufFile; std::string mmprojFile; - if (res_code == 200) { + if (res_code == 200 || res_code == 304) { // extract ggufFile.rfilename in json, using regex { std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); @@ -631,6 +683,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ mmprojFile = match[1].str(); } } + if (!use_cache) { + // if not using cached response, update the cache file + write_file(cached_response_path, res_str); + } } else if (res_code == 401) { throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); } else { @@ -1142,6 +1198,9 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e fprintf(stderr, "%s\n", ex.what()); ctx_arg.params = params_org; return false; + } catch (std::exception & ex) { + fprintf(stderr, "%s\n", ex.what()); + exit(1); // for other exceptions, we exit with status code 1 } return true; @@ -1442,13 +1501,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-f", "--file"}, "FNAME", "a file containing the prompt (default: none)", [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } + params.prompt = read_file(value); // store the external file name in params params.prompt_file = value; - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (!params.prompt.empty() && params.prompt.back() == '\n') { params.prompt.pop_back(); } @@ -1458,11 +1513,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-sysf", "--system-prompt-file"}, "FNAME", "a file containing the system prompt (default: none)", [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.system_prompt)); + params.system_prompt = read_file(value); if (!params.system_prompt.empty() && params.system_prompt.back() == '\n') { params.system_prompt.pop_back(); } @@ -1887,15 +1938,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--grammar-file"}, "FNAME", "file to read grammar from", [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(params.sampling.grammar) - ); + params.sampling.grammar = read_file(value); } ).set_sparam()); add_opt(common_arg( @@ -2815,14 +2858,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() ), [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(params.chat_template)); + params.chat_template = read_file(value); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); add_opt(common_arg(