-
Notifications
You must be signed in to change notification settings - Fork 296
Add Float8Tensor #2463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Float8Tensor #2463
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from torch.testing._internal import common_utils | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
run_tests, | ||
) | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from torchao.utils import _is_fbgemm_genai_gpu_available, is_sm_at_least_90 | ||
|
||
_MODEL_NAMES = [ | ||
"torchao-testing/opt-125m-float8dq-row-fbgemm", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think model name here should specify the relevant versions also, IMO this should be a toy model with a single layer with matching done on the layer output, to make it 100x easier to debug when things do go wrong. It's fine to also have a real model and match tokens, but I think it's more important to have a toy model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. single linear for debugability makes sense, although I'm not sure how we can get a toy model with a single linear in huggingface transformers actually, I can add version but can revisit on getting a single layer |
||
] | ||
|
||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") | ||
@unittest.skip("temporary skip since we have some refactor next") | ||
class TestSerializationBC(TestCase): | ||
"""Test we can still load and run serialized model in previous AO versions | ||
we commit to have BC for 3 pytorch releases | ||
""" | ||
|
||
@common_utils.parametrize("model_name", _MODEL_NAMES) | ||
def test_load_model_and_run(self, model_name): | ||
if "fbgemm" in model_name and not _is_fbgemm_genai_gpu_available(): | ||
# TODO: this is not enabled in CI, enable this after new fbgemm releases | ||
print("can't run fbgemm model without fbgemm_genai_gpu installed") | ||
return | ||
# Load and quantize model | ||
quantized_model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
torch_dtype="bfloat16", | ||
device_map="cuda", | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
prompt = ("Hello, my name is",) | ||
|
||
inputs = tokenizer( | ||
prompt, | ||
return_tensors="pt", | ||
).to("cuda") | ||
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) | ||
# make sure it runs | ||
_ = tokenizer.batch_decode( | ||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
|
||
|
||
common_utils.instantiate_parametrized_tests(TestSerializationBC) | ||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does that practically mean, is static_quant no broken after this PR
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static quant is not migrated yet, it won't break