Skip to content

Commit b83c781

Browse files
committed
Imatrix: first implementation attempt
Refactor imatrix implementation into main example try fix CI build
1 parent f58bad8 commit b83c781

File tree

8 files changed

+453
-12
lines changed

8 files changed

+453
-12
lines changed

clip.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ class CLIPTextModel : public GGMLBlock {
661661
if (version == OPEN_CLIP_VIT_BIGG_14) {
662662
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
663663
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
664+
ggml_set_name(params["text_projection"], (prefix + "text_projection").c_str());
664665
}
665666
}
666667

@@ -812,6 +813,7 @@ class CLIPProjection : public UnaryBlock {
812813
} else {
813814
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
814815
}
816+
ggml_set_name(params["weight"], (prefix + "weight").c_str());
815817
}
816818

817819
public:

examples/cli/main.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#define STB_IMAGE_RESIZE_STATIC
2323
#include "stb_image_resize.h"
2424

25+
#define IMATRIX_IMPL
26+
#include "imatrix.hpp"
27+
static IMatrixCollector g_collector;
28+
2529
const char* rng_type_to_str[] = {
2630
"std_default",
2731
"cuda",
@@ -147,6 +151,12 @@ struct SDParams {
147151
int preview_interval = 1;
148152
std::string preview_path = "preview.png";
149153
bool taesd_preview = false;
154+
155+
/* Imatrix params */
156+
157+
std::string imatrix_out = "";
158+
159+
std::vector<std::string> imatrix_in = {};
150160
};
151161

152162
void print_params(SDParams params) {
@@ -225,6 +235,8 @@ void print_usage(int argc, const char* argv[]) {
225235
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
226236
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
227237
printf(" If not specified, the default is the type of the weight file\n");
238+
printf(" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path");
239+
printf(" --imat-in [PATH] Use imatrix for quantization.");
228240
printf(" --lora-model-dir [DIR] lora model directory\n");
229241
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
230242
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
@@ -719,6 +731,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
719731
break;
720732
}
721733
params.preview_path = argv[i];
734+
} else if (arg == "--imat-out") {
735+
if (++i >= argc) {
736+
invalid_arg = true;
737+
break;
738+
}
739+
params.imatrix_out = argv[i];
740+
} else if (arg == "--imat-in") {
741+
if (++i >= argc) {
742+
invalid_arg = true;
743+
break;
744+
}
745+
params.imatrix_in.push_back(std::string(argv[i]));
722746
} else {
723747
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
724748
print_usage(argc, argv);
@@ -896,6 +920,10 @@ void step_callback(int step, sd_image_t image) {
896920
stbi_write_png(preview_path, image.width, image.height, image.channel, image.data, 0);
897921
}
898922

923+
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
924+
return g_collector.collect_imatrix(t, ask, user_data);
925+
}
926+
899927
int main(int argc, const char* argv[]) {
900928
SDParams params;
901929

@@ -910,6 +938,19 @@ int main(int argc, const char* argv[]) {
910938
printf("%s", sd_get_system_info());
911939
}
912940

941+
if (params.imatrix_out != "") {
942+
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, &params);
943+
}
944+
if (params.imatrix_out != "" || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
945+
setConvertImatrixCollector((void*)&g_collector);
946+
for (const auto& in_file : params.imatrix_in) {
947+
printf("loading imatrix from '%s'\n", in_file.c_str());
948+
if (!g_collector.load_imatrix(in_file.c_str())) {
949+
printf("Failed to load %s\n", in_file.c_str());
950+
}
951+
}
952+
}
953+
913954
if (params.mode == CONVERT) {
914955
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
915956
if (!success) {
@@ -1233,6 +1274,9 @@ int main(int argc, const char* argv[]) {
12331274
free(results[i].data);
12341275
results[i].data = NULL;
12351276
}
1277+
if (params.imatrix_out != "") {
1278+
g_collector.save_imatrix(params.imatrix_out);
1279+
}
12361280
free(results);
12371281
free_sd_ctx(sd_ctx);
12381282
free(control_image_buffer);

ggml_extend.hpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
#include "ggml-alloc.h"
2424
#include "ggml-backend.h"
2525
#include "ggml-cpu.h"
26+
#include "ggml/src/ggml-impl.h"
2627
#include "ggml.h"
2728

2829
#include "model.h"
30+
#include "util.h"
2931

3032
#ifdef SD_USE_CUDA
3133
#include "ggml-cuda.h"
@@ -117,13 +119,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g
117119
b);
118120
}
119121

120-
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
121-
(void)level;
122-
(void)user_data;
123-
fputs(text, stderr);
124-
fflush(stderr);
125-
}
126-
127122
__STATIC_INLINE__ void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr<RNG> rng) {
128123
uint32_t n = (uint32_t)ggml_nelements(tensor);
129124
std::vector<float> random_numbers = rng->randn(n);
@@ -1312,7 +1307,39 @@ struct GGMLRunner {
13121307
ggml_backend_cpu_set_n_threads(backend, n_threads);
13131308
}
13141309

1315-
ggml_backend_graph_compute(backend, gf);
1310+
auto callback_eval = get_callback_eval();
1311+
1312+
if(!callback_eval){
1313+
ggml_backend_graph_compute(backend, gf);
1314+
}else{
1315+
void * callback_eval_user_data = get_callback_eval_user_data();
1316+
for (int j0 = 0; j0 < gf->n_nodes; j0++) {
1317+
struct ggml_tensor * t = gf->nodes[j0];
1318+
1319+
// check if the user needs data from this node
1320+
bool need = callback_eval(t, true, callback_eval_user_data);
1321+
1322+
int j1 = j0;
1323+
1324+
// determine the range [j0, j1] of nodes that can be computed together
1325+
while (!need && j1 < gf->n_nodes - 1) {
1326+
t = gf->nodes[++j1];
1327+
need = callback_eval(t, true, callback_eval_user_data);
1328+
}
1329+
1330+
struct ggml_cgraph gv = ggml_graph_view(gf, j0, j1 + 1);
1331+
1332+
ggml_backend_graph_compute_async(backend, &gv);
1333+
1334+
if (need && !callback_eval(t, false, callback_eval_user_data)) {
1335+
break;
1336+
}
1337+
1338+
j0 = j1;
1339+
}
1340+
ggml_backend_synchronize(backend);
1341+
}
1342+
13161343
#ifdef GGML_PERF
13171344
ggml_graph_print(gf);
13181345
#endif
@@ -1416,6 +1443,7 @@ class Linear : public UnaryBlock {
14161443
wtype = GGML_TYPE_F32;
14171444
}
14181445
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
1446+
ggml_set_name(params["weight"], (prefix + "weight").c_str());
14191447
if (bias) {
14201448
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
14211449
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features);
@@ -1579,6 +1607,8 @@ class LayerNorm : public UnaryBlock {
15791607
if (elementwise_affine) {
15801608
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
15811609
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
1610+
ggml_set_name(params["weight"], (prefix + "weight").c_str());
1611+
15821612
if (bias) {
15831613
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
15841614
params["bias"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);

0 commit comments

Comments
 (0)