Skip to content

Commit 16e2d0a

Browse files
authored
Add support for bmm and to for fbgemm Tensor (#2337)
Add support for bmm for fbgemm config Summary: att, this PR adds support for running quantized bmm, the quantized bmm kernel for int4 and fp8 (with dynamic activation quantization) requires transpose of weights in order to run, so added transpose_input to the convert function to transpose the weights first Test Plan: python test/dtypes/test_fbgemm_fp8.py -k test_bmm python test/dtypes/test_fbgemm_int4.py -k test_bmm Reviewers: Subscribers: Tasks: Tags:
1 parent 769ffa5 commit 16e2d0a

File tree

6 files changed

+173
-12
lines changed

6 files changed

+173
-12
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def setUp(self):
3434
weight_dtype=e4m3_dtype,
3535
output_dtype=torch.bfloat16,
3636
)
37+
self.bmm_config = FbgemmConfig(
38+
input_dtype=e4m3_dtype,
39+
weight_dtype=e4m3_dtype,
40+
output_dtype=torch.bfloat16,
41+
transpose_input=True,
42+
)
43+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
3744

3845
def test_linear(self):
3946
dtype = torch.bfloat16
@@ -106,6 +113,39 @@ def test_slice_and_copy_(self):
106113
# making sure param.data is updated
107114
assert param.data.float8_data[0][0] != orig_value
108115

116+
def test_bmm(self):
117+
class M(torch.nn.Module):
118+
def __init__(self, weight):
119+
super().__init__()
120+
self.weight = weight
121+
122+
def forward(self, x):
123+
return torch.bmm(x, self.weight)
124+
125+
dtype = torch.bfloat16
126+
device = "cuda"
127+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
128+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
129+
m = M(weight).eval()
130+
original = m(input)
131+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
132+
quantized = m(input)
133+
self.assertTrue(compute_error(original, quantized) > 20)
134+
135+
def test_to_device(self):
136+
for device in self.GPU_DEVICES:
137+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
138+
quantize_(linear, self.config)
139+
linear.to(device)
140+
141+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
142+
quantize_(linear, self.config)
143+
linear.to(device=device)
144+
145+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
146+
quantize_(linear, self.config)
147+
linear.to(device)
148+
109149

110150
if __name__ == "__main__":
111151
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def setUp(self):
3434
output_dtype=torch.bfloat16,
3535
block_size=[1, 128],
3636
)
37+
self.bmm_config = FbgemmConfig(
38+
input_dtype=torch.bfloat16,
39+
weight_dtype=torch.int4,
40+
output_dtype=torch.bfloat16,
41+
block_size=[1, 1, 128],
42+
transpose_input=True,
43+
)
44+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
3745

3846
def test_linear(self):
3947
dtype = torch.bfloat16
@@ -111,6 +119,39 @@ def test_slice_and_copy_(self):
111119
# making sure param.data is updated
112120
assert param.data.packed_weight[0][0] != orig_value
113121

122+
def test_bmm(self):
123+
class M(torch.nn.Module):
124+
def __init__(self, weight):
125+
super().__init__()
126+
self.weight = weight
127+
128+
def forward(self, x):
129+
return torch.bmm(x, self.weight)
130+
131+
dtype = torch.bfloat16
132+
device = "cuda"
133+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
134+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
135+
m = M(weight).eval()
136+
original = m(input)
137+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
138+
quantized = m(input)
139+
self.assertTrue(compute_error(original, quantized) > 18)
140+
141+
def test_to_device(self):
142+
for device in self.GPU_DEVICES:
143+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
144+
quantize_(linear, self.config)
145+
linear.to(device)
146+
147+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
148+
quantize_(linear, self.config)
149+
linear.to(device=device)
150+
151+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
152+
quantize_(linear, self.config)
153+
linear.to(device)
154+
114155

115156
if __name__ == "__main__":
116157
run_tests()

torchao/dtypes/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
to_affine_quantized_intx,
99
to_affine_quantized_intx_static,
1010
)
11-
from .fbgemm_fp8_tensor import to_fbgemm_fp8
12-
from .fbgemm_int4_tensor import to_fbgemm_int4
11+
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
12+
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
1313
from .floatx import (
1414
CutlassSemiSparseLayout,
1515
Float8Layout,
@@ -64,5 +64,7 @@
6464
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
6565
"Int4XPULayout",
6666
"to_fbgemm_int4",
67+
"FbgemmInt4Tensor",
6768
"to_fbgemm_fp8",
69+
"FbgemmFp8Tensor",
6870
]

torchao/dtypes/fbgemm_fp8_tensor.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
__all__ = [
2020
"to_fbgemm_fp8",
21+
"FbgemmFp8Tensor",
2122
]
2223

2324
aten = torch.ops.aten
@@ -74,11 +75,22 @@ def __repr__(self):
7475
def _quantization_type(self):
7576
return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}"
7677

