Skip to content

Commit c6d2a57

Browse files
committed
Refactor imatrix api, fix build shared libs
1 parent 55f7f35 commit c6d2a57

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

examples/cli/main.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@
2424
#define STB_IMAGE_RESIZE_STATIC
2525
#include "stb_image_resize.h"
2626

27-
#define IMATRIX_IMPL
28-
#include "imatrix.hpp"
29-
static IMatrixCollector g_collector;
30-
3127
#define SAFE_STR(s) ((s) ? (s) : "")
3228
#define BOOL_STR(b) ((b) ? "true" : "false")
3329

@@ -770,10 +766,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
770766
fflush(out_stream);
771767
}
772768

773-
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
774-
return g_collector.collect_imatrix(t, ask, user_data);
775-
}
776-
777769
int main(int argc, const char* argv[]) {
778770
SDParams params;
779771

@@ -799,13 +791,12 @@ int main(int argc, const char* argv[]) {
799791
}
800792

801793
if (params.imatrix_out != "") {
802-
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, &params);
794+
enableImatrixCollection();
803795
}
804796
if (params.imatrix_out != "" || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
805-
setConvertImatrixCollector((void*)&g_collector);
806797
for (const auto& in_file : params.imatrix_in) {
807798
printf("loading imatrix from '%s'\n", in_file.c_str());
808-
if (!g_collector.load_imatrix(in_file.c_str())) {
799+
if (!loadImatrix(in_file.c_str())) {
809800
printf("Failed to load %s\n", in_file.c_str());
810801
}
811802
}
@@ -1120,7 +1111,7 @@ int main(int argc, const char* argv[]) {
11201111
results[i].data = NULL;
11211112
}
11221113
if (params.imatrix_out != "") {
1123-
g_collector.save_imatrix(params.imatrix_out);
1114+
saveImatrix(params.imatrix_out.c_str());
11241115
}
11251116
free(results);
11261117
free_sd_ctx(sd_ctx);

model.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#define ST_HEADER_SIZE_LEN 8
3535

36-
static IMatrixCollector* imatrix_collector = NULL;
36+
static IMatrixCollector imatrix_collector;
3737

3838
uint64_t read_u64(uint8_t* buffer) {
3939
// little endian
@@ -1984,7 +1984,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19841984

19851985
auto processed_name = convert_tensor_name(tensor_storage.name);
19861986
// LOG_DEBUG("%s",processed_name.c_str());
1987-
std::vector<float> imatrix = imatrix_collector ? imatrix_collector->get_values(processed_name) : std::vector<float>{};
1987+
std::vector<float> imatrix = imatrix_collector.get_values(processed_name);
19881988

19891989
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
19901990
dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0], imatrix);
@@ -2011,7 +2011,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
20112011
// convert first, then copy to device memory
20122012
auto processed_name = convert_tensor_name(tensor_storage.name);
20132013
// LOG_DEBUG("%s",processed_name.c_str());
2014-
std::vector<float> imatrix = imatrix_collector ? imatrix_collector->get_values(processed_name) : std::vector<float>{};
2014+
std::vector<float> imatrix = imatrix_collector.get_values(processed_name);
20152015

20162016
convert_buffer.resize(ggml_nbytes(dst_tensor));
20172017
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
@@ -2263,10 +2263,6 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
22632263
return mem_size;
22642264
}
22652265

2266-
void setConvertImatrixCollector(void* collector) {
2267-
imatrix_collector = ((IMatrixCollector*)collector);
2268-
}
2269-
22702266
bool convert(const char* model_path, const char* clip_l_path, const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) {
22712267
ModelLoader model_loader;
22722268

@@ -2314,3 +2310,19 @@ bool convert(const char* model_path, const char* clip_l_path, const char* clip_g
23142310
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
23152311
return success;
23162312
}
2313+
2314+
bool loadImatrix(const char* imatrix_path) {
2315+
return imatrix_collector.load_imatrix(imatrix_path);
2316+
}
2317+
void saveImatrix(const char* imatrix_path) {
2318+
imatrix_collector.save_imatrix(imatrix_path);
2319+
}
2320+
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
2321+
return imatrix_collector.collect_imatrix(t, ask, user_data);
2322+
}
2323+
void enableImatrixCollection() {
2324+
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, NULL);
2325+
}
2326+
void disableImatrixCollection() {
2327+
sd_set_backend_eval_callback(NULL, NULL);
2328+
}

stable-diffusion.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
243243

244244
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
245245

246-
SD_API void setConvertImatrixCollector(void * collector);
247246
SD_API bool convert(const char* model_path, const char* clip_l_path, const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path,
248247
const char* vae_path,
249248
const char* output_path,
@@ -259,6 +258,11 @@ SD_API uint8_t* preprocess_canny(uint8_t* img,
259258
float strong,
260259
bool inverse);
261260

261+
SD_API bool loadImatrix(const char * imatrix_path);
262+
SD_API void saveImatrix(const char * imatrix_path);
263+
SD_API void enableImatrixCollection();
264+
SD_API void disableImatrixCollection();
265+
262266
#ifdef __cplusplus
263267
}
264268
#endif

0 commit comments

Comments
 (0)