@@ -45,12 +45,13 @@ __global__ void lltm_cuda_forward_kernel(
45
45
scalar_t * __restrict__ output_gate,
46
46
scalar_t * __restrict__ candidate_cell,
47
47
size_t state_size) {
48
- const auto column = blockIdx .x * blockDim .x + threadIdx .x ;
49
- const auto index = blockIdx .y * state_size + column;
48
+ const int column = blockIdx .x * blockDim .x + threadIdx .x ;
49
+ const int index = blockIdx .y * state_size + column;
50
+ const int gates_row = blockIdx .y * (state_size * 3 );
50
51
if (column < state_size) {
51
- input_gate[index] = sigmoid (gates[index ]);
52
- output_gate[index] = sigmoid (gates[state_size + index ]);
53
- candidate_cell[index] = elu (gates[2 * state_size + index ]);
52
+ input_gate[index] = sigmoid (gates[gates_row + column ]);
53
+ output_gate[index] = sigmoid (gates[gates_row + state_size + column ]);
54
+ candidate_cell[index] = elu (gates[gates_row + 2 * state_size + column ]);
54
55
new_cell[index] =
55
56
old_cell[index] + candidate_cell[index] * input_gate[index];
56
57
new_h[index] = tanh (new_cell[index]) * output_gate[index];
@@ -104,8 +105,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
104
105
auto X = at::cat ({old_h, input}, /* dim=*/ 1 );
105
106
auto gates = at::addmm (bias, X, weights.transpose (0 , 1 ));
106
107
107
- const auto batch_size = old_cell.size (0 );
108
- const auto state_size = old_cell.size (1 );
108
+ const size_t batch_size = old_cell.size (0 );
109
+ const size_t state_size = old_cell.size (1 );
109
110
110
111
auto new_h = at::zeros_like (old_cell);
111
112
auto new_cell = at::zeros_like (old_cell);
@@ -114,7 +115,7 @@ std::vector<at::Tensor> lltm_cuda_forward(
114
115
auto candidate_cell = at::zeros_like (old_cell);
115
116
116
117
const int threads = 1024 ;
117
- const dim3 blocks (batch_size, (state_size + threads - 1 ) / threads);
118
+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size );
118
119
119
120
AT_DISPATCH_FLOATING_TYPES (gates.type (), " lltm_forward_cuda" , ([&] {
120
121
lltm_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
@@ -148,7 +149,7 @@ std::vector<at::Tensor> lltm_cuda_backward(
148
149
const auto state_size = new_cell.size (1 );
149
150
150
151
const int threads = 1024 ;
151
- const dim3 blocks (batch_size, (state_size + threads - 1 ) / threads);
152
+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size );
152
153
153
154
AT_DISPATCH_FLOATING_TYPES (X.type (), " lltm_forward_cuda" , ([&] {
154
155
lltm_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
0 commit comments