Skip to content

convert : experimental support for --mmproj flag #13023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 194 additions & 38 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,20 @@ class Model:
dir_model_card: Path
remote_hf_model_id: str | None

# for vision encoders
mmproj: bool
ignore_vision: bool = False # subclasses may overwrite this
mtmd_model: MultimodalModel | None = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@compilade currently, the GGUFWriter for mmproj file is wrapped inside MultimodalModel. The consequence is that MultimodalModel is now an attribute of Model

Another way is to male MultimodalModel inherits Model, but this seems a bit complicated to think about. Not sure which way you prefer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, the GGUFWriter for mmproj file is wrapped inside MultimodalModel.

This means that when self.mmproj is true, then self.gguf_writer is unused (but still created (!)), and another GGUFWriter is created somewhere in self.mtmd_model. It works because the output files are no longer opened/created as soon as the GGUFWriter is instantiated since #7827. (but there's still some unnecessary metadata keys set and ignored)

There's probably some way to simplify this.

What seems to be needed (eventually, to make this cleaner) is some more general abstraction to convert submodels (unless I'm misunderstanding the problem).

A submodel is part of a model, and a model is one or more submodels. Not quite sure how that should interact with model architectures, though. Each submodel could have its own architecture and tensor mappings, but I don't know what the main model architecture would be (the first submodel? a meta-model? or maybe there doesn't need to be a main one).

Since model loading doesn't quite support sub-models yet (we'll need to figure out namespaces or other ideas from #13028), only one submodel can be exported at a time, but at least conceptually it might be simpler to adapt such an abstraction to actually include multiple submodels in a single GGUF file once we've figured that out.

Another way is to make MultimodalModel inherit Model, but this seems a bit complicated to think about. Not sure which way you prefer?

I think I prefer the way you currently did it for now, because you're right that Model does a lot, and refactoring multimodal support will be simpler by duplicating some parts of Model in a smaller class like with MultimodalModel until we figure out something cleaner.

(I also don't know how MultimodalModel could cleanly subclass Model in this case)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A submodel is part of a model, and a model is one or more submodels. Not quite sure how that should interact with model architectures, though. Each submodel could have its own architecture and tensor mappings, but I don't know what the main model architecture would be (the first submodel? a meta-model? or maybe there doesn't need to be a main one).

That's some interesting questions. What I'm thinking is:

  • A model is currently equal to one HF repo. So for example with multimodal, one model contains both text model's tensors and vision/audio/etc model's tensors
  • A submodel is equal to a llama_model or clip_model, which only loads some tensors that it needs
  • For submodel compatible with libllama, it is distinguished by model arch, so currently each submodel has one submodel.arch. But this can be tricky in the case of models for clip.cpp which does not care about arch (the equivalent is the notion of "projector type")

(I also don't know how MultimodalModel could cleanly subclass Model in this case)

So from my POV above, what I'm thinking is that a submodel is just a Model with a custom list of tensors and metadata

One idea could be:

  • Having a generic Model that provides some basic functions like reading safetensors, GGUFWriter, etc
  • Moving LLM-specific logic (like vocab) into a base class TextModel that inherits Model
  • Finally, MultimodalModel inherits Model

Since model loading doesn't quite support sub-models yet (we'll need to figure out namespaces or other ideas from #13028)

Please note that, the currently and the mentioned issue is not quite related atm. The main problem is that mmproj is currently not supported by libllama, so it's currently not possible to bundle mmproj + LLM.

My currently PR is made mostly for 2 purposes:

  • To provide a more intuitive way to get mmproj file, since the current way requires firstly doing a surgery then convert 2 models
  • To unify all conversion scripts under examples/llava since some of them are very hacky and better to just abandon them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Having a generic Model that provides some basic functions like reading safetensors, GGUFWriter, etc
  • Moving LLM-specific logic (like vocab) into a base class TextModel that inherits Model
  • Finally, MultimodalModel inherits Model

Ok so I ended up doing this and this seems to be more generic (while being less hacky at the same time), please have a look on this commit: ddd7920

The main idea is to have VisionModel and TextModel both inherits Model super class, and existing text models inherit TextModel (hence why you see many LOC changed in the commit, but most of them are just changing Model --> TextModel)

Btw, it would be nice if we can finalize this during the week, so I can go ahead and add SmolVLM support. The clip.cpp implementation should be very straight-forward, the only thing blocking me is that the current mmproj conversion script is a nightmare to work with 😂 So would be nice if we can finally use convert_hf_to_gguf to get the mmproj

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main idea is to have VisionModel and TextModel both inherit Model super class

Right, this does feel much better, especially with how quant types are overridden in the intended way.

LoraModel will likely need adaptation, though. Not sure if it should be based on TextModel or Model still. (Does it make sense to have LoRA adapters of mmproj?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Does it make sense to have LoRA adapters of mmproj?)

I haven't seen anyone doing this, so I guess it doesn't make sense practically. In most (if not all) cases, people interested in doing LoRA for text model because it's easier to prepare the dataset.

And since LoraModel using Model.from_model_architecture which returns the TextModel subclass by default, I think it will continue to work as-is. Can you think of any cases which need to be adapted?


# subclasses should define this!
model_arch: gguf.MODEL_ARCH

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
mmproj: bool = False):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")

Expand Down Expand Up @@ -109,6 +115,7 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.mmproj = mmproj

# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
Expand All @@ -125,6 +132,28 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)

# vision encoder
if mmproj:
vision_hparams = self.hparams.get("vision_config")
if vision_hparams is None:
raise ValueError("Vision config not found in model config")
elif self.ignore_vision:
raise ValueError("Vision config found, but mmproj conversion for this model is not supported yet")
else:
self.mtmd_model = MultimodalModel(
hparams=vision_hparams,
ftype=self.ftype,
fname_out=self.fname_out,
endianess=self.endianess,
use_temp_file=self.use_temp_file,
)

@classmethod
def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
stem, suffix = path.stem, path.suffix
new_name = f"{prefix}{stem}{suffix}"
return path.with_name(new_name)

@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
Expand Down Expand Up @@ -272,8 +301,13 @@ def set_gguf_parameters(self):
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)

