@@ -72,19 +72,22 @@ __global__ void lltm_cuda_backward_kernel(
72
72
size_t state_size) {
73
73
const int column = blockIdx .x * blockDim .x + threadIdx .x ;
74
74
const int index = blockIdx .y * state_size + column;
75
+ const int gates_row = blockIdx .y * (state_size * 3 );
75
76
if (column < state_size) {
76
77
const auto d_output_gate = tanh (new_cell[index]) * grad_h[index];
77
78
const auto d_tanh_new_cell = output_gate[index] * grad_h[index];
78
79
const auto d_new_cell =
79
80
d_tanh (new_cell[index]) * d_tanh_new_cell + grad_cell[index];
80
81
82
+
81
83
d_old_cell[index] = d_new_cell;
82
84
const auto d_candidate_cell = input_gate[index] * d_new_cell;
83
85
const auto d_input_gate = candidate_cell[index] * d_new_cell;
84
86
85
- const auto input_gate_index = index;
86
- const auto output_gate_index = state_size + index;
87
- const auto candidate_cell_index = 2 * state_size + index;
87
+
88
+ const auto input_gate_index = gates_row + column;
89
+ const auto output_gate_index = gates_row + state_size + column;
90
+ const auto candidate_cell_index = gates_row + 2 * state_size + column;
88
91
89
92
d_gates[input_gate_index] =
90
93
d_input_gate * d_sigmoid (gate_weights[input_gate_index]);
@@ -105,8 +108,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
105
108
auto X = at::cat ({old_h, input}, /* dim=*/ 1 );
106
109
auto gates = at::addmm (bias, X, weights.transpose (0 , 1 ));
107
110
108
- const size_t batch_size = old_cell.size (0 );
109
- const size_t state_size = old_cell.size (1 );
111
+ const auto batch_size = old_cell.size (0 );
112
+ const auto state_size = old_cell.size (1 );
110
113
111
114
auto new_h = at::zeros_like (old_cell);
112
115
auto new_cell = at::zeros_like (old_cell);
@@ -119,8 +122,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
119
122
120
123
AT_DISPATCH_FLOATING_TYPES (gates.type (), " lltm_forward_cuda" , ([&] {
121
124
lltm_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
122
- gates.data <scalar_t >(),
123
- old_cell.data <scalar_t >(),
125
+ gates.contiguous (). data <scalar_t >(),
126
+ old_cell.contiguous (). data <scalar_t >(),
124
127
new_h.data <scalar_t >(),
125
128
new_cell.data <scalar_t >(),
126
129
input_gate.data <scalar_t >(),
@@ -155,13 +158,13 @@ std::vector<at::Tensor> lltm_cuda_backward(
155
158
lltm_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
156
159
d_old_cell.data <scalar_t >(),
157
160
d_gates.data <scalar_t >(),
158
- grad_h.data <scalar_t >(),
159
- grad_cell.data <scalar_t >(),
160
- new_cell.data <scalar_t >(),
161
- input_gate.data <scalar_t >(),
162
- output_gate.data <scalar_t >(),
163
- candidate_cell.data <scalar_t >(),
164
- gate_weights.data <scalar_t >(),
161
+ grad_h.contiguous (). data <scalar_t >(),
162
+ grad_cell.contiguous (). data <scalar_t >(),
163
+ new_cell.contiguous (). data <scalar_t >(),
164
+ input_gate.contiguous (). data <scalar_t >(),
165
+ output_gate.contiguous (). data <scalar_t >(),
166
+ candidate_cell.contiguous (). data <scalar_t >(),
167
+ gate_weights.contiguous (). data <scalar_t >(),
165
168
state_size);
166
169
}));
167
170
@@ -172,5 +175,5 @@ std::vector<at::Tensor> lltm_cuda_backward(
172
175
auto d_old_h = d_X.slice (/* dim=*/ 1 , 0 , state_size);
173
176
auto d_input = d_X.slice (/* dim=*/ 1 , state_size);
174
177
175
- return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
178
+ return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates };
176
179
}
0 commit comments