@@ -91,14 +91,6 @@ inline bool torch_tensor_empty_or_on_cpu_check(
91
91
#x " must be empty or a CPU tensor; it is currently on device " , \
92
92
torch_tensor_device_name (x))
93
93
94
- #define TENSORS_HAVE_SAME_TYPE (x, y ) \
95
- TORCH_CHECK ( \
96
- (x).dtype() == (y).dtype(), \
97
- #x " must have the same type as " #y " types were " , \
98
- (x).dtype().name(), \
99
- " and ", \
100
- (y).dtype().name())
101
-
102
94
#define TENSOR_ON_CUDA_GPU (x ) \
103
95
TORCH_CHECK ( \
104
96
torch_tensor_on_cuda_gpu_check (x), \
@@ -144,6 +136,10 @@ inline bool torch_tensor_empty_or_on_cpu_check(
144
136
#define TENSOR_CONTIGUOUS (x ) \
145
137
TORCH_CHECK ((x).is_contiguous(), #x " must be contiguous" )
146
138
139
+ #define TENSOR_CONTIGUOUS_AND_ON_CPU (x ) \
140
+ TENSOR_ON_CPU (x); \
141
+ TENSOR_CONTIGUOUS (x)
142
+
147
143
#define TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU (x ) \
148
144
TENSOR_ON_CUDA_GPU (x); \
149
145
TENSOR_CONTIGUOUS (x)
@@ -156,6 +152,28 @@ inline bool torch_tensor_empty_or_on_cpu_check(
156
152
" Found " , \
157
153
(ten).ndimension())
158
154
155
+ #define TENSOR_TYPE_MUST_BE (ten, typ ) \
156
+ TORCH_CHECK ( \
157
+ (ten).scalar_type() == typ, \
158
+ "Tensor '" #ten " ' must have scalar type " #typ " but it had type " , \
159
+ (ten).dtype().name())
160
+
161
+ #define TENSOR_NDIM_EXCEEDS (ten, dims ) \
162
+ TORCH_CHECK ( \
163
+ (ten).dim() > (dims), \
164
+ "Tensor '" #ten " ' must have more than " #dims \
165
+ " dimension(s). " \
166
+ " Found " , \
167
+ (ten).ndimension())
168
+
169
+ #define TENSORS_HAVE_SAME_NUMEL (x, y ) \
170
+ TORCH_CHECK ( \
171
+ (x).numel() == (y).numel(), \
172
+ #x " must have the same number of elements as " #y " They had " , \
173
+ (x).numel(), \
174
+ " and " , \
175
+ (y).numel())
176
+
159
177
// / Determine an appropriate CUDA block count along the x axis
160
178
// /
161
179
// / When launching CUDA kernels the number of blocks B is often calculated
0 commit comments