Skip to content

Commit babaa16

Browse files
committed
Make CUDA tensors contiguous and dont share grads in check.py
1 parent 9f1b3d3 commit babaa16

File tree

3 files changed

+28
-20
lines changed

3 files changed

+28
-20
lines changed

check.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def zero_grad(variables):
2929
variable.grad.zero_()
3030

3131

32+
def get_grads(variables):
33+
return [var.grad.clone() for var in variables]
34+
35+
3236
def check_forward(variables, with_cuda, verbose):
3337
baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables)
3438
cpp_values = cpp.lltm.LLTMFunction.apply(*variables)
@@ -47,13 +51,13 @@ def check_forward(variables, with_cuda, verbose):
4751
def check_backward(variables, with_cuda, verbose):
4852
baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables)
4953
(baseline_values[0] + baseline_values[1]).sum().backward()
50-
grad_baseline = [var.grad for var in variables]
54+
grad_baseline = get_grads(variables)
5155

5256
zero_grad(variables)
5357

5458
cpp_values = cpp.lltm.LLTMFunction.apply(*variables)
5559
(cpp_values[0] + cpp_values[1]).sum().backward()
56-
grad_cpp = [var.grad for var in variables]
60+
grad_cpp = get_grads(variables)
5761

5862
print('Backward: Baseline (Python) vs. C++ ... ', end='')
5963
check_equal(grad_baseline, grad_cpp, verbose)
@@ -63,7 +67,7 @@ def check_backward(variables, with_cuda, verbose):
6367
zero_grad(variables)
6468
cuda_values = cuda.lltm.LLTMFunction.apply(*variables)
6569
(cuda_values[0] + cuda_values[1]).sum().backward()
66-
grad_cuda = [var.grad for var in variables]
70+
grad_cuda = get_grads(variables)
6771

6872
print('Backward: Baseline (Python) vs. CUDA ... ', end='')
6973
check_equal(grad_baseline, grad_cuda, verbose)

cuda/lltm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ def forward(ctx, input, weights, bias, old_h, old_cell):
2020

2121
@staticmethod
2222
def backward(ctx, grad_h, grad_cell):
23-
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cuda.backward(
24-
grad_h, grad_cell, *ctx.saved_variables)
23+
outputs = lltm_cuda.backward(
24+
grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
25+
d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs
2526
return d_input, d_weights, d_bias, d_old_h, d_old_cell
2627

2728

cuda/lltm_cuda_kernel.cu

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,22 @@ __global__ void lltm_cuda_backward_kernel(
7272
size_t state_size) {
7373
const int column = blockIdx.x * blockDim.x + threadIdx.x;
7474
const int index = blockIdx.y * state_size + column;
75+
const int gates_row = blockIdx.y * (state_size * 3);
7576
if (column < state_size) {
7677
const auto d_output_gate = tanh(new_cell[index]) * grad_h[index];
7778
const auto d_tanh_new_cell = output_gate[index] * grad_h[index];
7879
const auto d_new_cell =
7980
d_tanh(new_cell[index]) * d_tanh_new_cell + grad_cell[index];
8081

82+
8183
d_old_cell[index] = d_new_cell;
8284
const auto d_candidate_cell = input_gate[index] * d_new_cell;
8385
const auto d_input_gate = candidate_cell[index] * d_new_cell;
8486

85-
const auto input_gate_index = index;
86-
const auto output_gate_index = state_size + index;
87-
const auto candidate_cell_index = 2 * state_size + index;
87+
88+
const auto input_gate_index = gates_row + column;
89+
const auto output_gate_index = gates_row + state_size + column;
90+
const auto candidate_cell_index = gates_row + 2 * state_size + column;
8891

8992
d_gates[input_gate_index] =
9093
d_input_gate * d_sigmoid(gate_weights[input_gate_index]);
@@ -105,8 +108,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
105108
auto X = at::cat({old_h, input}, /*dim=*/1);
106109
auto gates = at::addmm(bias, X, weights.transpose(0, 1));
107110

108-
const size_t batch_size = old_cell.size(0);
109-
const size_t state_size = old_cell.size(1);
111+
const auto batch_size = old_cell.size(0);
112+
const auto state_size = old_cell.size(1);
110113

111114
auto new_h = at::zeros_like(old_cell);
112115
auto new_cell = at::zeros_like(old_cell);
@@ -119,8 +122,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
119122

120123
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
121124
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
122-
gates.data<scalar_t>(),
123-
old_cell.data<scalar_t>(),
125+
gates.contiguous().data<scalar_t>(),
126+
old_cell.contiguous().data<scalar_t>(),
124127
new_h.data<scalar_t>(),
125128
new_cell.data<scalar_t>(),
126129
input_gate.data<scalar_t>(),
@@ -155,13 +158,13 @@ std::vector<at::Tensor> lltm_cuda_backward(
155158
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
156159
d_old_cell.data<scalar_t>(),
157160
d_gates.data<scalar_t>(),
158-
grad_h.data<scalar_t>(),
159-
grad_cell.data<scalar_t>(),
160-
new_cell.data<scalar_t>(),
161-
input_gate.data<scalar_t>(),
162-
output_gate.data<scalar_t>(),
163-
candidate_cell.data<scalar_t>(),
164-
gate_weights.data<scalar_t>(),
161+
grad_h.contiguous().data<scalar_t>(),
162+
grad_cell.contiguous().data<scalar_t>(),
163+
new_cell.contiguous().data<scalar_t>(),
164+
input_gate.contiguous().data<scalar_t>(),
165+
output_gate.contiguous().data<scalar_t>(),
166+
candidate_cell.contiguous().data<scalar_t>(),
167+
gate_weights.contiguous().data<scalar_t>(),
165168
state_size);
166169
}));
167170

@@ -172,5 +175,5 @@ std::vector<at::Tensor> lltm_cuda_backward(
172175
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
173176
auto d_input = d_X.slice(/*dim=*/1, state_size);
174177

175-
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
178+
return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
176179
}

0 commit comments

Comments
 (0)