Skip to content

Matmul errors out when one tensor is batched and another isn't #1

@rationalism

Description

@rationalism

Cool idea! Proud to submit a first bug report :)

This PyTorch code (Ubuntu, CUDA 12.1, Torch 2.2.2, Nvidia 4090):

>>> import cublas_ops
>>> import torch
>>> x = torch.ones([1, 2560, 8192], dtype=torch.float16, device="cuda:0")
>>> x
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]], device='cuda:0',
       dtype=torch.float16)
>>> y = torch.ones([8192, 28672], dtype=torch.float16, device="cuda:0")
>>> z = cublas_ops.cublas_half_matmul_batched_simple(x, y)

fails with this stack trace:

 ** On entry to HgemmStridedBatched parameter number 10 had an illegal value
cuBLAS API failed with status 7
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/cublas_ops/__init__.py", line 39, in cublas_half_matmul_batched_simple
    return _cublas_hgemm_batched_simple(a, b)
RuntimeError: cuBLAS API failed

but this code works:

>>> y = torch.ones([1, 8192, 28672], dtype=torch.float16, device="cuda:0")
>>> z = cublas_ops.cublas_half_matmul_batched_simple(x, y)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions