Skip to content

Commit 47672e5

Browse files
committed
support download from modelscope
1 parent fd7855f commit 47672e5

File tree

4 files changed

+213
-24
lines changed

4 files changed

+213
-24
lines changed

common/arg.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ static void common_params_handle_model_default(
140140
// short-hand to avoid specifying --hf-file -> default it to --model
141141
if (hf_file.empty()) {
142142
if (model.empty()) {
143-
auto auto_detected = common_get_hf_file(hf_repo, hf_token);
143+
std::pair<std::string, std::string> auto_detected;
144+
if (LLAMACPP_USE_MODELSCOPE_DEFINITION) {
145+
auto_detected = common_get_ms_file(hf_repo, hf_token);
146+
} else {
147+
auto_detected = common_get_hf_file(hf_repo, hf_token);
148+
}
144149
if (auto_detected.first.empty() || auto_detected.second.empty()) {
145150
exit(1); // built without CURL, error message already printed
146151
}

common/common.cpp

Lines changed: 127 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,11 @@ struct common_init_result common_init_from_params(common_params & params) {
907907
llama_model * model = nullptr;
908908

909909
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
910-
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
910+
if (LLAMACPP_USE_MODELSCOPE_DEFINITION) {
911+
model = common_load_model_from_ms(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
912+
} else {
913+
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
914+
}
911915
} else if (!params.model_url.empty()) {
912916
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
913917
} else {
@@ -1207,6 +1211,12 @@ static bool common_download_file(const std::string & url, const std::string & pa
12071211
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
12081212
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
12091213

1214+
std::vector<std::string> _headers = {"User-Agent: llama-cpp"};
1215+
for (const auto & header : _headers) {
1216+
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
1217+
}
1218+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
1219+
12101220
// Check if hf-token or bearer-token was specified
12111221
if (!hf_token.empty()) {
12121222
std::string auth_header = "Authorization: Bearer " + hf_token;
@@ -1265,6 +1275,7 @@ static bool common_download_file(const std::string & url, const std::string & pa
12651275
};
12661276

12671277
common_load_model_from_url_headers headers;
1278+
bool should_download = false;
12681279

12691280
{
12701281
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
@@ -1293,32 +1304,35 @@ static bool common_download_file(const std::string & url, const std::string & pa
12931304
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
12941305
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
12951306
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
1307+
if (!LLAMACPP_USE_MODELSCOPE_DEFINITION) {
1308+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
1309+
if (!was_perform_successful) {
1310+
return false;
1311+
}
12961312

1297-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
1298-
if (!was_perform_successful) {
1299-
return false;
1300-
}
1301-
1302-
long http_code = 0;
1303-
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
1304-
if (http_code != 200) {
1305-
// HEAD not supported, we don't know if the file has changed
1306-
// force trigger downloading
1307-
force_download = true;
1308-
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
1313+
long http_code = 0;
1314+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
1315+
if (http_code != 200) {
1316+
// HEAD not supported, we don't know if the file has changed
1317+
// force trigger downloading
1318+
force_download = true;
1319+
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
1320+
}
1321+
should_download = !file_exists || force_download;
1322+
if (!should_download) {
1323+
if (!etag.empty() && etag != headers.etag) {
1324+
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
1325+
should_download = true;
1326+
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
1327+
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
1328+
should_download = true;
1329+
}
1330+
}
1331+
} else {
1332+
should_download = !file_exists;
13091333
}
13101334
}
13111335

1312-
bool should_download = !file_exists || force_download;
1313-
if (!should_download) {
1314-
if (!etag.empty() && etag != headers.etag) {
1315-
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
1316-
should_download = true;
1317-
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
1318-
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
1319-
should_download = true;
1320-
}
1321-
}
13221336
if (should_download) {
13231337
std::string path_temporary = path + ".downloadInProgress";
13241338
if (file_exists) {
@@ -1507,6 +1521,20 @@ struct llama_model * common_load_model_from_hf(
15071521
return common_load_model_from_url(model_url, local_path, hf_token, params);
15081522
}
15091523

1524+
struct llama_model * common_load_model_from_ms(
1525+
const std::string & repo,
1526+
const std::string & remote_path,
1527+
const std::string & local_path,
1528+
const std::string & ms_token,
1529+
const struct llama_model_params & params) {
1530+
std::string model_url = "https://" + MODELSCOPE_DOMAIN_DEFINITION + "/models/";
1531+
model_url += repo;
1532+
model_url += "/resolve/master/";
1533+
model_url += remote_path;
1534+
// modelscope does not support token in header
1535+
return common_load_model_from_url(model_url, local_path, "", params);
1536+
}
1537+
15101538
/**
15111539
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
15121540
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
@@ -1581,6 +1609,82 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_re
15811609
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
15821610
}
15831611

1612+
std::pair<std::string, std::string> common_get_ms_file(const std::string & ms_repo_with_tag, const std::string & ms_token) {
1613+
auto parts = string_split<std::string>(ms_repo_with_tag, ':');
1614+
std::string tag = parts.size() > 1 ? parts.back() : "Q4_K_M";
1615+
std::string hf_repo = parts[0];
1616+
if (string_split<std::string>(hf_repo, '/').size() != 2) {
1617+
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
1618+
}
1619+
1620+
// fetch model info from Hugging Face Hub API
1621+
json model_info;
1622+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
1623+
curl_slist_ptr http_headers;
1624+
std::string res_str;
1625+
auto endpoint = MODELSCOPE_DOMAIN_DEFINITION;
1626+
1627+
std::string url = endpoint + "/api/v1/models/" + hf_repo + "/repo/files?Revision=master&Recursive=True";
1628+
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
1629+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
1630+
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
1631+
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
1632+
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
1633+
return size * nmemb;
1634+
};
1635+
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
1636+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
1637+
#if defined(_WIN32)
1638+
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
1639+
#endif
1640+
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
1641+
http_headers.ptr = curl_slist_append(http_headers.ptr, "user-agent: llama-cpp");
1642+
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
1643+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
1644+
1645+
CURLcode res = curl_easy_perform(curl.get());
1646+
1647+
if (res != CURLE_OK) {
1648+
throw std::runtime_error("error: cannot make GET request to HF API");
1649+
}
1650+
1651+
long res_code;
1652+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
1653+
if (res_code == 200) {
1654+
model_info = nlohmann::json::parse(res_str);
1655+
} else if (res_code == 401) {
1656+
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid MS token");
1657+
} else {
1658+
throw std::runtime_error(string_format("error from MS API, response code: %ld, data: %s", res_code, res_str.c_str()));
1659+
}
1660+
1661+
auto all_files = model_info["Data"]["Files"];
1662+
1663+
std::vector<std::string> all_available_files;
1664+
std::string gguf_file;
1665+
std::string upper_tag;
1666+
upper_tag.reserve(tag.size());
1667+
std::string lower_tag;
1668+
lower_tag.reserve(tag.size());
1669+
std::transform(tag.begin(), tag.end(), std::back_inserter(upper_tag), ::toupper);
1670+
std::transform(tag.begin(), tag.end(), std::back_inserter(lower_tag), ::tolower);
1671+
for (const auto & _file : all_files) {
1672+
auto file = _file["Path"].get<std::string>();
1673+
if (!string_ends_with(file, ".gguf")) {
1674+
continue;
1675+
}
1676+
if (file.find(upper_tag) != std::string::npos || file.find(lower_tag) != std::string::npos) {
1677+
gguf_file = file;
1678+
}
1679+
all_available_files.push_back(file);
1680+
}
1681+
if (gguf_file.empty()) {
1682+
gguf_file = all_available_files[0];
1683+
}
1684+
1685+
return std::make_pair(hf_repo, gguf_file);
1686+
}
1687+
15841688
#else
15851689

15861690
struct llama_model * common_load_model_from_url(

common/common.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@
2525

2626
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
2727

28+
#define MODELSCOPE_DOMAIN_DEFINITION ([]() -> std::string { \
29+
const char* ms_endpoint = std::getenv("MODELSCOPE_DOMAIN"); \
30+
if (ms_endpoint == nullptr) { \
31+
ms_endpoint = "www.modelscope.cn"; \
32+
} \
33+
return std::string(ms_endpoint); \
34+
}())
35+
36+
#define LLAMACPP_USE_MODELSCOPE_DEFINITION ([]() -> bool { \
37+
const char* use_modelscope = std::getenv("LLAMACPP_USE_MODELSCOPE"); \
38+
if (use_modelscope == nullptr) { \
39+
use_modelscope = "False"; \
40+
} \
41+
bool llamacpp_use_modelscope = false; \
42+
if (std::string(use_modelscope) == "True" || std::string(use_modelscope) == "true") { \
43+
llamacpp_use_modelscope = true; \
44+
} \
45+
return llamacpp_use_modelscope; \
46+
}())
47+
2848
struct common_adapter_lora_info {
2949
std::string path;
3050
float scale;
@@ -559,10 +579,21 @@ struct llama_model * common_load_model_from_hf(
559579
const std::string & hf_token,
560580
const struct llama_model_params & params);
561581

582+
struct llama_model * common_load_model_from_ms(
583+
const std::string & repo,
584+
const std::string & remote_path,
585+
const std::string & local_path,
586+
const std::string & ms_token,
587+
const struct llama_model_params & params);
588+
562589
std::pair<std::string, std::string> common_get_hf_file(
563590
const std::string & hf_repo_with_tag,
564591
const std::string & hf_token);
565592

593+
std::pair<std::string, std::string> common_get_ms_file(
594+
const std::string & ms_repo_with_tag,
595+
const std::string & ms_token);
596+
566597
// clear LoRA adapters from context, then apply new list of adapters
567598
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
568599

examples/run/run.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,49 @@ class LlamaData {
718718
return download(url, bn, true, headers);
719719
}
720720

721+
int modelscope_dl(std::string & model, const std::string & bn) {
722+
// Find the second occurrence of '/' after protocol string
723+
size_t pos = model.find('/');
724+
pos = model.find('/', pos + 1);
725+
std::string hfr;
726+
std::string hff;
727+
std::vector<std::string> headers = { "user-agent: llama-cpp", "Accept: application/json"};
728+
std::string url;
729+
auto endpoint = MODELSCOPE_DOMAIN_DEFINITION;
730+
731+
if (pos == std::string::npos) {
732+
auto [model_name, _] = extract_model_and_tag(model, "");
733+
hfr = model_name;
734+
std::string manifest_str;
735+
url = endpoint + "/api/v1/models/" + hfr + "/repo/files?Revision=master&Recursive=True";
736+
if (int ret = download(url, "", false, headers, &manifest_str)) {
737+
return ret;
738+
}
739+
auto all_files = nlohmann::json::parse(manifest_str)["Data"]["Files"];
740+
741+
std::vector<std::string> all_available_files;
742+
for (const auto & _file : all_files) {
743+
auto file = _file["Path"].get<std::string>();
744+
if (!string_ends_with(file, ".gguf")) {
745+
continue;
746+
}
747+
if (file.find("Q4_K_M") != std::string::npos || file.find("q4_k_m") != std::string::npos) {
748+
hff = file;
749+
}
750+
all_available_files.push_back(file);
751+
}
752+
if (hff.empty()) {
753+
hff = all_available_files[0];
754+
}
755+
756+
} else {
757+
hfr = model.substr(0, pos);
758+
hff = model.substr(pos + 1);
759+
}
760+
url = endpoint + "/models/" + hfr + "/resolve/master/" + hff;
761+
return download(url, bn, true, headers);
762+
}
763+
721764
int ollama_dl(std::string & model, const std::string & bn) {
722765
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
723766
if (model.find('/') == std::string::npos) {
@@ -835,6 +878,12 @@ class LlamaData {
835878
rm_until_substring(model_, "hf.co/");
836879
rm_until_substring(model_, "://");
837880
ret = huggingface_dl(model_, bn);
881+
} else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://") ||
882+
model_.find("modelscope") != std::string::npos || LLAMACPP_USE_MODELSCOPE_DEFINITION) {
883+
rm_until_substring(model_, "modelscope.cn/");
884+
rm_until_substring(model_, "modelscope.ai/");
885+
rm_until_substring(model_, "://");
886+
ret = modelscope_dl(model_, bn);
838887
} else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) &&
839888
!string_starts_with(model_, "https://ollama.com/library/")) {
840889
ret = download(model_, bn, true);

0 commit comments

Comments
 (0)