self.gguf_writer.add_file_type(self.ftype)
logger.info(f"gguf: file type = {self.ftype}")
if not self.mmproj:
self.gguf_writer.add_file_type(self.ftype)
logger.info(f"gguf: file type = {self.ftype}")
else:
assert self.mtmd_model is not None
self.mtmd_model.set_gguf_parameters(n_embd_text=n_embd)
logger.info(f"mmproj: file type = {self.mtmd_model.ftype}")

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
Expand Down Expand Up @@ -311,6 +345,10 @@ def prepare_tensors(self):
break

for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
# skip adding tensor if we're working with a vision model
if self.mmproj:
continue

# TODO: why do we squeeze here?
# data = data_torch.squeeze().numpy()
data = data_torch.numpy()
Expand Down Expand Up @@ -455,12 +493,18 @@ def prepare_metadata(self, vocab_only: bool):
self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)

def write(self):
self.prepare_tensors()
self.prepare_metadata(vocab_only=False)
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()
if self.mtmd_model is not None:
self.prepare_tensors()
self.prepare_metadata(vocab_only=False)
logger.info("Writing vision model")
self.mtmd_model.write()
else:
self.prepare_tensors()
self.prepare_metadata(vocab_only=False)
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()

def write_vocab(self):
if len(self.gguf_writer.tensors) != 1:
Expand All @@ -485,7 +529,10 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
hparams = json.load(f)
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
return hparams

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -1024,6 +1071,101 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])


# for converting mmproj file
class MultimodalModel:
hparams: dict
dir_model: Path
ftype: gguf.LlamaFileType
fname_out: Path
tensor_map: gguf.TensorNameMap
gguf_writer: gguf.GGUFWriter

def __init__(self, hparams: dict, ftype: gguf.LlamaFileType, fname_out: Path, endianess: gguf.GGUFEndian, use_temp_file: bool):
self.hparams = hparams
self.ftype = ftype
self.fname_out = fname_out
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
self.gguf_writer = gguf.GGUFWriter(path=None,
arch="clip",
endianess=endianess,
use_temp_file=use_temp_file)

def set_gguf_parameters(self, n_embd_text: int):
"""Function to be called by Model.set_gguf_parameters()"""
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PROJECTION_DIM, n_embd_text)
self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_VISION_ENCODER, True)

# vision config
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.IMAGE_SIZE, self.find_hparam(["image_size"]))
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PATCH_SIZE, self.find_hparam(["patch_size"]))
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.EMBEDDING_LENGTH, self.find_hparam(["hidden_size"]))
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.FEED_FORWARD_LENGTH, self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.BLOCK_COUNT, self.find_hparam(["num_hidden_layers"]))
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Attention.HEAD_COUNT, self.find_hparam(["num_attention_heads"]))
Comment on lines +1090 to +1095
Copy link
Collaborator Author

@ngxson ngxson Apr 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I didn't add gguf_writer.add_* wrapper functions because I don't yet have the full list of keys. This can be done in a follow-up PR


def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in self.hparams), None)
if key is not None:
return self.hparams[key]
if optional:
return None
raise KeyError(f"could not find any of: {keys}")

def get_quantization(self, mapped_name: str, data_torch: Tensor) -> gguf.GGMLQuantizationType:
is_1d = len(data_torch.shape) == 1
is_embd = "_embd" in mapped_name
can_quantize = not is_1d and not is_embd
data_qtype = gguf.GGMLQuantizationType.F32
if can_quantize:
if self.ftype == gguf.LlamaFileType.ALL_F32:
data_qtype = gguf.GGMLQuantizationType.F32
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_qtype = gguf.GGMLQuantizationType.F16
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
else:
raise ValueError(f"Unsupported file type: {self.ftype}")
return data_qtype

def add_tensor(self, original_name: str, data_torch: Tensor) -> None:
"""Function to be called inside Model.modify_tensors()"""
# name mapping
new_name = self.tensor_map.get_name(key=original_name, try_suffixes=(".weight", ".bias"))
if new_name is None:
raise ValueError(f"Can not map tensor {original_name!r}")

