Skip to content

Commit 06e51d6

Browse files
committed
conditionner: make text encoders optional for Flux
1 parent 6d84a30 commit 06e51d6

File tree

2 files changed

+115
-50
lines changed

2 files changed

+115
-50
lines changed

conditioner.hpp

Lines changed: 114 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,39 +1009,83 @@ struct FluxCLIPEmbedder : public Conditioner {
10091009
std::shared_ptr<T5Runner> t5;
10101010
size_t chunk_len = 256;
10111011

1012+
bool use_clip_l = false;
1013+
bool use_t5 = false;
1014+
10121015
FluxCLIPEmbedder(ggml_backend_t backend,
10131016
std::map<std::string, enum ggml_type>& tensor_types,
10141017
int clip_skip = -1) {
1015-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1016-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1018+
1019+
for (auto pair : tensor_types) {
1020+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
1021+
use_clip_l = true;
1022+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1023+
use_t5 = true;
1024+
}
1025+
}
1026+
1027+
if (!use_clip_l && !use_t5) {
1028+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1029+
return;
1030+
}
1031+
1032+
if (use_clip_l) {
1033+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1034+
} else {
1035+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
1036+
}
1037+
if (use_t5) {
1038+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1039+
} else {
1040+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
1041+
}
10171042
set_clip_skip(clip_skip);
10181043
}
10191044

10201045
void set_clip_skip(int clip_skip) {
10211046
if (clip_skip <= 0) {
10221047
clip_skip = 2;
10231048
}
1024-
clip_l->set_clip_skip(clip_skip);
1049+
if (use_clip_l) {
1050+
clip_l->set_clip_skip(clip_skip);
1051+
}
10251052
}
10261053

10271054
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1028-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1029-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1055+
if (use_clip_l) {
1056+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1057+
}
1058+
if (use_t5) {
1059+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1060+
}
10301061
}
10311062

10321063
void alloc_params_buffer() {
1033-
clip_l->alloc_params_buffer();
1034-
t5->alloc_params_buffer();
1064+
if (use_clip_l) {
1065+
clip_l->alloc_params_buffer();
1066+
}
1067+
if (use_t5) {
1068+
t5->alloc_params_buffer();
1069+
}
10351070
}
10361071

10371072
void free_params_buffer() {
1038-
clip_l->free_params_buffer();
1039-
t5->free_params_buffer();
1073+
if (use_clip_l) {
1074+
clip_l->free_params_buffer();
1075+
}
1076+
if (use_t5) {
1077+
t5->free_params_buffer();
1078+
}
10401079
}
10411080

10421081
size_t get_params_buffer_size() {
1043-
size_t buffer_size = clip_l->get_params_buffer_size();
1044-
buffer_size += t5->get_params_buffer_size();
1082+
size_t buffer_size = 0;
1083+
if (use_clip_l) {
1084+
buffer_size += clip_l->get_params_buffer_size();
1085+
}
1086+
if (use_t5) {
1087+
buffer_size += t5->get_params_buffer_size();
1088+
}
10451089
return buffer_size;
10461090
}
10471091

@@ -1071,18 +1115,23 @@ struct FluxCLIPEmbedder : public Conditioner {
10711115
for (const auto& item : parsed_attention) {
10721116
const std::string& curr_text = item.first;
10731117
float curr_weight = item.second;
1074-
1075-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1076-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1077-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1078-
1079-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
1080-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1081-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1118+
if (use_clip_l) {
1119+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1120+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1121+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1122+
}
1123+
if (use_t5) {
1124+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1125+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1126+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1127+
}
1128+
}
1129+
if (use_clip_l) {
1130+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1131+
}
1132+
if (use_t5) {
1133+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
10821134
}
1083-
1084-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1085-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
10861135

10871136
// for (int i = 0; i < clip_l_tokens.size(); i++) {
10881137
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1114,34 +1163,36 @@ struct FluxCLIPEmbedder : public Conditioner {
11141163
struct ggml_tensor* pooled = NULL; // [768,]
11151164
std::vector<float> hidden_states_vec;
11161165

1117-
size_t chunk_count = t5_tokens.size() / chunk_len;
1166+
size_t chunk_count = std::max(clip_l_tokens.size() > 0 ? chunk_len : 0, t5_tokens.size()) / chunk_len;
11181167
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
11191168
// clip_l
11201169
if (chunk_idx == 0) {
1121-
size_t chunk_len_l = 77;
1122-
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1123-
clip_l_tokens.begin() + chunk_len_l);
1124-
std::vector<float> chunk_weights(clip_l_weights.begin(),
1125-
clip_l_weights.begin() + chunk_len_l);
1170+
if (use_clip_l) {
1171+
size_t chunk_len_l = 77;
1172+
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1173+
clip_l_tokens.begin() + chunk_len_l);
1174+
std::vector<float> chunk_weights(clip_l_weights.begin(),
1175+
clip_l_weights.begin() + chunk_len_l);
11261176

1127-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1128-
size_t max_token_idx = 0;
1177+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1178+
size_t max_token_idx = 0;
11291179

1130-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1131-
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1180+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1181+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
11321182

1133-
clip_l->compute(n_threads,
1134-
input_ids,
1135-
0,
1136-
NULL,
1137-
max_token_idx,
1138-
true,
1139-
&pooled,
1140-
work_ctx);
1183+
clip_l->compute(n_threads,
1184+
input_ids,
1185+
0,
1186+
NULL,
1187+
max_token_idx,
1188+
true,
1189+
&pooled,
1190+
work_ctx);
1191+
}
11411192
}
11421193

11431194
// t5
1144-
{
1195+
if (use_t5) {
11451196
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
11461197
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
11471198
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -1169,8 +1220,12 @@ struct FluxCLIPEmbedder : public Conditioner {
11691220
float new_mean = ggml_tensor_mean(tensor);
11701221
ggml_tensor_scale(tensor, (original_mean / new_mean));
11711222
}
1223+
} else {
1224+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1225+
ggml_set_f32(chunk_hidden_states, 0.f);
11721226
}
11731227

1228+
11741229
int64_t t1 = ggml_time_ms();
11751230
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
11761231
if (force_zero_embeddings) {
@@ -1179,17 +1234,26 @@ struct FluxCLIPEmbedder : public Conditioner {
11791234
vec[i] = 0;
11801235
}
11811236
}
1182-
1237+
11831238
hidden_states_vec.insert(hidden_states_vec.end(),
1184-
(float*)chunk_hidden_states->data,
1185-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1239+
(float*)chunk_hidden_states->data,
1240+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1241+
}
1242+
1243+
if (hidden_states_vec.size() > 0) {
1244+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1245+
hidden_states = ggml_reshape_2d(work_ctx,
1246+
hidden_states,
1247+
chunk_hidden_states->ne[0],
1248+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1249+
} else {
1250+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1251+
ggml_set_f32(hidden_states, 0.f);
1252+
}
1253+
if (pooled == NULL) {
1254+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1255+
ggml_set_f32(pooled, 0.f);
11861256
}
1187-
1188-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1189-
hidden_states = ggml_reshape_2d(work_ctx,
1190-
hidden_states,
1191-
chunk_hidden_states->ne[0],
1192-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
11931257
return SDCondition(hidden_states, pooled, NULL);
11941258
}
11951259

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ class StableDiffusionGGML {
329329
clip_backend = backend;
330330
bool use_t5xxl = false;
331331
if (sd_version_is_dit(version)) {
332+
// TODO: check if t5 is actually loaded?
332333
use_t5xxl = true;
333334
}
334335
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {

0 commit comments

Comments
 (0)