diff --git a/model2vec/model.py b/model2vec/model.py index 62255b8c..9a8d8e8d 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -150,6 +150,7 @@ def from_pretrained( path: PathLike, token: str | None = None, normalize: bool | None = None, + dimensionality: int | None = None, ) -> StaticModel: """ Load a StaticModel from a local path or huggingface hub path. @@ -159,12 +160,27 @@ def from_pretrained( :param path: The path to load your static model from. :param token: The huggingface token to use. :param normalize: Whether to normalize the embeddings. + :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model. + This is useful if you want to load a model with a lower dimensionality. + Note that this only applies if you have trained your model using mrl or PCA. :return: A StaticModel + :raises: ValueError if the dimensionality is greater than the model dimensionality. """ from model2vec.hf_utils import load_pretrained embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False) + if dimensionality is not None: + if dimensionality > embeddings.shape[1]: + raise ValueError( + f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}" + ) + embeddings = embeddings[:, :dimensionality] + if config.get("apply_pca", None) is None: + logger.warning( + "You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected." + ) + return cls( embeddings, tokenizer, diff --git a/tests/test_model.py b/tests/test_model.py index 810fed84..ebcf37fe 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -182,6 +182,36 @@ def test_load_pretrained( assert loaded_model.config == mock_config +def test_load_pretrained_dim( + tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] +) -> None: + """Test loading a pretrained model after saving it.""" + # Save the model to a temporary path + model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) + save_path = tmp_path / "saved_model" + model.save_pretrained(save_path) + + # Load the model back from the same path + loaded_model = StaticModel.from_pretrained(save_path, dimensionality=2) + + # Assert that the loaded model has the same properties as the original one + np.testing.assert_array_equal(loaded_model.embedding, mock_vectors[:, :2]) + assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab() + assert loaded_model.config == mock_config + + # Load the model back from the same path + loaded_model = StaticModel.from_pretrained(save_path, dimensionality=None) + + # Assert that the loaded model has the same properties as the original one + np.testing.assert_array_equal(loaded_model.embedding, mock_vectors) + assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab() + assert loaded_model.config == mock_config + + # Load the model back from the same path + with pytest.raises(ValueError): + StaticModel.from_pretrained(save_path, dimensionality=3000) + + def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: """Tests whether the normalization initialization is correct.""" model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)