Skip to content

Commit 8ed3074

Browse files
committed
conditionner: make text encoders optional for Flux
1 parent ae78b97 commit 8ed3074

File tree

2 files changed

+115
-50
lines changed

2 files changed

+115
-50
lines changed

conditioner.hpp

Lines changed: 114 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,39 +1103,83 @@ struct FluxCLIPEmbedder : public Conditioner {
11031103
std::shared_ptr<T5Runner> t5;
11041104
size_t chunk_len = 256;
11051105

1106+
bool use_clip_l = false;
1107+
bool use_t5 = false;
1108+
11061109
FluxCLIPEmbedder(ggml_backend_t backend,
11071110
std::map<std::string, enum ggml_type>& tensor_types,
11081111
int clip_skip = -1) {
1109-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1110-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1112+
1113+
for (auto pair : tensor_types) {
1114+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
1115+
use_clip_l = true;
1116+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1117+
use_t5 = true;
1118+
}
1119+
}
1120+
1121+
if (!use_clip_l && !use_t5) {
1122+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1123+
return;
1124+
}
1125+
1126+
if (use_clip_l) {
1127+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1128+
} else {
1129+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
1130+
}
1131+
if (use_t5) {
1132+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1133+
} else {
1134+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
1135+
}
11111136
set_clip_skip(clip_skip);
11121137
}
11131138

11141139
void set_clip_skip(int clip_skip) {
11151140
if (clip_skip <= 0) {
11161141
clip_skip = 2;
11171142
}
1118-
clip_l->set_clip_skip(clip_skip);
1143+
if (use_clip_l) {
1144+
clip_l->set_clip_skip(clip_skip);
1145+
}
11191146
}
11201147

11211148
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1122-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1123-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1149+
if (use_clip_l) {
1150+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1151+
}
1152+
if (use_t5) {
1153+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1154+
}
11241155
}
11251156

11261157
void alloc_params_buffer() {
1127-
clip_l->alloc_params_buffer();
1128-
t5->alloc_params_buffer();
1158+
if (use_clip_l) {
1159+
clip_l->alloc_params_buffer();
1160+
}
1161+
if (use_t5) {
1162+
t5->alloc_params_buffer();
1163+
}
11291164
}
11301165

11311166
void free_params_buffer() {
1132-
clip_l->free_params_buffer();
1133-
t5->free_params_buffer();
1167+
if (use_clip_l) {
1168+
clip_l->free_params_buffer();
1169+
}
1170+
if (use_t5) {
1171+
t5->free_params_buffer();
1172+
}
11341173
}
11351174

11361175
size_t get_params_buffer_size() {
1137-
size_t buffer_size = clip_l->get_params_buffer_size();
1138-
buffer_size += t5->get_params_buffer_size();
1176+
size_t buffer_size = 0;
1177+
if (use_clip_l) {
1178+
buffer_size += clip_l->get_params_buffer_size();
1179+
}
1180+
if (use_t5) {
1181+
buffer_size += t5->get_params_buffer_size();
1182+
}
11391183
return buffer_size;
11401184
}
11411185

@@ -1165,18 +1209,23 @@ struct FluxCLIPEmbedder : public Conditioner {
11651209
for (const auto& item : parsed_attention) {
11661210
const std::string& curr_text = item.first;
11671211
float curr_weight = item.second;
1168-
1169-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1170-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1171-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1172-
1173-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
1174-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1175-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1212+
if (use_clip_l) {
1213+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1214+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1215+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1216+
}
1217+
if (use_t5) {
1218+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1219+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1220+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1221+
}
1222+
}
1223+
if (use_clip_l) {
1224+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1225+
}
1226+
if (use_t5) {
1227+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
11761228
}
1177-
1178-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1179-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
11801229

11811230
// for (int i = 0; i < clip_l_tokens.size(); i++) {
11821231
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1208,34 +1257,36 @@ struct FluxCLIPEmbedder : public Conditioner {
12081257
struct ggml_tensor* pooled = NULL; // [768,]
12091258
std::vector<float> hidden_states_vec;
12101259

1211-
size_t chunk_count = t5_tokens.size() / chunk_len;
1260+
size_t chunk_count = std::max(clip_l_tokens.size() > 0 ? chunk_len : 0, t5_tokens.size()) / chunk_len;
12121261
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
12131262
// clip_l
12141263
if (chunk_idx == 0) {
1215-
size_t chunk_len_l = 77;
1216-
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1217-
clip_l_tokens.begin() + chunk_len_l);
1218-
std::vector<float> chunk_weights(clip_l_weights.begin(),
1219-
clip_l_weights.begin() + chunk_len_l);
1264+
if (use_clip_l) {
1265+
size_t chunk_len_l = 77;
1266+
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1267+
clip_l_tokens.begin() + chunk_len_l);
1268+
std::vector<float> chunk_weights(clip_l_weights.begin(),
1269+
clip_l_weights.begin() + chunk_len_l);
12201270

1221-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1222-
size_t max_token_idx = 0;
1271+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1272+
size_t max_token_idx = 0;
12231273

1224-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1225-
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1274+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1275+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
12261276

1227-
clip_l->compute(n_threads,
1228-
input_ids,
1229-
0,
1230-
NULL,
1231-
max_token_idx,
1232-
true,
1233-
&pooled,
1234-
work_ctx);
1277+
clip_l->compute(n_threads,
1278+
input_ids,
1279+
0,
1280+
NULL,
1281+
max_token_idx,
1282+
true,
1283+
&pooled,
1284+
work_ctx);
1285+
}
12351286
}
12361287

12371288
// t5
1238-
{
1289+
if (use_t5) {
12391290
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
12401291
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
12411292
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -1263,8 +1314,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12631314
float new_mean = ggml_tensor_mean(tensor);
12641315
ggml_tensor_scale(tensor, (original_mean / new_mean));
12651316
}
1317+
} else {
1318+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1319+
ggml_set_f32(chunk_hidden_states, 0.f);
12661320
}
12671321

1322+
12681323
int64_t t1 = ggml_time_ms();
12691324
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
12701325
if (force_zero_embeddings) {
@@ -1273,17 +1328,26 @@ struct FluxCLIPEmbedder : public Conditioner {
12731328
vec[i] = 0;
12741329
}
12751330
}
1276-
1331+
12771332
hidden_states_vec.insert(hidden_states_vec.end(),
1278-
(float*)chunk_hidden_states->data,
1279-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1333+
(float*)chunk_hidden_states->data,
1334+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1335+
}
1336+
1337+
if (hidden_states_vec.size() > 0) {
1338+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1339+
hidden_states = ggml_reshape_2d(work_ctx,
1340+
hidden_states,
1341+
chunk_hidden_states->ne[0],
1342+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1343+
} else {
1344+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1345+
ggml_set_f32(hidden_states, 0.f);
1346+
}
1347+
if (pooled == NULL) {
1348+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1349+
ggml_set_f32(pooled, 0.f);
12801350
}
1281-
1282-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1283-
hidden_states = ggml_reshape_2d(work_ctx,
1284-
hidden_states,
1285-
chunk_hidden_states->ne[0],
1286-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
12871351
return SDCondition(hidden_states, pooled, NULL);
12881352
}
12891353

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ class StableDiffusionGGML {
329329
clip_backend = backend;
330330
bool use_t5xxl = false;
331331
if (sd_version_is_dit(version)) {
332+
// TODO: check if t5 is actually loaded?
332333
use_t5xxl = true;
333334
}
334335
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {

0 commit comments

Comments
 (0)