# process data
# old_dtype = data_torch.dtype
data_qtype = self.get_quantization(new_name, data_torch)
data = data_torch.numpy()
try:
data = gguf.quants.quantize(data, data_qtype)
except Exception as e:
logger.error(f"Error quantizing tensor '{new_name}': {e}, fallback to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)

# reverse shape to make it similar to the internal ggml dimension order
# TODO: we don't print old_dtype because it's not correct, to be fixed later
old_dtype = ""
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
logger.info(f"{f'%-32s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

# add tensor
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)

def write(self):
"""Function to be called by Model.write()"""
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()


@Model.register("GPTNeoXForCausalLM")
class GPTNeoXModel(Model):
model_arch = gguf.MODEL_ARCH.GPTNEOX
Expand Down Expand Up @@ -1781,20 +1923,13 @@ def prepare_tensors(self):
@Model.register("Llama4ForConditionalGeneration")
class Llama4Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA4
has_vision: bool = False
undo_permute = False
ignore_vision = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used anywhere

Suggested change
ignore_vision = True

Copy link
Collaborator Author

@ngxson ngxson Apr 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the conversation, but Llama 4 vision support also seem to be a low-hanging fruit. They no longer use cross-attn like in llama 3, here it's just simple embeddings passed from encoder to decoder, so also would be a nice thing to try out.

(Noting here so I remember)


# TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
# same with llama, but we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True
# IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]
Expand Down Expand Up @@ -1824,7 +1959,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name += ".weight"
data_torch = data_torch.transpose(-1, -2)

if "multi_modal_projector" in name or "vision_model" in name:
if "multi_modal_projector" in name or "mtmd_model" in name:
return []
return super().modify_tensors(data_torch, name, bid)

Expand Down Expand Up @@ -3474,24 +3609,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA3
has_vision: bool = False

# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True

def write(self):
super().write()
if self.has_vision:
logger.info("NOTE: this script only convert the language model to GGUF")
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")

def set_vocab(self):
self._set_vocab_sentencepiece()
Expand Down Expand Up @@ -3524,15 +3644,42 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])

if self.mtmd_model is not None:
self.mtmd_model.set_gguf_parameters(n_embd_text=hparams["hidden_size"])
vgguf = self.mtmd_model.gguf_writer
vgguf.add_string(gguf.Keys.ClipVision.PROJECTOR_TYPE, "gemma3")
# default values below are taken from HF tranformers code
vgguf.add_float32(gguf.Keys.ClipVision.Attention.LAYERNORM_EPS, self.mtmd_model.hparams.get("layer_norm_eps", 1e-6))
vgguf.add_array(gguf.Keys.ClipVision.IMAGE_MEAN, [0.5, 0.5, 0.5])
vgguf.add_array(gguf.Keys.ClipVision.IMAGE_STD, [0.5, 0.5, 0.5])
vgguf.add_bool (gguf.Keys.ClipVision.USE_GELU, True)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("language_model."):
name = name.replace("language_model.", "")

elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
# ignore vision tensors
return []
or name.startswith("multimodal_projector.") or name.startswith("mtmd_model."):
if self.mmproj:
assert self.mtmd_model is not None
# process vision tensors
name = name.replace("_weight", ".weight")
if "fc1" in name:
name = name.replace("fc1", "fc2")
else:
name = name.replace("fc2", "fc1")

# corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
# the other norm values are part of SigLIP model, and they are already correct
# ref code: Gemma3RMSNorm
if "soft_emb_norm.weight" in name:
logger.info(f"Correcting norm value for '{name}'")
data_torch = data_torch + 1

self.mtmd_model.add_tensor(name, data_torch)
return [] # vision tensor already handled

# remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name:
Expand Down Expand Up @@ -5554,6 +5701,10 @@ def parse_args() -> argparse.Namespace:
"--remote", action="store_true",
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
)
parser.add_argument(
"--mmproj", action="store_true",
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
)

args = parser.parse_args()
if not args.print_supported_models and args.model is None:
Expand Down Expand Up @@ -5633,6 +5784,10 @@ def main() -> None:

hparams = Model.load_hparams(dir_model)

if args.mmproj:
if "mmproj" not in fname_out.name:
fname_out = Model.add_prefix_to_filename(fname_out, "mmproj-")

with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_architecture = hparams["architectures"][0]
Expand All @@ -5649,7 +5804,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=str(args.model) if args.remote else None)
remote_hf_model_id=str(args.model) if args.remote else None,
mmproj=args.mmproj)

if args.vocab_only:
logger.info("Exporting model vocab...")
Expand Down
3 changes: 0 additions & 3 deletions examples/llava/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
// tensor name constants
//

#define TN_TOKEN_EMBD "%s.token_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
Expand All @@ -66,8 +65,6 @@
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s"
#define TN_TEXT_PROJ "text_projection.weight"
#define TN_VIS_PROJ "visual_projection.weight"
#define TN_LLAVA_PROJ "mm.%d.%s"
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
Expand Down
Loading
Loading