Skip to content

Commit 88a9772

Browse files
committed
Refactor for offline models
1 parent 226dfa1 commit 88a9772

File tree

1 file changed

+41
-42
lines changed

1 file changed

+41
-42
lines changed

src/huggingface_hub.cpp

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ std::string get_model_repo_path(const std::string &repo_id) {
9797
return model_folder;
9898
}
9999

100+
std::string find_outdated_file(const std::string &snapshot_dir,
101+
const std::string &filename) {
102+
for (const auto &version :
103+
std::filesystem::directory_iterator(snapshot_dir)) {
104+
for (const auto &file :
105+
std::filesystem::directory_iterator(version.path())) {
106+
if (file.path().filename() == filename) {
107+
return file.path();
108+
break;
109+
}
110+
}
111+
}
112+
return "";
113+
}
114+
100115
std::string create_cache_system(const std::string &cache_dir,
101116
const std::string &repo_id) {
102117
std::string model_folder = get_model_repo_path(repo_id);
@@ -179,7 +194,7 @@ std::string get_file_path(const std::string &cache_dir,
179194
expanded_cache_dir / model_folder / "refs" / "main";
180195

181196
if (!std::filesystem::exists(refs_file_path)) {
182-
log_info("refs file does not exist");
197+
log_debug("refs file does not exist");
183198
return "";
184199
}
185200
std::ifstream refs_file(refs_file_path);
@@ -397,51 +412,35 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
397412
if (std::holds_alternative<CURLcode>(commit_result)) {
398413
CURLcode err = std::get<CURLcode>(commit_result);
399414

400-
if (err == CURLE_COULDNT_RESOLVE_HOST ||
401-
err == CURLE_COULDNT_CONNECT) { // OFFLINE MODE
402-
std::string file_path = get_file_path(cache_dir, repo_id, filename);
403-
if (!file_path.empty()) {
404-
log_info("No connection. Using cached file.");
405-
result.path = file_path;
406-
result.success = true;
407-
return result;
408-
} else {
409-
std::string model_path = get_model_repo_path(repo_id);
410-
bool file_found = false;
411-
for (const auto &version : std::filesystem::directory_iterator(
412-
expand_user_home(cache_dir + "/" + model_path + "/snapshots"))) {
413-
for (const auto &file :
414-
std::filesystem::directory_iterator(version.path())) {
415-
if (file.path().filename() == filename) {
416-
result.path = file.path();
417-
file_found = true;
418-
break;
419-
}
420-
}
421-
if (file_found) {
422-
break;
423-
}
424-
}
425-
if (file_found) {
426-
log_info("No connection. Using outdated cached file: " + result.path);
427-
result.success = true;
428-
return result;
429-
}
430-
431-
log_error("Could not resolve host or connect to Hugging Face. "
432-
"Please check your internet connection.");
433-
result.success = false;
434-
return result;
435-
}
436-
} else if (err == CURLE_HTTP_RETURNED_ERROR) { // REPOSITORY NOT FOUND
437-
log_error("Repository not found: " + repo_id);
438-
result.success = false;
415+
std::string file_path = get_file_path(cache_dir, repo_id, filename);
416+
if (!file_path.empty()) {
417+
log_info("Using cached file.");
418+
result.path = file_path;
419+
result.success = true;
439420
return result;
440-
} else {
441-
log_error("Error getting model: " + std::string(curl_easy_strerror(err)));
421+
}
422+
423+
std::string model_path = get_model_repo_path(repo_id);
424+
std::string snapshot_path =
425+
expand_user_home(cache_dir + "/" + model_path + "/snapshots");
426+
if (!std::filesystem::exists(snapshot_path)) {
427+
log_info(snapshot_path);
428+
log_error("Repo not found (locally nor online): " + repo_id);
442429
result.success = false;
443430
return result;
444431
}
432+
433+
std::string outdated_file = find_outdated_file(snapshot_path, filename);
434+
if (!outdated_file.empty()) {
435+
log_info("Using outdated cached file " + outdated_file);
436+
result.path = outdated_file;
437+
result.success = true;
438+
return result;
439+
}
440+
441+
log_error("Error getting model: " + std::string(curl_easy_strerror(err)));
442+
result.success = false;
443+
return result;
445444
}
446445

447446
std::string latest_commit = std::get<std::string>(commit_result);

0 commit comments

Comments
 (0)