Skip to content

Commit 1cc864e

Browse files
xin3heCopilot
andauthored
better solution for checking g_idx support (#2251)
Signed-off-by: Xin He <xinhe3@habana.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 0b63176 commit 1cc864e

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626
from torch.nn import functional as F
2727

2828
from neural_compressor.torch.utils import (
29-
Version,
3029
accelerator,
3130
can_pack_with_numba,
32-
get_hpex_version,
31+
is_hpex_support_g_idx,
3332
logger,
3433
)
3534

@@ -731,7 +730,7 @@ def __init__(
731730
)
732731
else:
733732
self.g_idx = None
734-
self.support_g_idx = True if get_hpex_version() >= Version("1.23.0") else False
733+
self.support_g_idx = is_hpex_support_g_idx()
735734

736735
self.half_indim = self.in_features // 2
737736

neural_compressor/torch/utils/environ.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import importlib
1818
import os
1919
import sys
20+
from functools import lru_cache
2021

2122
import torch
2223
from packaging.version import Version
@@ -79,19 +80,20 @@ def is_hpu_available():
7980
return get_accelerator().name() == "hpu"
8081

8182

82-
def get_hpex_version():
83-
"""Return ipex version if ipex exists."""
83+
@lru_cache(None)
84+
def is_hpex_support_g_idx():
85+
"""Check if HPEX supports group_index in the schema of hpu::convert_from_int4."""
8486
if is_hpex_available():
8587
try:
8688
import habana_frameworks.torch
89+
import torch
8790

88-
hpex_version = habana_frameworks.torch.__version__
89-
except ValueError as e: # pragma: no cover
90-
assert False, "Got an unknown version of habana_frameworks.torch: {}".format(e)
91-
version = Version(hpex_version)
92-
return version
91+
schema = torch._C._get_schema("hpu::convert_from_int4", "")
92+
return "group_index" in str(schema)
93+
except: # pragma: no cover
94+
return False
9395
else:
94-
return None
96+
return False
9597

9698

9799
## check optimum

0 commit comments

Comments
 (0)