Skip to content

Commit 554c55d

Browse files
author
Clément Pinard
committed
Upgrade deprecated function calls
* include torch/extension.h instead of torch/torch.h or Aten/Aten.h * replace all at:: calls to torch::
1 parent eea6d31 commit 554c55d

File tree

3 files changed

+88
-88
lines changed

3 files changed

+88
-88
lines changed

cpp/lltm.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
1-
#include <torch/torch.h>
1+
#include <torch/extension.h>
22

33
#include <vector>
44

55
// s'(z) = (1 - s(z)) * s(z)
6-
at::Tensor d_sigmoid(at::Tensor z) {
7-
auto s = at::sigmoid(z);
6+
torch::Tensor d_sigmoid(torch::Tensor z) {
7+
auto s = torch::sigmoid(z);
88
return (1 - s) * s;
99
}
1010

1111
// tanh'(z) = 1 - tanh^2(z)
12-
at::Tensor d_tanh(at::Tensor z) {
12+
torch::Tensor d_tanh(torch::Tensor z) {
1313
return 1 - z.tanh().pow(2);
1414
}
1515

1616
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
17-
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
17+
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
1818
auto e = z.exp();
1919
auto mask = (alpha * (e - 1)) < 0;
2020
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
2121
}
2222

23-
std::vector<at::Tensor> lltm_forward(
24-
at::Tensor input,
25-
at::Tensor weights,
26-
at::Tensor bias,
27-
at::Tensor old_h,
28-
at::Tensor old_cell) {
29-
auto X = at::cat({old_h, input}, /*dim=*/1);
23+
std::vector<torch::Tensor> lltm_forward(
24+
torch::Tensor input,
25+
torch::Tensor weights,
26+
torch::Tensor bias,
27+
torch::Tensor old_h,
28+
torch::Tensor old_cell) {
29+
auto X = torch::cat({old_h, input}, /*dim=*/1);
3030

31-
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
31+
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
3232
auto gates = gate_weights.chunk(3, /*dim=*/1);
3333

34-
auto input_gate = at::sigmoid(gates[0]);
35-
auto output_gate = at::sigmoid(gates[1]);
36-
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
34+
auto input_gate = torch::sigmoid(gates[0]);
35+
auto output_gate = torch::sigmoid(gates[1]);
36+
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
3737

3838
auto new_cell = old_cell + candidate_cell * input_gate;
39-
auto new_h = at::tanh(new_cell) * output_gate;
39+
auto new_h = torch::tanh(new_cell) * output_gate;
4040

4141
return {new_h,
4242
new_cell,
@@ -47,17 +47,17 @@ std::vector<at::Tensor> lltm_forward(
4747
gate_weights};
4848
}
4949

50-
std::vector<at::Tensor> lltm_backward(
51-
at::Tensor grad_h,
52-
at::Tensor grad_cell,
53-
at::Tensor new_cell,
54-
at::Tensor input_gate,
55-
at::Tensor output_gate,
56-
at::Tensor candidate_cell,
57-
at::Tensor X,
58-
at::Tensor gate_weights,
59-
at::Tensor weights) {
60-
auto d_output_gate = at::tanh(new_cell) * grad_h;
50+
std::vector<torch::Tensor> lltm_backward(
51+
torch::Tensor grad_h,
52+
torch::Tensor grad_cell,
53+
torch::Tensor new_cell,
54+
torch::Tensor input_gate,
55+
torch::Tensor output_gate,
56+
torch::Tensor candidate_cell,
57+
torch::Tensor X,
58+
torch::Tensor gate_weights,
59+
torch::Tensor weights) {
60+
auto d_output_gate = torch::tanh(new_cell) * grad_h;
6161
auto d_tanh_new_cell = output_gate * grad_h;
6262
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
6363

@@ -71,7 +71,7 @@ std::vector<at::Tensor> lltm_backward(
7171
d_candidate_cell *= d_elu(gates[2]);
7272

7373
auto d_gates =
74-
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
74+
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
7575

7676
auto d_weights = d_gates.t().mm(X);
7777
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);

cuda/lltm_cuda.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
#include <torch/torch.h>
1+
#include <torch/extension.h>
22

33
#include <vector>
44

55
// CUDA forward declarations
66

7-
std::vector<at::Tensor> lltm_cuda_forward(
8-
at::Tensor input,
9-
at::Tensor weights,
10-
at::Tensor bias,
11-
at::Tensor old_h,
12-
at::Tensor old_cell);
7+
std::vector<torch::Tensor> lltm_cuda_forward(
8+
torch::Tensor input,
9+
torch::Tensor weights,
10+
torch::Tensor bias,
11+
torch::Tensor old_h,
12+
torch::Tensor old_cell);
1313

14-
std::vector<at::Tensor> lltm_cuda_backward(
15-
at::Tensor grad_h,
16-
at::Tensor grad_cell,
17-
at::Tensor new_cell,
18-
at::Tensor input_gate,
19-
at::Tensor output_gate,
20-
at::Tensor candidate_cell,
21-
at::Tensor X,
22-
at::Tensor gate_weights,
23-
at::Tensor weights);
14+
std::vector<torch::Tensor> lltm_cuda_backward(
15+
torch::Tensor grad_h,
16+
torch::Tensor grad_cell,
17+
torch::Tensor new_cell,
18+
torch::Tensor input_gate,
19+
torch::Tensor output_gate,
20+
torch::Tensor candidate_cell,
21+
torch::Tensor X,
22+
torch::Tensor gate_weights,
23+
torch::Tensor weights);
2424

2525
// C++ interface
2626

@@ -29,12 +29,12 @@ std::vector<at::Tensor> lltm_cuda_backward(
2929
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
3030
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
3131

32-
std::vector<at::Tensor> lltm_forward(
33-
at::Tensor input,
34-
at::Tensor weights,
35-
at::Tensor bias,
36-
at::Tensor old_h,
37-
at::Tensor old_cell) {
32+
std::vector<torch::Tensor> lltm_forward(
33+
torch::Tensor input,
34+
torch::Tensor weights,
35+
torch::Tensor bias,
36+
torch::Tensor old_h,
37+
torch::Tensor old_cell) {
3838
CHECK_INPUT(input);
3939
CHECK_INPUT(weights);
4040
CHECK_INPUT(bias);
@@ -44,16 +44,16 @@ std::vector<at::Tensor> lltm_forward(
4444
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
4545
}
4646

47-
std::vector<at::Tensor> lltm_backward(
48-
at::Tensor grad_h,
49-
at::Tensor grad_cell,
50-
at::Tensor new_cell,
51-
at::Tensor input_gate,
52-
at::Tensor output_gate,
53-
at::Tensor candidate_cell,
54-
at::Tensor X,
55-
at::Tensor gate_weights,
56-
at::Tensor weights) {
47+
std::vector<torch::Tensor> lltm_backward(
48+
torch::Tensor grad_h,
49+
torch::Tensor grad_cell,
50+
torch::Tensor new_cell,
51+
torch::Tensor input_gate,
52+
torch::Tensor output_gate,
53+
torch::Tensor candidate_cell,
54+
torch::Tensor X,
55+
torch::Tensor gate_weights,
56+
torch::Tensor weights) {
5757
CHECK_INPUT(grad_h);
5858
CHECK_INPUT(grad_cell);
5959
CHECK_INPUT(input_gate);

cuda/lltm_cuda_kernel.cu

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <ATen/ATen.h>
1+
#include <torch/extension.h>
22

33
#include <cuda.h>
44
#include <cuda_runtime.h>
@@ -99,23 +99,23 @@ __global__ void lltm_cuda_backward_kernel(
9999
}
100100
} // namespace
101101

102-
std::vector<at::Tensor> lltm_cuda_forward(
103-
at::Tensor input,
104-
at::Tensor weights,
105-
at::Tensor bias,
106-
at::Tensor old_h,
107-
at::Tensor old_cell) {
108-
auto X = at::cat({old_h, input}, /*dim=*/1);
109-
auto gates = at::addmm(bias, X, weights.transpose(0, 1));
102+
std::vector<torch::Tensor> lltm_cuda_forward(
103+
torch::Tensor input,
104+
torch::Tensor weights,
105+
torch::Tensor bias,
106+
torch::Tensor old_h,
107+
torch::Tensor old_cell) {
108+
auto X = torch::cat({old_h, input}, /*dim=*/1);
109+
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));
110110

111111
const auto batch_size = old_cell.size(0);
112112
const auto state_size = old_cell.size(1);
113113

114-
auto new_h = at::zeros_like(old_cell);
115-
auto new_cell = at::zeros_like(old_cell);
116-
auto input_gate = at::zeros_like(old_cell);
117-
auto output_gate = at::zeros_like(old_cell);
118-
auto candidate_cell = at::zeros_like(old_cell);
114+
auto new_h = torch::zeros_like(old_cell);
115+
auto new_cell = torch::zeros_like(old_cell);
116+
auto input_gate = torch::zeros_like(old_cell);
117+
auto output_gate = torch::zeros_like(old_cell);
118+
auto candidate_cell = torch::zeros_like(old_cell);
119119

120120
const int threads = 1024;
121121
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
@@ -135,18 +135,18 @@ std::vector<at::Tensor> lltm_cuda_forward(
135135
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
136136
}
137137

138-
std::vector<at::Tensor> lltm_cuda_backward(
139-
at::Tensor grad_h,
140-
at::Tensor grad_cell,
141-
at::Tensor new_cell,
142-
at::Tensor input_gate,
143-
at::Tensor output_gate,
144-
at::Tensor candidate_cell,
145-
at::Tensor X,
146-
at::Tensor gate_weights,
147-
at::Tensor weights) {
148-
auto d_old_cell = at::zeros_like(new_cell);
149-
auto d_gates = at::zeros_like(gate_weights);
138+
std::vector<torch::Tensor> lltm_cuda_backward(
139+
torch::Tensor grad_h,
140+
torch::Tensor grad_cell,
141+
torch::Tensor new_cell,
142+
torch::Tensor input_gate,
143+
torch::Tensor output_gate,
144+
torch::Tensor candidate_cell,
145+
torch::Tensor X,
146+
torch::Tensor gate_weights,
147+
torch::Tensor weights) {
148+
auto d_old_cell = torch::zeros_like(new_cell);
149+
auto d_gates = torch::zeros_like(gate_weights);
150150

151151
const auto batch_size = new_cell.size(0);
152152
const auto state_size = new_cell.size(1);

0 commit comments

Comments
 (0)