From 0e239dd3c1fbbf32597490ebcfc801fe29d03293 Mon Sep 17 00:00:00 2001 From: Kyra Date: Sun, 23 Mar 2025 11:30:56 +0100 Subject: [PATCH] Loras can now be found in subdirectories. Made to not change the execution flow in case of the lora being directly in the model folder so it stays optimized if the correct path is provided. /!\ Couldn't test unix version yet /!\ --- stable-diffusion.cpp | 16 ++++-- util.cpp | 113 +++++++++++++++++++++++++++++++++++++++++++ util.h | 2 + 3 files changed, 127 insertions(+), 4 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..e5d96428a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -640,17 +640,25 @@ class StableDiffusionGGML { void apply_lora(const std::string& lora_name, float multiplier) { int64_t t0 = ggml_time_ms(); - std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors"); - std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt"); + std::vector extensions = {".safetensors", ".ckpt"}; + std::string st_file_path = path_join(lora_model_dir, lora_name + extensions[0]); + std::string ckpt_file_path = path_join(lora_model_dir, lora_name + extensions[1]); std::string file_path; if (file_exists(st_file_path)) { file_path = st_file_path; } else if (file_exists(ckpt_file_path)) { file_path = ckpt_file_path; } else { - LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); - return; + file_path = get_filepath_from_dir(lora_model_dir, lora_name, &extensions); + + if (file_path.empty()) { + LOG_WARN("can not find %s or %s for lora %s in main directory or subdirectories", + st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); + return; + } + LOG_INFO("found lora %s in subdirectory: %s", lora_name.c_str(), file_path.c_str()); } + LoraModel lora(backend, file_path); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); diff --git a/util.cpp b/util.cpp index da11a14d6..46c404089 100644 --- a/util.cpp +++ b/util.cpp @@ -75,6 +75,60 @@ std::string format(const char* fmt, ...) { #ifdef _WIN32 // code for windows #include +std::string get_filepath_from_dir_recursive( + const std::string& dir_path, + const std::string& file_name, + const std::vector* extensions = nullptr) { + + if (extensions) { + // Search with provided extensions + for (const auto& ext : *extensions) { + std::string file_path = path_join(dir_path, file_name + ext); + if (file_exists(file_path)) { + return file_path; + } + } + } else { + // Search for exact filename without extensions + std::string file_path = path_join(dir_path, file_name); + if (file_exists(file_path)) { + return file_path; + } + } + + // Check subdirectories + WIN32_FIND_DATA findData; + HANDLE hFind; + std::string search_path = path_join(dir_path, "*"); + + hFind = FindFirstFile(search_path.c_str(), &findData); + if (hFind != INVALID_HANDLE_VALUE) { + do { + if ((findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && + strcmp(findData.cFileName, ".") != 0 && + strcmp(findData.cFileName, "..") != 0) { + + std::string subdir = path_join(dir_path, findData.cFileName); + std::string result = get_filepath_from_dir_recursive(subdir, file_name, extensions); + + if (!result.empty()) { + FindClose(hFind); + return result; + } + } + } while (FindNextFile(hFind, &findData)); + + FindClose(hFind); + } + + return ""; +} + +std::string get_filepath_from_dir(const std::string& dir, const std::string& filename, const std::vector* extensions = nullptr) { + return get_filepath_from_dir_recursive(dir, filename, extensions); +} + + bool file_exists(const std::string& filename) { DWORD attributes = GetFileAttributesA(filename.c_str()); return (attributes != INVALID_FILE_ATTRIBUTES && !(attributes & FILE_ATTRIBUTE_DIRECTORY)); @@ -153,6 +207,65 @@ std::vector get_files_from_dir(const std::string& dir) { #include #include +std::string get_filepath_from_dir_recursive( + const std::string& dir_path, + const std::string& file_name, + const std::vector* extensions = nullptr) { + + DIR* dir = opendir(dir_path.c_str()); + if (dir == nullptr) { + return ""; + } + + std::string result = ""; + + if (extensions) { + for (const auto& ext : *extensions) { + std::string file_path = path_join(dir_path, file_name + ext); + if (file_exists(file_path)) { + closedir(dir); + return file_path; + } + } + } else { + std::string file_path = path_join(dir_path, file_name); + if (file_exists(file_path)) { + closedir(dir); + return file_path; + } + } + + // Check all subdirectories + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + if (entry->d_type == DT_DIR && + strcmp(entry->d_name, ".") != 0 && + strcmp(entry->d_name, "..") != 0) { + + std::string subdir = path_join(dir_path, entry->d_name); + + result = get_filepath_from_dir_recursive(subdir, file_name, extensions); + + if (!result.empty()) { + closedir(dir); + return result; + } + } + } + + closedir(dir); + return ""; +} + +std::string get_filepath_from_dir( + const std::string& dir_path, + const std::string& file_name, + const std::vector* extensions = nullptr) { + + return get_filepath_from_dir_recursive(dir_path, file_name, extensions); +} + + bool file_exists(const std::string& filename) { struct stat buffer; return (stat(filename.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); diff --git a/util.h b/util.h index 14fa812e5..913dd029b 100644 --- a/util.h +++ b/util.h @@ -15,6 +15,8 @@ std::string format(const char* fmt, ...); void replace_all_chars(std::string& str, char target, char replacement); +std::string get_filepath_from_dir(const std::string& dir, const std::string& filename, const std::vector* extensions); + bool file_exists(const std::string& filename); bool is_directory(const std::string& path); std::string get_full_path(const std::string& dir, const std::string& filename);