Skip to content

Commit de7d3b6

Browse files
author
jax authors
committed
Merge pull request #20816 from superbobry:int4
PiperOrigin-RevId: 626131730
2 parents c2d4373 + a13efc2 commit de7d3b6

File tree

5 files changed

+33
-22
lines changed

5 files changed

+33
-22
lines changed

jax/_src/test_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,15 +1289,17 @@ def integer(self):
12891289

12901290
@_cached_property
12911291
def all_integer(self):
1292-
return self.supported([np.int8, np.int16, np.int32, np.int64])
1292+
return self.supported([
1293+
_dtypes.int4, np.int8, np.int16, np.int32, np.int64])
12931294

12941295
@_cached_property
12951296
def unsigned(self):
12961297
return self.supported([np.uint32, np.uint64])
12971298

12981299
@_cached_property
12991300
def all_unsigned(self):
1300-
return self.supported([np.uint8, np.uint16, np.uint32, np.uint64])
1301+
return self.supported([
1302+
_dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64])
13011303

13021304
@_cached_property
13031305
def complex(self):

jax/tools/jax_to_ir.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,14 @@ def parse_shape_str(s):
238238
shape = ()
239239
return jax.core.ShapedArray(shape, dtype)
240240

241-
_DT = {'pred': jnp.bool_,
242-
'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
243-
's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
244-
'bf16': jnp.bfloat16,
245-
'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64,
246-
'c64': jnp.complex64, 'c128': jnp.complex128}
241+
_DT = {
242+
'pred': jnp.bool_,
243+
'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
244+
's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
245+
'bf16': jnp.bfloat16,
246+
'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64,
247+
'c64': jnp.complex64, 'c128': jnp.complex128
248+
}
247249
_SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
248250

249251

tests/dtypes_test.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
all_dtypes = (bool_dtypes + signed_dtypes + unsigned_dtypes + float_dtypes +
6868
complex_dtypes)
6969

70-
scalar_types = [jnp.bool_, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
71-
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
70+
scalar_types = [jnp.bool_, jnp.int4, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
71+
jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
7272
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
7373
jnp.complex64, jnp.complex128]
7474

@@ -94,6 +94,7 @@
9494
_EXPECTED_CANONICALIZE_X32[np.longlong] = np.int32
9595

9696
UINT_DTYPES = {
97+
4: jnp.uint4,
9798
8: np.uint8,
9899
16: np.uint16,
99100
32: np.uint32,
@@ -284,13 +285,15 @@ def testIsSubdtype(self):
284285
self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t))
285286
self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type))
286287
self.assertTrue(dtypes.issubdtype(t, np.dtype(t)))
287-
if t != jnp.bfloat16:
288-
for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger,
289-
jnp.unsignedinteger, jnp.floating, jnp.complexfloating]:
290-
self.assertEqual(dtypes.issubdtype(t, category),
291-
np.issubdtype(np.dtype(t).type, category))
292-
self.assertEqual(dtypes.issubdtype(t, category),
293-
np.issubdtype(np.dtype(t).type, category))
288+
if t in [jnp.int4, jnp.uint4, jnp.bfloat16]:
289+
# These dtype have no equivalent in NumPy.
290+
continue
291+
for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger,
292+
jnp.unsignedinteger, jnp.floating, jnp.complexfloating]:
293+
self.assertEqual(dtypes.issubdtype(t, category),
294+
np.issubdtype(np.dtype(t).type, category))
295+
self.assertEqual(dtypes.issubdtype(t, category),
296+
np.issubdtype(np.dtype(t).type, category))
294297

295298
def testIsSubdtypeExtended(self):
296299
self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended))

tests/jax_to_ir_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,12 @@ def test_parse_shape_str(self):
119119
self.assertParsedShape('f32[]', [], jnp.float32)
120120
self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32)
121121
self.assertParsedShape('pred[1]', [1], jnp.bool_)
122+
self.assertParsedShape('s4[1]', [1], jnp.int4)
122123
self.assertParsedShape('s8[1]', [1], jnp.int8)
123124
self.assertParsedShape('s16[1]', [1], jnp.int16)
124125
self.assertParsedShape('s32[1]', [1], jnp.int32)
125126
self.assertParsedShape('s64[1]', [1], jnp.int64)
127+
self.assertParsedShape('u4[1]', [1], jnp.uint4)
126128
self.assertParsedShape('u8[1]', [1], jnp.uint8)
127129
self.assertParsedShape('u16[1]', [1], jnp.uint16)
128130
self.assertParsedShape('u32[1]', [1], jnp.uint32)

tests/lax_numpy_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,13 @@ def testUnstack(self, shape, axis, dtype):
195195

196196

197197
@parameterized.parameters(
198-
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
199-
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
200-
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
201-
jnp.complex64, jnp.complex128]
202-
if dtype == dtypes.canonicalize_dtype(dtype)])
198+
[dtype for dtype in [
199+
jnp.bool,
200+
jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
201+
jnp.int4, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
202+
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
203+
jnp.complex64, jnp.complex128]
204+
if dtype == dtypes.canonicalize_dtype(dtype)])
203205
def testDtypeWrappers(self, dtype):
204206
arr = dtype(0)
205207
self.assertIsInstance(arr, jax.Array)

0 commit comments

Comments
 (0)