@@ -661,47 +661,108 @@ struct SD3CLIPEmbedder : public Conditioner {
661
661
std::shared_ptr<CLIPTextModelRunner> clip_l;
662
662
std::shared_ptr<CLIPTextModelRunner> clip_g;
663
663
std::shared_ptr<T5Runner> t5;
664
+ bool use_clip_l = false ;
665
+ bool use_clip_g = false ;
666
+ bool use_t5 = false ;
664
667
665
668
SD3CLIPEmbedder (ggml_backend_t backend,
666
669
std::map<std::string, enum ggml_type>& tensor_types,
667
670
int clip_skip = -1 )
668
671
: clip_g_tokenizer(0 ) {
669
- clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, false );
670
- clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, false );
671
- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
672
+ if (clip_skip <= 0 ) {
673
+ clip_skip = 2 ;
674
+ }
675
+
676
+ for (auto pair : tensor_types) {
677
+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
678
+ use_clip_l = true ;
679
+ } else if (pair.first .find (" text_encoders.clip_g" ) != std::string::npos) {
680
+ use_clip_g = true ;
681
+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
682
+ use_t5 = true ;
683
+ }
684
+ }
685
+ if (!use_clip_l && !use_clip_g && !use_t5) {
686
+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
687
+ return ;
688
+ }
689
+ if (use_clip_l) {
690
+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
691
+ } else {
692
+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
693
+ }
694
+ if (use_clip_g) {
695
+ clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
696
+ } else {
697
+ LOG_WARN (" clip_g text encoder not found! Prompt adherence might be degraded." );
698
+ }
699
+ if (use_t5) {
700
+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
701
+ } else {
702
+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
703
+ }
672
704
set_clip_skip (clip_skip);
673
705
}
674
706
675
707
void set_clip_skip (int clip_skip) {
676
708
if (clip_skip <= 0 ) {
677
709
clip_skip = 2 ;
678
710
}
679
- clip_l->set_clip_skip (clip_skip);
680
- clip_g->set_clip_skip (clip_skip);
711
+ if (use_clip_l) {
712
+ clip_l->set_clip_skip (clip_skip);
713
+ }
714
+ if (use_clip_g) {
715
+ clip_g->set_clip_skip (clip_skip);
716
+ }
681
717
}
682
718
683
719
void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
684
- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
685
- clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
686
- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
720
+ if (use_clip_l) {
721
+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
722
+ }
723
+ if (use_clip_g) {
724
+ clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
725
+ }
726
+ if (use_t5) {
727
+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
728
+ }
687
729
}
688
730
689
731
void alloc_params_buffer () {
690
- clip_l->alloc_params_buffer ();
691
- clip_g->alloc_params_buffer ();
692
- t5->alloc_params_buffer ();
732
+ if (use_clip_l) {
733
+ clip_l->alloc_params_buffer ();
734
+ }
735
+ if (use_clip_g) {
736
+ clip_g->alloc_params_buffer ();
737
+ }
738
+ if (use_t5) {
739
+ t5->alloc_params_buffer ();
740
+ }
693
741
}
694
742
695
743
void free_params_buffer () {
696
- clip_l->free_params_buffer ();
697
- clip_g->free_params_buffer ();
698
- t5->free_params_buffer ();
744
+ if (use_clip_l) {
745
+ clip_l->free_params_buffer ();
746
+ }
747
+ if (use_clip_g) {
748
+ clip_g->free_params_buffer ();
749
+ }
750
+ if (use_t5) {
751
+ t5->free_params_buffer ();
752
+ }
699
753
}
700
754
701
755
size_t get_params_buffer_size () {
702
- size_t buffer_size = clip_l->get_params_buffer_size ();
703
- buffer_size += clip_g->get_params_buffer_size ();
704
- buffer_size += t5->get_params_buffer_size ();
756
+ size_t buffer_size = 0 ;
757
+ if (use_clip_l) {
758
+ buffer_size += clip_l->get_params_buffer_size ();
759
+ }
760
+ if (use_clip_g) {
761
+ buffer_size += clip_g->get_params_buffer_size ();
762
+ }
763
+ if (use_t5) {
764
+ buffer_size += t5->get_params_buffer_size ();
765
+ }
705
766
return buffer_size;
706
767
}
707
768
@@ -733,23 +794,32 @@ struct SD3CLIPEmbedder : public Conditioner {
733
794
for (const auto & item : parsed_attention) {
734
795
const std::string& curr_text = item.first ;
735
796
float curr_weight = item.second ;
736
-
737
- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
738
- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
739
- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
740
-
741
- curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
742
- clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
743
- clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
744
-
745
- curr_tokens = t5_tokenizer.Encode (curr_text, true );
746
- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
747
- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
797
+ if (use_clip_l) {
798
+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
799
+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
800
+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
801
+ }
802
+ if (use_clip_g) {
803
+ std::vector<int > curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
804
+ clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
805
+ clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
806
+ }
807
+ if (use_t5) {
808
+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
809
+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
810
+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
811
+ }
748
812
}
749
813
750
- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
751
- clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
752
- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
814
+ if (use_clip_l) {
815
+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
816
+ }
817
+ if (use_clip_g) {
818
+ clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
819
+ }
820
+ if (use_t5) {
821
+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
822
+ }
753
823
754
824
// for (int i = 0; i < clip_l_tokens.size(); i++) {
755
825
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -794,10 +864,10 @@ struct SD3CLIPEmbedder : public Conditioner {
794
864
std::vector<float > hidden_states_vec;
795
865
796
866
size_t chunk_len = 77 ;
797
- size_t chunk_count = clip_l_tokens.size () / chunk_len;
867
+ size_t chunk_count = std::max ( std::max ( clip_l_tokens.size (), clip_g_tokens. size ()), t5_tokens. size () ) / chunk_len;
798
868
for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
799
869
// clip_l
800
- {
870
+ if (use_clip_l) {
801
871
std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len,
802
872
clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
803
873
std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len,
@@ -842,10 +912,17 @@ struct SD3CLIPEmbedder : public Conditioner {
842
912
&pooled_l,
843
913
work_ctx);
844
914
}
915
+ } else {
916
+ chunk_hidden_states_l = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 768 , chunk_len);
917
+ ggml_set_f32 (chunk_hidden_states_l, 0 .f );
918
+ if (chunk_idx == 0 ) {
919
+ pooled_l = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
920
+ ggml_set_f32 (pooled_l, 0 .f );
921
+ }
845
922
}
846
923
847
924
// clip_g
848
- {
925
+ if (use_clip_g) {
849
926
std::vector<int > chunk_tokens (clip_g_tokens.begin () + chunk_idx * chunk_len,
850
927
clip_g_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
851
928
std::vector<float > chunk_weights (clip_g_weights.begin () + chunk_idx * chunk_len,
@@ -891,10 +968,17 @@ struct SD3CLIPEmbedder : public Conditioner {
891
968
&pooled_g,
892
969
work_ctx);
893
970
}
971
+ } else {
972
+ chunk_hidden_states_g = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 1280 , chunk_len);
973
+ ggml_set_f32 (chunk_hidden_states_g, 0 .f );
974
+ if (chunk_idx == 0 ) {
975
+ pooled_g = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 1280 );
976
+ ggml_set_f32 (pooled_g, 0 .f );
977
+ }
894
978
}
895
979
896
980
// t5
897
- {
981
+ if (use_t5) {
898
982
std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
899
983
t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
900
984
std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -922,6 +1006,8 @@ struct SD3CLIPEmbedder : public Conditioner {
922
1006
float new_mean = ggml_tensor_mean (tensor);
923
1007
ggml_tensor_scale (tensor, (original_mean / new_mean));
924
1008
}
1009
+ } else {
1010
+ chunk_hidden_states_t5 = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
925
1011
}
926
1012
927
1013
auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d (work_ctx,
@@ -964,11 +1050,19 @@ struct SD3CLIPEmbedder : public Conditioner {
964
1050
((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
965
1051
}
966
1052
967
- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
968
- hidden_states = ggml_reshape_2d (work_ctx,
969
- hidden_states,
970
- chunk_hidden_states->ne [0 ],
971
- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1053
+ if (hidden_states_vec.size () > 0 ) {
1054
+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1055
+ hidden_states = ggml_reshape_2d (work_ctx,
1056
+ hidden_states,
1057
+ chunk_hidden_states->ne [0 ],
1058
+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1059
+ } else {
1060
+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
1061
+ }
1062
+ if (pooled == NULL ) {
1063
+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 2048 );
1064
+ ggml_set_f32 (pooled, 0 .f );
1065
+ }
972
1066
return SDCondition (hidden_states, pooled, NULL );
973
1067
}
974
1068
0 commit comments