Skip to content

Commit a44db54

Browse files
committed
conditionner: make t5 optional for chroma
1 parent 4d72eca commit a44db54

File tree

1 file changed

+76
-50
lines changed

1 file changed

+76
-50
lines changed

conditioner.hpp

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,6 @@ struct FluxCLIPEmbedder : public Conditioner {
10151015
FluxCLIPEmbedder(ggml_backend_t backend,
10161016
std::map<std::string, enum ggml_type>& tensor_types,
10171017
int clip_skip = -1) {
1018-
10191018
for (auto pair : tensor_types) {
10201019
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
10211020
use_clip_l = true;
@@ -1225,7 +1224,6 @@ struct FluxCLIPEmbedder : public Conditioner {
12251224
ggml_set_f32(chunk_hidden_states, 0.f);
12261225
}
12271226

1228-
12291227
int64_t t1 = ggml_time_ms();
12301228
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
12311229
if (force_zero_embeddings) {
@@ -1234,12 +1232,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12341232
vec[i] = 0;
12351233
}
12361234
}
1237-
1235+
12381236
hidden_states_vec.insert(hidden_states_vec.end(),
1239-
(float*)chunk_hidden_states->data,
1240-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1237+
(float*)chunk_hidden_states->data,
1238+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
12411239
}
1242-
1240+
12431241
if (hidden_states_vec.size() > 0) {
12441242
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
12451243
hidden_states = ggml_reshape_2d(work_ctx,
@@ -1294,35 +1292,54 @@ struct PixArtCLIPEmbedder : public Conditioner {
12941292
bool use_mask = false;
12951293
int mask_pad = 1;
12961294

1295+
bool use_t5 = false;
1296+
12971297
PixArtCLIPEmbedder(ggml_backend_t backend,
12981298
std::map<std::string, enum ggml_type>& tensor_types,
12991299
int clip_skip = -1,
13001300
bool use_mask = false,
13011301
int mask_pad = 1)
13021302
: use_mask(use_mask), mask_pad(mask_pad) {
1303-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1303+
for (auto pair : tensor_types) {
1304+
if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1305+
use_t5 = true;
1306+
}
1307+
}
1308+
1309+
if (!use_t5) {
1310+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1311+
return;
1312+
} else {
1313+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1314+
}
13041315
}
13051316

13061317
void set_clip_skip(int clip_skip) {
13071318
}
13081319

13091320
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1310-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1321+
if (use_t5) {
1322+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1323+
}
13111324
}
13121325

13131326
void alloc_params_buffer() {
1314-
t5->alloc_params_buffer();
1327+
if (use_t5) {
1328+
t5->alloc_params_buffer();
1329+
}
13151330
}
13161331

13171332
void free_params_buffer() {
1318-
t5->free_params_buffer();
1333+
if (use_t5) {
1334+
t5->free_params_buffer();
1335+
}
13191336
}
13201337

13211338
size_t get_params_buffer_size() {
13221339
size_t buffer_size = 0;
1323-
1324-
buffer_size += t5->get_params_buffer_size();
1325-
1340+
if (use_t5) {
1341+
buffer_size += t5->get_params_buffer_size();
1342+
}
13261343
return buffer_size;
13271344
}
13281345

@@ -1348,17 +1365,18 @@ struct PixArtCLIPEmbedder : public Conditioner {
13481365
std::vector<int> t5_tokens;
13491366
std::vector<float> t5_weights;
13501367
std::vector<float> t5_mask;
1351-
for (const auto& item : parsed_attention) {
1352-
const std::string& curr_text = item.first;
1353-
float curr_weight = item.second;
1354-
1355-
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1356-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1357-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1358-
}
1368+
if (use_t5) {
1369+
for (const auto& item : parsed_attention) {
1370+
const std::string& curr_text = item.first;
1371+
float curr_weight = item.second;
13591372

1360-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1373+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1374+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1375+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1376+
}
13611377

1378+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1379+
}
13621380
return {t5_tokens, t5_weights, t5_mask};
13631381
}
13641382

@@ -1395,38 +1413,44 @@ struct PixArtCLIPEmbedder : public Conditioner {
13951413
std::vector<float> hidden_states_vec;
13961414

13971415
size_t chunk_count = t5_tokens.size() / chunk_len;
1398-
13991416
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
14001417
// t5
1401-
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1402-
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1403-
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1404-
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1405-
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1406-
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1407-
1408-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1409-
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1410-
1411-
t5->compute(n_threads,
1412-
input_ids,
1413-
t5_attn_mask_chunk,
1414-
&chunk_hidden_states,
1415-
work_ctx);
1416-
{
1417-
auto tensor = chunk_hidden_states;
1418-
float original_mean = ggml_tensor_mean(tensor);
1419-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1420-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1421-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1422-
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1423-
value *= chunk_weights[i1];
1424-
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1418+
1419+
if (use_t5) {
1420+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1421+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1422+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1423+
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1424+
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1425+
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1426+
1427+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1428+
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1429+
t5->compute(n_threads,
1430+
input_ids,
1431+
t5_attn_mask_chunk,
1432+
&chunk_hidden_states,
1433+
work_ctx);
1434+
{
1435+
auto tensor = chunk_hidden_states;
1436+
float original_mean = ggml_tensor_mean(tensor);
1437+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1438+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1439+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1440+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1441+
value *= chunk_weights[i1];
1442+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1443+
}
14251444
}
14261445
}
1446+
float new_mean = ggml_tensor_mean(tensor);
1447+
ggml_tensor_scale(tensor, (original_mean / new_mean));
14271448
}
1428-
float new_mean = ggml_tensor_mean(tensor);
1429-
ggml_tensor_scale(tensor, (original_mean / new_mean));
1449+
} else {
1450+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1451+
ggml_set_f32(chunk_hidden_states, 0.f);
1452+
t5_attn_mask = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, chunk_len);
1453+
ggml_set_f32(t5_attn_mask, -HUGE_VALF);
14301454
}
14311455

14321456
int64_t t1 = ggml_time_ms();
@@ -1450,8 +1474,10 @@ struct PixArtCLIPEmbedder : public Conditioner {
14501474
chunk_hidden_states->ne[0],
14511475
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
14521476
} else {
1453-
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1477+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
14541478
ggml_set_f32(hidden_states, 0.f);
1479+
t5_attn_mask = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, chunk_len);
1480+
ggml_set_f32(t5_attn_mask, -HUGE_VALF);
14551481
}
14561482

14571483
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);

0 commit comments

Comments
 (0)