Skip to content

Commit 63f005f

Browse files
committed
Better handling of contiguous tensors
1 parent 88e2a9a commit 63f005f

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

cuda/lltm_cuda.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,20 @@ std::vector<at::Tensor> lltm_cuda_backward(
2525
// C++ interface
2626

2727
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
28+
#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous")
29+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
2830

2931
std::vector<at::Tensor> lltm_forward(
3032
at::Tensor input,
3133
at::Tensor weights,
3234
at::Tensor bias,
3335
at::Tensor old_h,
3436
at::Tensor old_cell) {
35-
CHECK_CUDA(input);
36-
CHECK_CUDA(weights);
37-
CHECK_CUDA(bias);
38-
CHECK_CUDA(old_h);
39-
CHECK_CUDA(old_cell);
37+
CHECK_INPUT(input);
38+
CHECK_INPUT(weights);
39+
CHECK_INPUT(bias);
40+
CHECK_INPUT(old_h);
41+
CHECK_INPUT(old_cell);
4042

4143
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
4244
}
@@ -51,14 +53,14 @@ std::vector<at::Tensor> lltm_backward(
5153
at::Tensor X,
5254
at::Tensor gate_weights,
5355
at::Tensor weights) {
54-
CHECK_CUDA(grad_h);
55-
CHECK_CUDA(grad_cell);
56-
CHECK_CUDA(input_gate);
57-
CHECK_CUDA(output_gate);
58-
CHECK_CUDA(candidate_cell);
59-
CHECK_CUDA(X);
60-
CHECK_CUDA(gate_weights);
61-
CHECK_CUDA(weights);
56+
CHECK_INPUT(grad_h);
57+
CHECK_INPUT(grad_cell);
58+
CHECK_INPUT(input_gate);
59+
CHECK_INPUT(output_gate);
60+
CHECK_INPUT(candidate_cell);
61+
CHECK_INPUT(X);
62+
CHECK_INPUT(gate_weights);
63+
CHECK_INPUT(weights);
6264

6365
return lltm_cuda_backward(
6466
grad_h,

cuda/lltm_cuda_kernel.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
122122

123123
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
124124
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
125-
gates.contiguous().data<scalar_t>(),
126-
old_cell.contiguous().data<scalar_t>(),
125+
gates.data<scalar_t>(),
126+
old_cell.data<scalar_t>(),
127127
new_h.data<scalar_t>(),
128128
new_cell.data<scalar_t>(),
129129
input_gate.data<scalar_t>(),
@@ -158,13 +158,13 @@ std::vector<at::Tensor> lltm_cuda_backward(
158158
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
159159
d_old_cell.data<scalar_t>(),
160160
d_gates.data<scalar_t>(),
161-
grad_h.contiguous().data<scalar_t>(),
162-
grad_cell.contiguous().data<scalar_t>(),
163-
new_cell.contiguous().data<scalar_t>(),
164-
input_gate.contiguous().data<scalar_t>(),
165-
output_gate.contiguous().data<scalar_t>(),
166-
candidate_cell.contiguous().data<scalar_t>(),
167-
gate_weights.contiguous().data<scalar_t>(),
161+
grad_h.data<scalar_t>(),
162+
grad_cell.data<scalar_t>(),
163+
new_cell.data<scalar_t>(),
164+
input_gate.data<scalar_t>(),
165+
output_gate.data<scalar_t>(),
166+
candidate_cell.data<scalar_t>(),
167+
gate_weights.data<scalar_t>(),
168168
state_size);
169169
}));
170170

0 commit comments

Comments
 (0)