Skip to content

Commit 3da46da

Browse files
pls331facebook-github-bot
authored andcommitted
improve validation for input tensors to guard for case where inputs coming from different device (#1615)
Summary: Pull Request resolved: #1615 Reviewed By: jianyuh, houseroad Differential Revision: D43564925 fbshipit-source-id: 9d8db49df76889e56f70ebb3fb4984c292186edc
1 parent eec6fd2 commit 3da46da

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,20 +620,20 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
620620
const int64_t fp8_exponent_bias
621621
) {
622622
TENSOR_ON_CUDA_GPU(dev_weights);
623-
TENSOR_ON_CUDA_GPU(uvm_weights);
624-
TENSOR_ON_CUDA_GPU(weights_placements);
625-
TENSOR_ON_CUDA_GPU(weights_offsets);
626-
TENSOR_ON_CUDA_GPU(weights_tys);
623+
TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights);
624+
TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights);
625+
TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights);
626+
TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights);
627627
{% if not nobag %}
628-
TENSOR_ON_CUDA_GPU(D_offsets);
628+
TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights);
629629
{% endif %}
630-
TENSOR_ON_CUDA_GPU(indices);
631-
TENSOR_ON_CUDA_GPU(offsets);
630+
TENSORS_ON_SAME_DEVICE(indices, dev_weights);
631+
TENSORS_ON_SAME_DEVICE(offsets, dev_weights);
632632
{% if weighted %}
633-
TENSOR_EMPTY_OR_ON_CUDA_GPU(indice_weights);
633+
TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights);
634634
{% endif %}
635-
TENSOR_EMPTY_OR_ON_CUDA_GPU(lxu_cache_weights);
636-
TENSOR_EMPTY_OR_ON_CUDA_GPU(lxu_cache_locations);
635+
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights);
636+
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights);
637637

638638
at::cuda::OptionalCUDAGuard device_guard;
639639
device_guard.set_index(dev_weights.get_device());

fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ inline bool torch_tensor_empty_or_on_cpu_check(
111111
#x " must be empty or a CUDA tensor; it is currently on device ", \
112112
torch_tensor_device_name(x))
113113

114+
#define TENSORS_EMPTY_OR_ON_SAME_DEVICE(x, y) \
115+
TORCH_CHECK( \
116+
torch_tensor_on_same_device_check(x, y) || (x.numel() == 0), \
117+
#x " must be empty or a CUDA tensor; it is currently on device ", \
118+
torch_tensor_device_name(x))
119+
114120
#define TENSORS_ON_SAME_DEVICE(x, y) \
115121
TORCH_CHECK( \
116122
torch_tensor_on_same_device_check(x, y), \

0 commit comments

Comments
 (0)