Skip to content

Commit 0bef87a

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Utilities for slicing preshuffled tensors (#4396)
Summary: X-link: facebookresearch/FBGEMM#1467 Pull Request resolved: #4396 Some integrations of fbgemm kernels and oss systems like VLLM would be made simpler by the ability to slice preshuffled tensors. Prior to this diff, there were two blockers to doing that: - Scales were required to be contiguous. This is easily addressed by more carefully setting the stride argument. - Shuffled tensors have a non-trivial layout. We add a python helper function for slicing int4 shuffled tensors. Notably, it involves some data copying that I believe is unavoidable. Hopefully it only needs to be done during model setup. Reviewed By: jiawenliu64, jianyuh Differential Revision: D77239566 fbshipit-source-id: ad8eea5eb153f851f1b1e297a566fd36c0ac6409
1 parent f22f361 commit 0bef87a

File tree

3 files changed

+59
-8
lines changed

3 files changed

+59
-8
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,51 @@ def _quantize(
181181
return wq, scales
182182

183183

184+
def shuffle_slice(
185+
x: torch.Tensor, dim: int, start: int, length: int, dtype: str = "fp8"
186+
) -> torch.Tensor:
187+
"""
188+
Helper function to slice a preshuffled int4 tensor. This is needed since the shuffling
189+
reorders rows based on the size of the input. Slicing a tensor shuffled for a larger input
190+
is no longer valid. We must reorder the tensor to the appropriate size then slice.
191+
Args:
192+
x (Tensor): [N, K // 2] Preshuffled int4 tensor.
193+
dim (int): Dimension to slice.
194+
start (int): Start of slice.
195+
length (int): Number of elements to slice in the original [N, K] dimension.
196+
dtype (str): Type of corresponding activations. Must be fp8 or bf16.
197+
Returns:
198+
sliced (Tensor): [stop-start, K // 2] Sliced tensor.
199+
"""
200+
# Get the size of the input tensor.
201+
assert dim in [x.ndim - 2, x.ndim - 1], "Only slicing along N or K is supported."
202+
assert length % 16 == 0, "Slicing must be a multiple of 16."
203+
orig_shape = x.shape
204+
N = x.shape[-2]
205+
K = x.shape[-1]
206+
# Tile shape is based on the activation dtype.
207+
assert dtype in ("fp8", "bf16"), "Only fp8 and bf16 activations supported."
208+
# Handle slice along M
209+
if dim == x.ndim - 2:
210+
tile_shape = 8 if dtype == "fp8" else 16
211+
block_size = N // length
212+
# View the shape in terms of shuffled tiles then permute to allow slicing.
213+
x_s = x.view(-1, tile_shape, block_size, length // tile_shape, K)
214+
x_s = x_s.permute(0, 2, 1, 3, 4).contiguous().view(-1, N, K)
215+
out_slice = x_s.narrow(1, start, length)
216+
# Reshape back to original shape.
217+
return out_slice.view(*orig_shape[:-2], length, K)
218+
# Handle slice along K
219+
else:
220+
outer_dim = x.view(-1, N, K).shape[0]
221+
x_s = x.view(outer_dim, -1, length // 2)
222+
row_factor = x_s.shape[1] * (length // 2) // K
223+
# Take slices of rows corresponding to column slice.
224+
return x_s.narrow(1, start * 2 * K // length, row_factor).view(
225+
*orig_shape[:-2], N, length // 2
226+
)
227+
228+
184229
def scale_nvfp4_quant(
185230
input: torch.Tensor, input_global_scale: torch.Tensor
186231
) -> Tuple[torch.Tensor, torch.Tensor]:

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,13 @@ at::Tensor bf16i4bf16_dispatch(
260260
"and be contiguous on GPU.");
261261
// Make sure group scales and zeros are in proper format.
262262
TORCH_CHECK(
263-
w_scale_group.dim() == 2 && w_scale_group.size(1) == N,
264-
"Group scales are expected to have shape [num_groups, N].");
263+
w_scale_group.dim() == 2 && w_scale_group.size(1) == N &&
264+
w_scale_group.is_cuda() && w_scale_group.is_contiguous(),
265+
"Group scales are expected to have shape [num_groups, N] and be contiguous on GPU.");
266+
TORCH_CHECK(
267+
w_zero_group.dim() == 2 && w_zero_group.size(1) == N &&
268+
w_zero_group.is_cuda() && w_zero_group.is_contiguous(),
269+
"Group zeros are expected to have shape [num_groups, N] and be contiguous on GPU.");
265270

266271
// Allocate output or return an empty tensor if input is empty.
267272
if (M == 0 || N == 0 || K == 0) {

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,18 @@ at::Tensor f8i4bf16_shuffled(
278278
"and be contiguous on GPU.");
279279
TORCH_CHECK(
280280
x_scale.numel() == M && x_scale.dtype() == at::kFloat &&
281-
x_scale.is_cuda(),
282-
"x_scale must be fp32 and have M total elements.");
281+
x_scale.is_cuda() && x_scale.is_contiguous(),
282+
"x_scale must be fp32 and have M total elements and be contiguous.");
283283
TORCH_CHECK(
284284
w_scale.numel() == N && w_scale.dtype() == at::kFloat &&
285-
w_scale.is_cuda(),
286-
"Weight row scale should have N elements and be on GPU.");
285+
w_scale.is_cuda() && w_scale.is_contiguous(),
286+
"Weight row scale should have N elements and be contiguous on GPU.");
287287
// Make sure w_scale_group is in proper format.
288288
TORCH_CHECK(
289289
w_scale_group.dtype() == at::kFloat8_e4m3fn && w_scale_group.dim() == 3 &&
290-
w_scale_group.size(1) == 8 && w_scale_group.size(2) == N,
291-
"Weights and group scales must be prepacked with preshuffle_i4. "
290+
w_scale_group.size(1) == 8 && w_scale_group.size(2) == N &&
291+
w_scale_group.is_contiguous(),
292+
"Weights and group scales must be contiguous and prepacked with preshuffle_i4. "
292293
"Group scales are expected to be FP8 and have shape [num_groups, 8, N].");
293294

294295
// Allocate output or return an empty tensor if input is empty.

0 commit comments

Comments
 (0)