Skip to content

Commit 121c701

Browse files
Drop unnecessary torch warnings, allow numpy v2 support (#1595)
SUMMARY: Drop torch extraneous warnings, including warning about use of torch.compile on torch v2. Allow numpy v2 as a dependency (unclear why this was previously not allowed) - vllm pins torch==2.7.0 and has no restrictions on numpy - compressed-tensors has torch>=1.7.0 and no explicit numpy dependency - pytorch has no restrictions on numpy TEST PLAN: no net new src code --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 8680d95 commit 121c701

File tree

2 files changed

+3
-47
lines changed

2 files changed

+3
-47
lines changed

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ def localversion_func(version: ScmVersion) -> str:
112112
install_requires=[
113113
"loguru",
114114
"pyyaml>=5.0.0",
115-
"numpy>=1.17.0,<2.0",
115+
"numpy>=1.17.0",
116116
"requests>=2.0.0",
117117
"tqdm>=4.0.0",
118-
"torch>=1.7.0",
118+
# torch 1.10 and 1.11 do not support quantized onnx export
119+
"torch>=1.7.0,!=1.10,!=1.11",
119120
"transformers>4.0,<4.53.0",
120121
"datasets",
121122
"accelerate>=0.20.3,!=1.1.0",

src/llmcompressor/pytorch/__init__.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +0,0 @@
1-
"""
2-
Functionality for working with and sparsifying Models in the PyTorch framework
3-
"""
4-
5-
import os
6-
import warnings
7-
8-
from packaging import version
9-
10-
try:
11-
import torch
12-
13-
_PARSED_TORCH_VERSION = version.parse(torch.__version__)
14-
15-
if _PARSED_TORCH_VERSION.major >= 2:
16-
torch_compile_func = torch.compile
17-
18-
def raise_torch_compile_warning(*args, **kwargs):
19-
warnings.warn(
20-
"torch.compile is not supported by llmcompressor for torch 2.0.x"
21-
)
22-
return torch_compile_func(*args, **kwargs)
23-
24-
torch.compile = raise_torch_compile_warning
25-
26-
_BYPASS = bool(int(os.environ.get("NM_BYPASS_TORCH_VERSION", "0")))
27-
if _PARSED_TORCH_VERSION.major == 1 and _PARSED_TORCH_VERSION.minor in [10, 11]:
28-
if not _BYPASS:
29-
raise RuntimeError(
30-
"llmcompressor does not support torch==1.10.* or 1.11.*. "
31-
f"Found torch version {torch.__version__}.\n\n"
32-
"To bypass this error, set environment variable "
33-
"`NM_BYPASS_TORCH_VERSION` to '1'.\n\n"
34-
"Bypassing may result in errors or "
35-
"incorrect behavior, so set at your own risk."
36-
)
37-
else:
38-
warnings.warn(
39-
"llmcompressor quantized onnx export does not work "
40-
"with torch==1.10.* or 1.11.*"
41-
)
42-
except ImportError:
43-
pass
44-
45-
# flake8: noqa

0 commit comments

Comments
 (0)