Skip to content

Commit 91617a4

Browse files
committed
Fix cuda kernel for batch size > 1
1 parent 5451cec commit 91617a4

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

cuda/lltm_cuda_kernel.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ __global__ void lltm_cuda_forward_kernel(
4545
scalar_t* __restrict__ output_gate,
4646
scalar_t* __restrict__ candidate_cell,
4747
size_t state_size) {
48-
const auto column = blockIdx.x * blockDim.x + threadIdx.x;
49-
const auto index = blockIdx.y * state_size + column;
48+
const int column = blockIdx.x * blockDim.x + threadIdx.x;
49+
const int index = blockIdx.y * state_size + column;
50+
const int gates_row = blockIdx.y * (state_size * 3);
5051
if (column < state_size) {
51-
input_gate[index] = sigmoid(gates[index]);
52-
output_gate[index] = sigmoid(gates[state_size + index]);
53-
candidate_cell[index] = elu(gates[2 * state_size + index]);
52+
input_gate[index] = sigmoid(gates[gates_row + column]);
53+
output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
54+
candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
5455
new_cell[index] =
5556
old_cell[index] + candidate_cell[index] * input_gate[index];
5657
new_h[index] = tanh(new_cell[index]) * output_gate[index];
@@ -104,8 +105,8 @@ std::vector<at::Tensor> lltm_cuda_forward(
104105
auto X = at::cat({old_h, input}, /*dim=*/1);
105106
auto gates = at::addmm(bias, X, weights.transpose(0, 1));
106107

107-
const auto batch_size = old_cell.size(0);
108-
const auto state_size = old_cell.size(1);
108+
const size_t batch_size = old_cell.size(0);
109+
const size_t state_size = old_cell.size(1);
109110

110111
auto new_h = at::zeros_like(old_cell);
111112
auto new_cell = at::zeros_like(old_cell);
@@ -114,7 +115,7 @@ std::vector<at::Tensor> lltm_cuda_forward(
114115
auto candidate_cell = at::zeros_like(old_cell);
115116

116117
const int threads = 1024;
117-
const dim3 blocks(batch_size, (state_size + threads - 1) / threads);
118+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
118119

119120
AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
120121
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
@@ -148,7 +149,7 @@ std::vector<at::Tensor> lltm_cuda_backward(
148149
const auto state_size = new_cell.size(1);
149150

150151
const int threads = 1024;
151-
const dim3 blocks(batch_size, (state_size + threads - 1) / threads);
152+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
152153

153154
AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
154155
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(

0 commit comments

Comments
 (0)