diff --git a/README.md b/README.md
index baad76a3..c593d9cd 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ Model2Vec is a technique to turn any sentence transformer into a really small st
- [Distillation](#distillation)
- [Inference](#inference)
- [Evaluation](#evaluation)
+ - [Integrations](#integrations)
- [Model List](#model-list)
- [Results](#results)
- [Related Work](#related-work)
@@ -356,6 +357,89 @@ print(make_leaderboard(task_scores))
```
+### Integrations
+
+ Sentence Transformers
+
+
+Model2Vec can be used directly in [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) using the `StaticEmbedding` module.
+
+The following code snippet shows how to load a Model2Vec model into a Sentence Transformer model:
+```python
+from sentence_transformers import SentenceTransformer
+from sentence_transformers.models import StaticEmbedding
+
+# Initialize a StaticEmbedding module
+static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
+model = SentenceTransformer(modules=[static_embedding])
+embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
+```
+
+The following code snippet shows how to distill a model directly into a Sentence Transformer model:
+
+```python
+from sentence_transformers import SentenceTransformer
+from sentence_transformers.models import StaticEmbedding
+
+static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cpu", pca_dims=256)
+model = SentenceTransformer(modules=[static_embedding])
+embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
+```
+
+
+
+
+
+ Transformers.js
+
+
+
+To use a Model2Vec model in [transformers.js](https://github.com/huggingface/transformers.js), the following code snippet can be used as a starting point:
+
+```javascript
+import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
+
+const modelName = 'minishlab/potion-base-8M';
+
+const modelConfig = {
+ config: { model_type: 'model2vec' },
+ dtype: 'fp32',
+ revision: 'refs/pr/1'
+};
+const tokenizerConfig = {
+ revision: 'refs/pr/2'
+};
+
+const model = await AutoModel.from_pretrained(modelName, modelConfig);
+const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
+
+const texts = ['hello', 'hello world'];
+const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
+
+const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
+const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
+
+const flattened_input_ids = input_ids.flat();
+const modelInputs = {
+ input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
+ offsets: new Tensor('int64', offsets, [offsets.length])
+};
+
+const { embeddings } = await model(modelInputs);
+console.log(embeddings.tolist()); // output matches python version
+```
+
+Note that this requires that the Model2Vec has a `model.onnx` file and several required tokenizers file. To generate these for a model that does not have them yet, the following code snippet can be used:
+
+```bash
+python scripts/export_to_onnx.py --model_path --save_path ""
+```
+
+
+
+
+
+
## Model List
We provide a number of models that can be used out of the box. These models are available on the [HuggingFace hub](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) and can be loaded using the `from_pretrained` method. The models are listed below.
diff --git a/pyproject.toml b/pyproject.toml
index 7209a95d..625a5fd4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,6 +54,8 @@ dev = [
]
distill = ["torch", "transformers", "scikit-learn"]
+onnx = ["onnx", "torch"]
+
[project.urls]
"Homepage" = "https://github.com/MinishLab"
"Bug Reports" = "https://github.com/MinishLab/model2vec/issues"
diff --git a/scripts/export_to_onnx.py b/scripts/export_to_onnx.py
new file mode 100644
index 00000000..aa1bfae0
--- /dev/null
+++ b/scripts/export_to_onnx.py
@@ -0,0 +1,208 @@
+from model2vec.utils import get_package_extras, importable
+
+# Define the optional dependency group name
+_REQUIRED_EXTRA = "onnx"
+
+# Check if each dependency for the "onnx" group is importable
+for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
+ importable(extra_dependency, _REQUIRED_EXTRA)
+
+import argparse
+import json
+import logging
+from pathlib import Path
+
+import torch
+from tokenizers import Tokenizer
+from transformers import AutoTokenizer, PreTrainedTokenizerFast
+
+from model2vec import StaticModel
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class TorchStaticModel(torch.nn.Module):
+ def __init__(self, model: StaticModel) -> None:
+ """Initialize the TorchStaticModel with a StaticModel instance."""
+ super().__init__()
+ # Convert NumPy embeddings to a torch.nn.EmbeddingBag
+ embeddings = torch.tensor(model.embedding, dtype=torch.float32)
+ self.embedding_bag = torch.nn.EmbeddingBag.from_pretrained(embeddings, mode="mean", freeze=True)
+ self.normalize = model.normalize
+ # Save tokenizer attributes
+ self.tokenizer = model.tokenizer
+ self.unk_token_id = model.unk_token_id
+ self.median_token_length = model.median_token_length
+
+ def forward(self, input_ids: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the model.
+
+ :param input_ids: The input token ids.
+ :param offsets: The offsets to compute the mean pooling.
+ :return: The embeddings.
+ """
+ # Perform embedding lookup and mean pooling
+ embeddings = self.embedding_bag(input_ids, offsets)
+ # Normalize if required
+ if self.normalize:
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
+ return embeddings
+
+ def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Tokenize the input sentences.
+
+ :param sentences: The input sentences.
+ :param max_length: The maximum length of the input_ids.
+ :return: The input_ids and offsets.
+ """
+ # Tokenization logic similar to your StaticModel
+ if max_length is not None:
+ m = max_length * self.median_token_length
+ sentences = [sentence[:m] for sentence in sentences]
+ encodings = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
+ encodings_ids = [encoding.ids for encoding in encodings]
+ if self.unk_token_id is not None:
+ # Remove unknown tokens
+ encodings_ids = [
+ [token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
+ ]
+ if max_length is not None:
+ encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
+ # Flatten input_ids and compute offsets
+ offsets = torch.tensor([0] + [len(ids) for ids in encodings_ids[:-1]], dtype=torch.long).cumsum(dim=0)
+ input_ids = torch.tensor(
+ [token_id for token_ids in encodings_ids for token_id in token_ids],
+ dtype=torch.long,
+ )
+ return input_ids, offsets
+
+
+def export_model_to_onnx(model_path: str, save_path: Path) -> None:
+ """
+ Export the StaticModel to ONNX format and save tokenizer files.
+
+ :param model_path: The path to the pretrained StaticModel.
+ :param save_path: The directory to save the model and related files.
+ """
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ # Load the StaticModel
+ model = StaticModel.from_pretrained(model_path)
+ torch_model = TorchStaticModel(model)
+
+ # Save the model using save_pretrained
+ model.save_pretrained(save_path)
+
+ # Prepare dummy input data
+ texts = ["hello", "hello world"]
+ input_ids, offsets = torch_model.tokenize(texts)
+
+ # Export the model to ONNX
+ onnx_model_path = save_path / "onnx/model.onnx"
+ onnx_model_path.parent.mkdir(parents=True, exist_ok=True)
+ torch.onnx.export(
+ torch_model,
+ (input_ids, offsets),
+ str(onnx_model_path),
+ export_params=True,
+ opset_version=14,
+ do_constant_folding=True,
+ input_names=["input_ids", "offsets"],
+ output_names=["embeddings"],
+ dynamic_axes={
+ "input_ids": {0: "num_tokens"},
+ "offsets": {0: "batch_size"},
+ "embeddings": {0: "batch_size"},
+ },
+ )
+
+ logger.info(f"Model has been successfully exported to {onnx_model_path}")
+
+ # Save the tokenizer files required for transformers.js
+ save_tokenizer(model.tokenizer, save_path)
+ logger.info(f"Tokenizer files have been saved to {save_path}")
+
+
+def save_tokenizer(tokenizer: Tokenizer, save_directory: Path) -> None:
+ """
+ Save tokenizer files in a format compatible with Transformers.
+
+ :param tokenizer: The tokenizer from the StaticModel.
+ :param save_directory: The directory to save the tokenizer files.
+ :raises FileNotFoundError: If config.json is not found in save_directory.
+ :raises FileNotFoundError: If tokenizer_config.json is not found in save_directory.
+ :raises ValueError: If tokenizer_name is not found in config.json.
+ """
+ tokenizer_json_path = save_directory / "tokenizer.json"
+ tokenizer.save(str(tokenizer_json_path))
+
+ # Save vocab.txt
+ vocab = tokenizer.get_vocab()
+ vocab_path = save_directory / "vocab.txt"
+ with open(vocab_path, "w", encoding="utf-8") as vocab_file:
+ for token in sorted(vocab, key=vocab.get):
+ vocab_file.write(f"{token}\n")
+
+ # Load config.json to get tokenizer_name
+ config_path = save_directory / "config.json"
+ if config_path.exists():
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ else:
+ raise FileNotFoundError(f"config.json not found in {save_directory}")
+
+ tokenizer_name = config.get("tokenizer_name")
+ if not tokenizer_name:
+ raise ValueError("tokenizer_name not found in config.json")
+
+ # Load the original tokenizer
+ original_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ # Extract special tokens and tokenizer class
+ special_tokens = original_tokenizer.special_tokens_map
+ tokenizer_class = original_tokenizer.__class__.__name__
+
+ # Load the tokenizer using PreTrainedTokenizerFast with special tokens
+ fast_tokenizer = PreTrainedTokenizerFast(
+ tokenizer_file=str(tokenizer_json_path),
+ **special_tokens,
+ )
+
+ # Save the tokenizer files
+ fast_tokenizer.save_pretrained(str(save_directory))
+ # Modify tokenizer_config.json to set the correct tokenizer_class
+ tokenizer_config_path = save_directory / "tokenizer_config.json"
+ if tokenizer_config_path.exists():
+ with open(tokenizer_config_path, "r", encoding="utf-8") as f:
+ tokenizer_config = json.load(f)
+ else:
+ raise FileNotFoundError(f"tokenizer_config.json not found in {save_directory}")
+
+ # Update the tokenizer_class field
+ tokenizer_config["tokenizer_class"] = tokenizer_class
+
+ # Write the updated tokenizer_config.json back to disk
+ with open(tokenizer_config_path, "w", encoding="utf-8") as f:
+ json.dump(tokenizer_config, f, indent=4, sort_keys=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Export StaticModel to ONNX format")
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ required=True,
+ help="Path to the pretrained StaticModel",
+ )
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ required=True,
+ help="Directory to save the exported model and files",
+ )
+ args = parser.parse_args()
+
+ export_model_to_onnx(args.model_path, Path(args.save_path))
diff --git a/uv.lock b/uv.lock
index b331640b..68a34068 100644
--- a/uv.lock
+++ b/uv.lock
@@ -531,6 +531,10 @@ distill = [
{ name = "torch" },
{ name = "transformers" },
]
+onnx = [
+ { name = "onnx" },
+ { name = "torch" },
+]
[package.metadata]
requires-dist = [
@@ -538,6 +542,7 @@ requires-dist = [
{ name = "ipython", marker = "extra == 'dev'" },
{ name = "mypy", marker = "extra == 'dev'" },
{ name = "numpy" },
+ { name = "onnx", marker = "extra == 'onnx'" },
{ name = "pre-commit", marker = "extra == 'dev'" },
{ name = "pytest", marker = "extra == 'dev'" },
{ name = "pytest-coverage", marker = "extra == 'dev'" },
@@ -548,6 +553,7 @@ requires-dist = [
{ name = "setuptools" },
{ name = "tokenizers", specifier = ">=0.20" },
{ name = "torch", marker = "extra == 'distill'" },
+ { name = "torch", marker = "extra == 'onnx'" },
{ name = "tqdm" },
{ name = "transformers", marker = "extra == 'distill'" },
]
@@ -799,6 +805,38 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
]
+[[package]]
+name = "onnx"
+version = "1.17.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "numpy" },
+ { name = "protobuf" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/9a/54/0e385c26bf230d223810a9c7d06628d954008a5e5e4b73ee26ef02327282/onnx-1.17.0.tar.gz", hash = "sha256:48ca1a91ff73c1d5e3ea2eef20ae5d0e709bb8a2355ed798ffc2169753013fd3", size = 12165120 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2e/29/57053ba7787788ac75efb095cfc1ae290436b6d3a26754693cd7ed1b4fac/onnx-1.17.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:38b5df0eb22012198cdcee527cc5f917f09cce1f88a69248aaca22bd78a7f023", size = 16645616 },
+ { url = "https://files.pythonhosted.org/packages/75/0d/831807a18db2a5e8f7813848c59272b904a4ef3939fe4d1288cbce9ea735/onnx-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d545335cb49d4d8c47cc803d3a805deb7ad5d9094dc67657d66e568610a36d7d", size = 15908420 },
+ { url = "https://files.pythonhosted.org/packages/dd/5b/c4f95dbe652d14aeba9afaceb177e9ffc48ac3c03048dd3f872f26f07e34/onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3193a3672fc60f1a18c0f4c93ac81b761bc72fd8a6c2035fa79ff5969f07713e", size = 16046244 },
+ { url = "https://files.pythonhosted.org/packages/08/a9/c1f218085043dccc6311460239e253fa6957cf12ee4b0a56b82014938d0b/onnx-1.17.0-cp310-cp310-win32.whl", hash = "sha256:0141c2ce806c474b667b7e4499164227ef594584da432fd5613ec17c1855e311", size = 14423516 },
+ { url = "https://files.pythonhosted.org/packages/0e/d3/d26ebf590a65686dde6b27fef32493026c5be9e42083340d947395f93405/onnx-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:dfd777d95c158437fda6b34758f0877d15b89cbe9ff45affbedc519b35345cf9", size = 14528496 },
+ { url = "https://files.pythonhosted.org/packages/e5/a9/8d1b1d53aec70df53e0f57e9f9fcf47004276539e29230c3d5f1f50719ba/onnx-1.17.0-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:d6fc3a03fc0129b8b6ac03f03bc894431ffd77c7d79ec023d0afd667b4d35869", size = 16647991 },
+ { url = "https://files.pythonhosted.org/packages/7b/e3/cc80110e5996ca61878f7b4c73c7a286cd88918ff35eacb60dc75ab11ef5/onnx-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01a4b63d4e1d8ec3e2f069e7b798b2955810aa434f7361f01bc8ca08d69cce4", size = 15908949 },
+ { url = "https://files.pythonhosted.org/packages/b1/2f/91092557ed478e323a2b4471e2081fdf88d1dd52ae988ceaf7db4e4506ff/onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a183c6178be001bf398260e5ac2c927dc43e7746e8638d6c05c20e321f8c949", size = 16048190 },
+ { url = "https://files.pythonhosted.org/packages/ac/59/9ea23fc22d0bb853133f363e6248e31bcbc6c1c90543a3938c00412ac02a/onnx-1.17.0-cp311-cp311-win32.whl", hash = "sha256:081ec43a8b950171767d99075b6b92553901fa429d4bc5eb3ad66b36ef5dbe3a", size = 14424299 },
+ { url = "https://files.pythonhosted.org/packages/51/a5/19b0dfcb567b62e7adf1a21b08b23224f0c2d13842aee4d0abc6f07f9cf5/onnx-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:95c03e38671785036bb704c30cd2e150825f6ab4763df3a4f1d249da48525957", size = 14529142 },
+ { url = "https://files.pythonhosted.org/packages/b4/dd/c416a11a28847fafb0db1bf43381979a0f522eb9107b831058fde012dd56/onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:0e906e6a83437de05f8139ea7eaf366bf287f44ae5cc44b2850a30e296421f2f", size = 16651271 },
+ { url = "https://files.pythonhosted.org/packages/f0/6c/f040652277f514ecd81b7251841f96caa5538365af7df07f86c6018cda2b/onnx-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d955ba2939878a520a97614bcf2e79c1df71b29203e8ced478fa78c9a9c63c2", size = 15907522 },
+ { url = "https://files.pythonhosted.org/packages/3d/7c/67f4952d1b56b3f74a154b97d0dd0630d525923b354db117d04823b8b49b/onnx-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f3fb5cc4e2898ac5312a7dc03a65133dd2abf9a5e520e69afb880a7251ec97a", size = 16046307 },
+ { url = "https://files.pythonhosted.org/packages/ae/20/6da11042d2ab870dfb4ce4a6b52354d7651b6b4112038b6d2229ab9904c4/onnx-1.17.0-cp312-cp312-win32.whl", hash = "sha256:317870fca3349d19325a4b7d1b5628f6de3811e9710b1e3665c68b073d0e68d7", size = 14424235 },
+ { url = "https://files.pythonhosted.org/packages/35/55/c4d11bee1fdb0c4bd84b4e3562ff811a19b63266816870ae1f95567aa6e1/onnx-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:659b8232d627a5460d74fd3c96947ae83db6d03f035ac633e20cd69cfa029227", size = 14530453 },
+ { url = "https://files.pythonhosted.org/packages/49/e1/c5301ff2afa4c473d32a4e9f1bed5c589cfc4947c79002a00183f4cc0fa1/onnx-1.17.0-cp39-cp39-macosx_12_0_universal2.whl", hash = "sha256:67e1c59034d89fff43b5301b6178222e54156eadd6ab4cd78ddc34b2f6274a66", size = 16645989 },
+ { url = "https://files.pythonhosted.org/packages/61/94/d753c230d56234dd01ad939590a2ed33221b57c61abe513ff6823a69af6e/onnx-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e19fd064b297f7773b4c1150f9ce6213e6d7d041d7a9201c0d348041009cdcd", size = 15908316 },
+ { url = "https://files.pythonhosted.org/packages/3d/da/c19d0f20d310045f4701d75ecba4f765153251d48a32f27a5d6b0a7e3799/onnx-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8167295f576055158a966161f8ef327cb491c06ede96cc23392be6022071b6ed", size = 16046488 },
+ { url = "https://files.pythonhosted.org/packages/57/1a/79623a6cd305dfcd21888747364994109dfcb6194343157cb8653f1612dc/onnx-1.17.0-cp39-cp39-win32.whl", hash = "sha256:76884fe3e0258c911c749d7d09667fb173365fd27ee66fcedaf9fa039210fd13", size = 14423724 },
+ { url = "https://files.pythonhosted.org/packages/57/8e/ce0e20200bdf8e8b47679cd56efb1057aa218b29ccdf60a3b4fb6b91064c/onnx-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:5ca7a0894a86d028d509cdcf99ed1864e19bfe5727b44322c11691d834a1c546", size = 14524172 },
+]
+
[[package]]
name = "packaging"
version = "24.1"
@@ -884,6 +922,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a9/6a/fd08d94654f7e67c52ca30523a178b3f8ccc4237fce4be90d39c938a831a/prompt_toolkit-3.0.48-py3-none-any.whl", hash = "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e", size = 386595 },
]
+[[package]]
+name = "protobuf"
+version = "5.28.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/74/6e/e69eb906fddcb38f8530a12f4b410699972ab7ced4e21524ece9d546ac27/protobuf-5.28.3.tar.gz", hash = "sha256:64badbc49180a5e401f373f9ce7ab1d18b63f7dd4a9cdc43c92b9f0b481cef7b", size = 422479 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d1/c5/05163fad52d7c43e124a545f1372d18266db36036377ad29de4271134a6a/protobuf-5.28.3-cp310-abi3-win32.whl", hash = "sha256:0c4eec6f987338617072592b97943fdbe30d019c56126493111cf24344c1cc24", size = 419624 },
+ { url = "https://files.pythonhosted.org/packages/9c/4c/4563ebe001ff30dca9d7ed12e471fa098d9759712980cde1fd03a3a44fb7/protobuf-5.28.3-cp310-abi3-win_amd64.whl", hash = "sha256:91fba8f445723fcf400fdbe9ca796b19d3b1242cd873907979b9ed71e4afe868", size = 431464 },
+ { url = "https://files.pythonhosted.org/packages/1c/f2/baf397f3dd1d3e4af7e3f5a0382b868d25ac068eefe1ebde05132333436c/protobuf-5.28.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a3f6857551e53ce35e60b403b8a27b0295f7d6eb63d10484f12bc6879c715687", size = 414743 },
+ { url = "https://files.pythonhosted.org/packages/85/50/cd61a358ba1601f40e7d38bcfba22e053f40ef2c50d55b55926aecc8fec7/protobuf-5.28.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:3fa2de6b8b29d12c61911505d893afe7320ce7ccba4df913e2971461fa36d584", size = 316511 },
+ { url = "https://files.pythonhosted.org/packages/5d/ae/3257b09328c0b4e59535e497b0c7537d4954038bdd53a2f0d2f49d15a7c4/protobuf-5.28.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:712319fbdddb46f21abb66cd33cb9e491a5763b2febd8f228251add221981135", size = 316624 },
+ { url = "https://files.pythonhosted.org/packages/57/b5/ee3d918f536168def73b3f49edeba065429ab3a7e7b033d33e69c46ddff9/protobuf-5.28.3-cp39-cp39-win32.whl", hash = "sha256:135658402f71bbd49500322c0f736145731b16fc79dc8f367ab544a17eab4535", size = 419648 },
+ { url = "https://files.pythonhosted.org/packages/53/54/e1bdf6f1d29828ddb6aca0a83bf208ab1d5f88126f34e17e487b2cd20d93/protobuf-5.28.3-cp39-cp39-win_amd64.whl", hash = "sha256:70585a70fc2dd4818c51287ceef5bdba6387f88a578c86d47bb34669b5552c36", size = 431591 },
+ { url = "https://files.pythonhosted.org/packages/ad/c3/2377c159e28ea89a91cf1ca223f827ae8deccb2c9c401e5ca233cd73002f/protobuf-5.28.3-py3-none-any.whl", hash = "sha256:cee1757663fa32a1ee673434fcf3bf24dd54763c79690201208bafec62f19eed", size = 169511 },
+]
+
[[package]]
name = "ptyprocess"
version = "0.7.0"