@@ -620,20 +620,20 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
620
620
const int64_t fp8_exponent_bias
621
621
) {
622
622
TENSOR_ON_CUDA_GPU (dev_weights);
623
- TENSOR_ON_CUDA_GPU (uvm_weights);
624
- TENSOR_ON_CUDA_GPU (weights_placements);
625
- TENSOR_ON_CUDA_GPU (weights_offsets);
626
- TENSOR_ON_CUDA_GPU (weights_tys);
623
+ TENSORS_ON_SAME_DEVICE (uvm_weights, dev_weights );
624
+ TENSORS_ON_SAME_DEVICE (weights_placements, dev_weights );
625
+ TENSORS_ON_SAME_DEVICE (weights_offsets, dev_weights );
626
+ TENSORS_ON_SAME_DEVICE (weights_tys, dev_weights );
627
627
{% if not nobag %}
628
- TENSOR_ON_CUDA_GPU (D_offsets);
628
+ TENSORS_ON_SAME_DEVICE (D_offsets, dev_weights );
629
629
{% endif %}
630
- TENSOR_ON_CUDA_GPU (indices);
631
- TENSOR_ON_CUDA_GPU (offsets);
630
+ TENSORS_ON_SAME_DEVICE (indices, dev_weights );
631
+ TENSORS_ON_SAME_DEVICE (offsets, dev_weights );
632
632
{% if weighted %}
633
- TENSOR_EMPTY_OR_ON_CUDA_GPU (indice_weights);
633
+ TENSORS_EMPTY_OR_ON_SAME_DEVICE (indice_weights, dev_weights );
634
634
{% endif %}
635
- TENSOR_EMPTY_OR_ON_CUDA_GPU (lxu_cache_weights);
636
- TENSOR_EMPTY_OR_ON_CUDA_GPU (lxu_cache_locations);
635
+ TENSORS_EMPTY_OR_ON_SAME_DEVICE (lxu_cache_weights, dev_weights );
636
+ TENSORS_EMPTY_OR_ON_SAME_DEVICE (lxu_cache_locations, dev_weights );
637
637
638
638
at::cuda::OptionalCUDAGuard device_guard;
639
639
device_guard.set_index (dev_weights.get_device ());
0 commit comments