diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 9f429278e3..48e3ced71a 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -183,18 +183,27 @@ def __init__( vocab_size=32000, pad_token=0, device="cpu", + online_access=True, ): - try: - super().__init__() - except TypeError: - # lm_eval 0.4.2 removed the default init - super().__init__("gpt2", device="cpu") + if online_access: + try: + super().__init__() + except TypeError: + # lm_eval 0.4.2 removed the default init + super().__init__("gpt2", device=device) + else: + # Create a minimal implementation when run offline + self._device = torch.device(device) + self._max_length = 2048 + self._max_gen_toks = 256 + self._batch_size = 1 self.tokenizer = tokenizer self._device = torch.device(device) self.vocab_size = vocab_size self._max_seq_length = calibration_seq_length self.calibration_seq_length = calibration_seq_length + self.online_access = online_access # need to take inps and convert to corrent input # for model @@ -257,10 +266,11 @@ def record_inputs( calibration_tasks, calibration_limit, ): - try: - lm_eval.tasks.initialize_tasks() - except: - pass + if self.online_access: + try: + lm_eval.tasks.initialize_tasks() + except: + pass task_dict = get_task_dict(calibration_tasks) print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)