Skip to content

Commit ae78b97

Browse files
committed
conditionner: make text encoders optional for SD3.x
1 parent 6d84a30 commit ae78b97

File tree

1 file changed

+135
-41
lines changed

1 file changed

+135
-41
lines changed

conditioner.hpp

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -661,47 +661,108 @@ struct SD3CLIPEmbedder : public Conditioner {
661661
std::shared_ptr<CLIPTextModelRunner> clip_l;
662662
std::shared_ptr<CLIPTextModelRunner> clip_g;
663663
std::shared_ptr<T5Runner> t5;
664+
bool use_clip_l = false;
665+
bool use_clip_g = false;
666+
bool use_t5 = false;
664667

665668
SD3CLIPEmbedder(ggml_backend_t backend,
666669
std::map<std::string, enum ggml_type>& tensor_types,
667670
int clip_skip = -1)
668671
: clip_g_tokenizer(0) {
669-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
670-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
671-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
672+
if (clip_skip <= 0) {
673+
clip_skip = 2;
674+
}
675+
676+
for (auto pair : tensor_types) {
677+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
678+
use_clip_l = true;
679+
} else if (pair.first.find("text_encoders.clip_g") != std::string::npos) {
680+
use_clip_g = true;
681+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
682+
use_t5 = true;
683+
}
684+
}
685+
if (!use_clip_l && !use_clip_g && !use_t5) {
686+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
687+
return;
688+
}
689+
if (use_clip_l) {
690+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
691+
} else {
692+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
693+
}
694+
if (use_clip_g) {
695+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
696+
} else {
697+
LOG_WARN("clip_g text encoder not found! Prompt adherence might be degraded.");
698+
}
699+
if (use_t5) {
700+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
701+
} else {
702+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
703+
}
672704
set_clip_skip(clip_skip);
673705
}
674706

675707
void set_clip_skip(int clip_skip) {
676708
if (clip_skip <= 0) {
677709
clip_skip = 2;
678710
}
679-
clip_l->set_clip_skip(clip_skip);
680-
clip_g->set_clip_skip(clip_skip);
711+
if (use_clip_l) {
712+
clip_l->set_clip_skip(clip_skip);
713+
}
714+
if (use_clip_g) {
715+
clip_g->set_clip_skip(clip_skip);
716+
}
681717
}
682718

683719
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
684-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
685-
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
686-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
720+
if (use_clip_l) {
721+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
722+
}
723+
if (use_clip_g) {
724+
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
725+
}
726+
if (use_t5) {
727+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
728+
}
687729
}
688730

689731
void alloc_params_buffer() {
690-
clip_l->alloc_params_buffer();
691-
clip_g->alloc_params_buffer();
692-
t5->alloc_params_buffer();
732+
if (use_clip_l) {
733+
clip_l->alloc_params_buffer();
734+
}
735+
if (use_clip_g) {
736+
clip_g->alloc_params_buffer();
737+
}
738+
if (use_t5) {
739+
t5->alloc_params_buffer();
740+
}
693741
}
694742

695743
void free_params_buffer() {
696-
clip_l->free_params_buffer();
697-
clip_g->free_params_buffer();
698-
t5->free_params_buffer();
744+
if (use_clip_l) {
745+
clip_l->free_params_buffer();
746+
}
747+
if (use_clip_g) {
748+
clip_g->free_params_buffer();
749+
}
750+
if (use_t5) {
751+
t5->free_params_buffer();
752+
}
699753
}
700754

701755
size_t get_params_buffer_size() {
702-
size_t buffer_size = clip_l->get_params_buffer_size();
703-
buffer_size += clip_g->get_params_buffer_size();
704-
buffer_size += t5->get_params_buffer_size();
756+
size_t buffer_size = 0;
757+
if (use_clip_l) {
758+
buffer_size += clip_l->get_params_buffer_size();
759+
}
760+
if (use_clip_g) {
761+
buffer_size += clip_g->get_params_buffer_size();
762+
}
763+
if (use_t5) {
764+
buffer_size += t5->get_params_buffer_size();
765+
}
705766
return buffer_size;
706767
}
707768

