@@ -3061,6 +3061,64 @@ def prepare_tensors(self):
3061
3061
class Qwen3Model (Qwen2Model ):
3062
3062
model_arch = gguf .MODEL_ARCH .QWEN3
3063
3063
3064
+ # extra logic for rerank models
3065
+ token_false_id : int | None = None
3066
+ token_true_id : int | None = None
3067
+ sep_token_id : int = 0
3068
+ is_tied_embeddings : bool = False
3069
+
3070
+ def __init__ (self , * args , ** kwargs ):
3071
+ super ().__init__ (* args , ** kwargs )
3072
+ # a bit hacky, but currently the only way to detect if this is a rerank model
3073
+ readme_path = self .dir_model / "README.md"
3074
+ readme_text = ""
3075
+ if readme_path .exists ():
3076
+ with readme_path .open ("r" , encoding = "utf-8" ) as f :
3077
+ readme_text = f .read ()
3078
+ if "# Qwen3-Reranker" in readme_text :
3079
+ self ._find_rerank_config ()
3080
+
3081
+ def _find_rerank_config (self ):
3082
+ from transformers import AutoTokenizer
3083
+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
3084
+ self .token_false_id = tokenizer .convert_tokens_to_ids ("no" )
3085
+ self .token_true_id = tokenizer .convert_tokens_to_ids ("yes" )
3086
+ self .sep_token_id = tokenizer .convert_tokens_to_ids ("\\ n" ) # unused, but needed for rerank check
3087
+ self .is_tied_embeddings = self .hparams .get ("tie_word_embeddings" , False )
3088
+ logger .info (f"gguf: token_false_id = { self .token_false_id } , token_true_id = { self .token_true_id } " )
3089
+ logger .info (f"gguf: sep_token_id = { self .sep_token_id } " )
3090
+ logger .info (f"gguf: is_tied_embeddings = { self .is_tied_embeddings } " )
3091
+
3092
+ def set_gguf_parameters (self ):
3093
+ super ().set_gguf_parameters ()
3094
+ is_rerank = self .token_false_id is not None and self .token_true_id is not None
3095
+ if is_rerank :
3096
+ self .gguf_writer .add_pooling_type (gguf .PoolingType .RANK )
3097
+ self .gguf_writer .add_sep_token_id (self .sep_token_id )
3098
+ self .gguf_writer .add_uint32 (gguf .Keys .Classifier .OUTPUT_LABELS , 2 )
3099
+
3100
+ def _get_cls_out_tensor (self , data_torch : Tensor ) -> Tensor :
3101
+ # extract "yes" and "no" tokens from the output lm_head tensor
3102
+ assert self .token_false_id is not None and self .token_true_id is not None
3103
+ false_row = data_torch [self .token_false_id ]
3104
+ true_row = data_torch [self .token_true_id ]
3105
+ return torch .stack ([true_row , false_row ], dim = 0 )
3106
+
3107
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3108
+ is_rerank = self .token_false_id is not None and self .token_true_id is not None
3109
+
3110
+ if is_rerank :
3111
+ if self .is_tied_embeddings and "embed_tokens" in name :
3112
+ return [
3113
+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .CLS_OUT ] + ".weight" , self ._get_cls_out_tensor (data_torch )),
3114
+ (self .map_tensor_name (name ), data_torch ),
3115
+ ]
3116
+ if not self .is_tied_embeddings and "lm_head" in name :
3117
+ # this is the lm_head tensor, we need to extract the cls_out tensor
3118
+ return [(gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .CLS_OUT ] + ".weight" , self ._get_cls_out_tensor (data_torch ))]
3119
+
3120
+ return super ().modify_tensors (data_torch , name , bid )
3121
+
3064
3122
3065
3123
@ModelBase .register ("Qwen3MoeForCausalLM" )
3066
3124
class Qwen3MoeModel (Qwen2MoeModel ):
0 commit comments