Skip to content

Commit bdaad8e

Browse files
Guang Yangfacebook-github-bot
authored andcommitted
Add dtype arg to the script for exporting HuggingFace models (#5716)
Summary: As titled. This will unblock delegating to XNNPACK w/ float16 and bfloat16, which provides a comparable perf data points against egear, torch.compile, and AOTI, etc. Pull Request resolved: #5716 Reviewed By: kirklandsign Differential Revision: D63499648 Pulled By: guangy10 fbshipit-source-id: 5a06454f8af664e6d5f469dcf63869ca7c57a6ba
1 parent c1c5080 commit bdaad8e

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

extension/export_util/export_hf_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ def main() -> None:
2727
default=None,
2828
help="a valid huggingface model repo name",
2929
)
30+
parser.add_argument(
31+
"-d",
32+
"--dtype",
33+
type=str,
34+
choices=["float32", "float16", "bfloat16"],
35+
default="float32",
36+
help="specify the dtype for loading the model",
37+
)
3038
parser.add_argument(
3139
"-o",
3240
"--output_name",
@@ -39,7 +47,8 @@ def main() -> None:
3947

4048
# Configs to HF model
4149
device = "cpu"
42-
dtype = torch.float32
50+
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
51+
dtype = getattr(torch, args.dtype)
4352
batch_size = 1
4453
max_length = 123
4554
cache_implementation = "static"

0 commit comments

Comments
 (0)