@@ -733,23 +794,32 @@ struct SD3CLIPEmbedder : public Conditioner {
733794
for (const auto& item : parsed_attention) {
734795
const std::string& curr_text = item.first;
735796
float curr_weight = item.second;
736-
737-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
738-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
739-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
740-
741-
curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
742-
clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
743-
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
744-
745-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
746-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
747-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
797+
if (use_clip_l) {
798+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
799+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
800+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
801+
}
802+
if (use_clip_g) {
803+
std::vector<int> curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
804+
clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
805+
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
806+
}
807+
if (use_t5) {
808+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
809+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
810+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
811+
}
748812
}
749813

750-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
751-
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
752-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
814+
if (use_clip_l) {
815+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
816+
}
817+
if (use_clip_g) {
818+
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
819+
}
820+
if (use_t5) {
821+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
822+
}
753823

754824
// for (int i = 0; i < clip_l_tokens.size(); i++) {
755825
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -794,10 +864,10 @@ struct SD3CLIPEmbedder : public Conditioner {
794864
std::vector<float> hidden_states_vec;
795865

796866
size_t chunk_len = 77;
797-
size_t chunk_count = clip_l_tokens.size() / chunk_len;
867+
size_t chunk_count = std::max(std::max(clip_l_tokens.size(), clip_g_tokens.size()), t5_tokens.size()) / chunk_len;
798868
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
799869
// clip_l
800-
{
870+
if (use_clip_l) {
801871
std::vector<int> chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len,
802872
clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len);
803873
std::vector<float> chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len,
@@ -842,10 +912,17 @@ struct SD3CLIPEmbedder : public Conditioner {
842912
&pooled_l,
843913
work_ctx);
844914
}
915+
} else {
916+
chunk_hidden_states_l = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, chunk_len);
917+
ggml_set_f32(chunk_hidden_states_l, 0.f);
918+
if (chunk_idx == 0) {
919+
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
920+
ggml_set_f32(pooled_l, 0.f);
921+
}
845922
}
846923

847924
// clip_g
848-
{
925+
if (use_clip_g) {
849926
std::vector<int> chunk_tokens(clip_g_tokens.begin() + chunk_idx * chunk_len,
850927
clip_g_tokens.begin() + (chunk_idx + 1) * chunk_len);
851928
std::vector<float> chunk_weights(clip_g_weights.begin() + chunk_idx * chunk_len,
@@ -891,10 +968,17 @@ struct SD3CLIPEmbedder : public Conditioner {
891968
&pooled_g,
892969
work_ctx);
893970
}
971+
} else {
972+
chunk_hidden_states_g = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 1280, chunk_len);
973+
ggml_set_f32(chunk_hidden_states_g, 0.f);
974+
if (chunk_idx == 0) {
975+
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
976+
ggml_set_f32(pooled_g, 0.f);
977+
}
894978
}
895979

896980
// t5
897-
{
981+
if (use_t5) {
898982
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
899983
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
900984
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -922,6 +1006,8 @@ struct SD3CLIPEmbedder : public Conditioner {
9221006
float new_mean = ggml_tensor_mean(tensor);
9231007
ggml_tensor_scale(tensor, (original_mean / new_mean));
9241008
}
1009+
} else {
1010+
chunk_hidden_states_t5 = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 0);
9251011
}
9261012

9271013
auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d(work_ctx,
@@ -964,11 +1050,19 @@ struct SD3CLIPEmbedder : public Conditioner {
9641050
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
9651051
}
9661052

967-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
968-
hidden_states = ggml_reshape_2d(work_ctx,
969-
hidden_states,
970-
chunk_hidden_states->ne[0],
971-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1053+
if (hidden_states_vec.size() > 0) {
1054+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1055+
hidden_states = ggml_reshape_2d(work_ctx,
1056+
hidden_states,
1057+
chunk_hidden_states->ne[0],
1058+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1059+
} else {
1060+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 0);
1061+
}
1062+
if (pooled == NULL) {
1063+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 2048);
1064+
ggml_set_f32(pooled, 0.f);
1065+
}
9721066
return SDCondition(hidden_states, pooled, NULL);
9731067
}
9741068

0 commit comments

Comments
 (0)