78+
def to(self, *args, **kwargs):
79+
kwargs = self._get_to_kwargs(*args, **kwargs)
80+
device = kwargs.pop("device")
81+
return self.__class__(
82+
self.float8_data.to(device),
83+
self.scale.to(device),
84+
self.activation_scale_ub.to(device),
85+
self.dtype,
86+
)
87+
7788
@classmethod
7889
def from_float(
7990
cls,
8091
w: torch.Tensor,
8192
activation_scale_ub: Optional[float] = None,
93+
transpose_input: bool = False,
8294
):
8395
if activation_scale_ub is None:
8496
activation_scale_ub = 1200.0
@@ -88,6 +100,12 @@ def from_float(
88100
dtype=torch.float,
89101
device=w.device,
90102
)
103+
if transpose_input:
104+
if w.ndim == 3:
105+
w = w.transpose(-1, -2)
106+
else:
107+
w = w.t()
108+
91109
wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
92110
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
93111
dtype = w.dtype
@@ -110,11 +128,6 @@ def _(func, types, args, kwargs):
110128
args[1],
111129
args[2] if len(args) > 2 else None,
112130
)
113-
if not input_tensor.is_floating_point():
114-
raise NotImplementedError(
115-
f"{func} is not implemented for non floating point input"
116-
)
117-
118131
orig_act_size = input_tensor.size()
119132
orig_out_features = weight_tensor.shape[-2]
120133

@@ -141,6 +154,33 @@ def _(func, types, args, kwargs):
141154
return res
142155

143156

157+
@implements(torch.bmm)
158+
def _(func, types, args, kwargs):
159+
input_tensor, weight_tensor = (
160+
args[0],
161+
args[1],
162+
)
163+
orig_act_size = input_tensor.size()
164+
# not used
165+
num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
166+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
167+
input_tensor, num_tokens, weight_tensor.activation_scale_ub
168+
)
169+
170+
a_data = xq
171+
b_data = weight_tensor.float8_data
172+
orig_out_features = b_data.shape[-2]
173+
174+
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
175+
a_data,
176+
b_data,
177+
x_scale,
178+
weight_tensor.scale,
179+
)
180+
res = res.reshape(*orig_act_size[:-1], orig_out_features)
181+
return res
182+
183+
144184
@implements([aten.detach.default, aten.alias.default])
145185
def _(func, types, args, kwargs):
146186
return return_and_correct_aliasing(

torchao/dtypes/fbgemm_int4_tensor.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
__all__ = [
2121
"to_fbgemm_int4",
22+
"FbgemmInt4Tensor",
2223
]
2324

2425
aten = torch.ops.aten
@@ -77,18 +78,36 @@ def __repr__(self):
7778
def _quantization_type(self):
7879
return f"shape={self.shape}, group_size={self.group_size}, device={self.device}"
7980

81+
def to(self, *args, **kwargs):
82+
kwargs = self._get_to_kwargs(*args, **kwargs)
83+
device = kwargs.pop("device")
84+
return self.__class__(
85+
self.packed_weight.to(device),
86+
self.scale.to(device),
87+
self.zero_point.to(device),
88+
self.group_size,
89+
self.shape,
90+
)
91+
8092
@classmethod
8193
def from_float(
8294
cls,
8395
w: torch.Tensor,
8496
block_size: List[int],
97+
transpose_input: bool = False,
8598
):
8699
assert len(block_size) == w.ndim, (
87100
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
88101
)
89102
if int4_row_quantize_zp is None:
90103
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
91104

105+
if transpose_input:
106+
if w.ndim == 3:
107+
w = w.transpose(-1, -2)
108+
else:
109+
w = w.t()
110+
92111
group_size = block_size[-1]
93112
original_shape = w.shape
94113

@@ -126,11 +145,6 @@ def _(func, types, args, kwargs):
126145
args[1],
127146
args[2] if len(args) > 2 else None,
128147
)
129-
if not input_tensor.is_floating_point():
130-
raise NotImplementedError(
131-
f"{func} is not implemented for non floating point input"
132-
)
133-
134148
orig_act_size = input_tensor.size()
135149
orig_out_features = weight_tensor.shape[-2]
136150

@@ -146,6 +160,25 @@ def _(func, types, args, kwargs):
146160
return res
147161

148162

163+
@implements(torch.bmm)
164+
def _(func, types, args, kwargs):
165+
input_tensor, weight_tensor = (
166+
args[0],
167+
args[1],
168+
)
169+
orig_act_size = input_tensor.size()
170+
orig_out_features = weight_tensor.shape[-2]
171+
172+
res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(
173+
input_tensor,
174+
weight_tensor.packed_weight.contiguous(),
175+
weight_tensor.scale,
176+
weight_tensor.zero_point,
177+
)
178+
res = res.reshape(*orig_act_size[:-1], orig_out_features)
179+
return res
180+
181+
149182
@implements([aten.detach.default, aten.alias.default])
150183
def _(func, types, args, kwargs):
151184
return return_and_correct_aliasing(

torchao/quantization/quant_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,7 @@ class FbgemmConfig(AOBaseConfig):
19911991
output_dtype: torch.dtype
19921992
block_size: Optional[List[int]] = None
19931993
activation_scale_ub: Optional[float] = None
1994+
transpose_input: bool = False
19941995

19951996

19961997
@register_quantize_module_handler(FbgemmConfig)
@@ -2018,9 +2019,11 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20182019
weight = to_fbgemm_int4(
20192020
module.weight,
20202021
config.block_size,
2022+
config.transpose_input,
20212023
)
20222024
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20232025
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2026+
return module
20242027
elif (
20252028
(config.input_dtype == e4m3_dtype)
20262029
and (config.weight_dtype == e4m3_dtype)
@@ -2029,9 +2032,11 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20292032
weight = to_fbgemm_fp8(
20302033
module.weight,
20312034
config.activation_scale_ub,
2035+
config.transpose_input,
20322036
)
20332037
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20342038
module.extra_repr = types.MethodType(_linear_extra_repr, module)
2039+
return module
20352040
else:
20362041
raise NotImplementedError(
20372042
f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}"

0 commit comments

Comments
 (0)