From b68159758334d6ee6fc9aa302a7755207284f4de Mon Sep 17 00:00:00 2001 From: Chia-Jung Chang Date: Thu, 17 Apr 2025 13:06:34 -0700 Subject: [PATCH] Enable InputRecorder to run in offline mode and cuda device (#2067) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/2067 Enable `InputRecorder` to run in offline mode and cuda device. Specificlaly `InputRecorder` is called during Sandcastle CI tests, which do not allow us to access HuggingFace website. Differential Revision: D73179650 --- torchao/_models/_eval.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 9f429278e3..48e3ced71a 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -183,18 +183,27 @@ def __init__( vocab_size=32000, pad_token=0, device="cpu", + online_access=True, ): - try: - super().__init__() - except TypeError: - # lm_eval 0.4.2 removed the default init - super().__init__("gpt2", device="cpu") + if online_access: + try: + super().__init__() + except TypeError: + # lm_eval 0.4.2 removed the default init + super().__init__("gpt2", device=device) + else: + # Create a minimal implementation when run offline + self._device = torch.device(device) + self._max_length = 2048 + self._max_gen_toks = 256 + self._batch_size = 1 self.tokenizer = tokenizer self._device = torch.device(device) self.vocab_size = vocab_size self._max_seq_length = calibration_seq_length self.calibration_seq_length = calibration_seq_length + self.online_access = online_access # need to take inps and convert to corrent input # for model @@ -257,10 +266,11 @@ def record_inputs( calibration_tasks, calibration_limit, ): - try: - lm_eval.tasks.initialize_tasks() - except: - pass + if self.online_access: + try: + lm_eval.tasks.initialize_tasks() + except: + pass task_dict = get_task_dict(calibration_tasks) print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)