File tree Expand file tree Collapse file tree 2 files changed +12
-11
lines changed Expand file tree Collapse file tree 2 files changed +12
-11
lines changed Original file line number Diff line number Diff line change 26
26
from torch .nn import functional as F
27
27
28
28
from neural_compressor .torch .utils import (
29
- Version ,
30
29
accelerator ,
31
30
can_pack_with_numba ,
32
- get_hpex_version ,
31
+ is_hpex_support_g_idx ,
33
32
logger ,
34
33
)
35
34
@@ -731,7 +730,7 @@ def __init__(
731
730
)
732
731
else :
733
732
self .g_idx = None
734
- self .support_g_idx = True if get_hpex_version () >= Version ( "1.23.0" ) else False
733
+ self .support_g_idx = is_hpex_support_g_idx ()
735
734
736
735
self .half_indim = self .in_features // 2
737
736
Original file line number Diff line number Diff line change 17
17
import importlib
18
18
import os
19
19
import sys
20
+ from functools import lru_cache
20
21
21
22
import torch
22
23
from packaging .version import Version
@@ -79,19 +80,20 @@ def is_hpu_available():
79
80
return get_accelerator ().name () == "hpu"
80
81
81
82
82
- def get_hpex_version ():
83
- """Return ipex version if ipex exists."""
83
+ @lru_cache (None )
84
+ def is_hpex_support_g_idx ():
85
+ """Check if HPEX supports group_index in the schema of hpu::convert_from_int4."""
84
86
if is_hpex_available ():
85
87
try :
86
88
import habana_frameworks .torch
89
+ import torch
87
90
88
- hpex_version = habana_frameworks .torch .__version__
89
- except ValueError as e : # pragma: no cover
90
- assert False , "Got an unknown version of habana_frameworks.torch: {}" .format (e )
91
- version = Version (hpex_version )
92
- return version
91
+ schema = torch ._C ._get_schema ("hpu::convert_from_int4" , "" )
92
+ return "group_index" in str (schema )
93
+ except : # pragma: no cover
94
+ return False
93
95
else :
94
- return None
96
+ return False
95
97
96
98
97
99
## check optimum
You can’t perform that action at this time.
0 commit comments