Skip to content

Commit 07b1598

Browse files
authored
Merge pull request #32 from ClementPinard/fix-deprecated
Fix deprecated functions
2 parents eea6d31 + 4a86842 commit 07b1598

File tree

6 files changed

+124
-122
lines changed

6 files changed

+124
-122
lines changed

benchmark.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
parser.add_argument('-r', '--runs', type=int, default=100)
1818
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
1919
parser.add_argument('-c', '--cuda', action='store_true')
20+
parser.add_argument('-d', '--double', action='store_true')
2021
options = parser.parse_args()
2122

2223
if options.example == 'py':
@@ -27,16 +28,16 @@
2728
from cuda.lltm import LLTM
2829
options.cuda = True
2930

30-
X = torch.randn(options.batch_size, options.features)
31-
h = torch.randn(options.batch_size, options.state_size)
32-
C = torch.randn(options.batch_size, options.state_size)
33-
rnn = LLTM(options.features, options.state_size)
31+
device = torch.device("cuda") if options.cuda else torch.device("cpu")
32+
dtype = torch.float64 if options.double else torch.float32
3433

35-
if options.cuda:
36-
X = X.cuda()
37-
h = h.cuda()
38-
C = C.cuda()
39-
rnn.cuda()
34+
kwargs = {'dtype': dtype,
35+
'device': device,
36+
'requires_grad': True}
37+
X = torch.randn(options.batch_size, options.features, **kwargs)
38+
h = torch.randn(options.batch_size, options.state_size, **kwargs)
39+
C = torch.randn(options.batch_size, options.state_size, **kwargs)
40+
rnn = LLTM(options.features, options.state_size).to(device, dtype)
4041

4142
# Force CUDA initialization
4243
new_h, new_C = rnn(X, (h, C))

check.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import numpy as np
66
import torch
77

8-
from torch.autograd import Variable
9-
108
import python.lltm_baseline
119
import cpp.lltm
1210

@@ -85,21 +83,23 @@ def check_backward(variables, with_cuda, verbose):
8583

8684
if options.cuda:
8785
import cuda.lltm
88-
options.cuda = True
89-
90-
X = torch.randn(options.batch_size, options.features)
91-
h = torch.randn(options.batch_size, options.state_size)
92-
C = torch.randn(options.batch_size, options.state_size)
93-
W = torch.randn(3 * options.state_size, options.features + options.state_size)
94-
b = torch.randn(1, 3 * options.state_size)
86+
device = torch.device("cuda")
87+
else:
88+
device = torch.device("cpu")
89+
90+
kwargs = {'dtype': torch.float64,
91+
'device': device,
92+
'requires_grad': True}
93+
X = torch.randn(options.batch_size,
94+
options.features,
95+
**kwargs)
96+
h = torch.randn(options.batch_size, options.state_size, **kwargs)
97+
C = torch.randn(options.batch_size, options.state_size, **kwargs)
98+
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
99+
b = torch.randn(1, 3 * options.state_size, **kwargs)
95100

96101
variables = [X, W, b, h, C]
97102

98-
for i, var in enumerate(variables):
99-
if options.cuda:
100-
var = var.cuda()
101-
variables[i] = Variable(var.double(), requires_grad=True)
102-
103103
if 'forward' in options.direction:
104104
check_forward(variables, options.cuda, options.verbose)
105105

cpp/lltm.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
1-
#include <torch/torch.h>
1+
#include <torch/extension.h>
22

33
#include <vector>
44

55
// s'(z) = (1 - s(z)) * s(z)
6-
at::Tensor d_sigmoid(at::Tensor z) {
7-
auto s = at::sigmoid(z);
6+
torch::Tensor d_sigmoid(torch::Tensor z) {
7+
auto s = torch::sigmoid(z);
88
return (1 - s) * s;
99
}
1010

1111
// tanh'(z) = 1 - tanh^2(z)
12-
at::Tensor d_tanh(at::Tensor z) {
12+
torch::Tensor d_tanh(torch::Tensor z) {
1313
return 1 - z.tanh().pow(2);
1414
}
1515

1616
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
17-
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
17+
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
1818
auto e = z.exp();
1919
auto mask = (alpha * (e - 1)) < 0;
2020
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
2121
}
2222

23-
std::vector<at::Tensor> lltm_forward(
24-
at::Tensor input,
25-
at::Tensor weights,
26-
at::Tensor bias,
27-
at::Tensor old_h,
28-
at::Tensor old_cell) {
29-
auto X = at::cat({old_h, input}, /*dim=*/1);
23+
std::vector<torch::Tensor> lltm_forward(
24+
torch::Tensor input,
25+
torch::Tensor weights,
26+
torch::Tensor bias,
27+
torch::Tensor old_h,
28+
torch::Tensor old_cell) {
29+
auto X = torch::cat({old_h, input}, /*dim=*/1);
3030

31-
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
31+
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
3232
auto gates = gate_weights.chunk(3, /*dim=*/1);
3333

34-
auto input_gate = at::sigmoid(gates[0]);
35-
auto output_gate = at::sigmoid(gates[1]);
36-
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
34+
auto input_gate = torch::sigmoid(gates[0]);
35+
auto output_gate = torch::sigmoid(gates[1]);
36+
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
3737

3838
auto new_cell = old_cell + candidate_cell * input_gate;
39-
auto new_h = at::tanh(new_cell) * output_gate;
39+
auto new_h = torch::tanh(new_cell) * output_gate;
4040

4141
return {new_h,
4242
new_cell,
@@ -47,17 +47,17 @@ std::vector<at::Tensor> lltm_forward(
4747
gate_weights};
4848
}
4949

50-
std::vector<at::Tensor> lltm_backward(
51-
at::Tensor grad_h,
52-
at::Tensor grad_cell,
53-
at::Tensor new_cell,
54-
at::Tensor input_gate,
55-
at::Tensor output_gate,
56-
at::Tensor candidate_cell,
57-
at::Tensor X,
58-
at::Tensor gate_weights,
59-
at::Tensor weights) {
60-
auto d_output_gate = at::tanh(new_cell) * grad_h;
50+
std::vector<torch::Tensor> lltm_backward(
51+
torch::Tensor grad_h,
52+
torch::Tensor grad_cell,
53+
torch::Tensor new_cell,
54+
torch::Tensor input_gate,
55+
torch::Tensor output_gate,
56+
torch::Tensor candidate_cell,
57+
torch::Tensor X,
58+
torch::Tensor gate_weights,
59+
torch::Tensor weights) {
60+
auto d_output_gate = torch::tanh(new_cell) * grad_h;
6161
auto d_tanh_new_cell = output_gate * grad_h;
6262
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
6363

@@ -71,7 +71,7 @@ std::vector<at::Tensor> lltm_backward(
7171
d_candidate_cell *= d_elu(gates[2]);
7272

7373
auto d_gates =
74-
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
74+
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
7575

7676
auto d_weights = d_gates.t().mm(X);
7777
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);

cuda/lltm_cuda.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
#include <torch/torch.h>
1+
#include <torch/extension.h>
22

33
#include <vector>
44

55
// CUDA forward declarations
66

7-
std::vector<at::Tensor> lltm_cuda_forward(
8-
at::Tensor input,
9-
at::Tensor weights,
10-
at::Tensor bias,
11-
at::Tensor old_h,
12-
at::Tensor old_cell);
7+
std::vector<torch::Tensor> lltm_cuda_forward(
8+
torch::Tensor input,
9+
torch::Tensor weights,
10+
torch::Tensor bias,
11+
torch::Tensor old_h,
12+
torch::Tensor old_cell);
1313

14-
std::vector<at::Tensor> lltm_cuda_backward(
15-
at::Tensor grad_h,
16-
at::Tensor grad_cell,
17-
at::Tensor new_cell,
18-
at::Tensor input_gate,
19-
at::Tensor output_gate,
20-
at::Tensor candidate_cell,
21-
at::Tensor X,
22-
at::Tensor gate_weights,
23-
at::Tensor weights);
14+
std::vector<torch::Tensor> lltm_cuda_backward(
15+
torch::Tensor grad_h,
16+
torch::Tensor grad_cell,
17+
torch::Tensor new_cell,
18+
torch::Tensor input_gate,
19+
torch::Tensor output_gate,
20+
torch::Tensor candidate_cell,
21+
torch::Tensor X,
22+
torch::Tensor gate_weights,
23+
torch::Tensor weights);
2424

2525
// C++ interface
2626

@@ -29,12 +29,12 @@ std::vector<at::Tensor> lltm_cuda_backward(
2929
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
3030
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
3131

32-
std::vector<at::Tensor> lltm_forward(
33-
at::Tensor input,
34-
at::Tensor weights,
35-
at::Tensor bias,
36-
at::Tensor old_h,
37-
at::Tensor old_cell) {
32+
std::vector<torch::Tensor> lltm_forward(
33+
torch::Tensor input,
34+
torch::Tensor weights,
35+
torch::Tensor bias,
36+
torch::Tensor old_h,
37+
torch::Tensor old_cell) {
3838
CHECK_INPUT(input);
3939
CHECK_INPUT(weights);
4040
CHECK_INPUT(bias);
@@ -44,16 +44,16 @@ std::vector<at::Tensor> lltm_forward(
4444
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
4545
}
4646

47-
std::vector<at::Tensor> lltm_backward(
48-
at::Tensor grad_h,
49-
at::Tensor grad_cell,
50-
at::Tensor new_cell,
51-
at::Tensor input_gate,
52-
at::Tensor output_gate,
53-
at::Tensor candidate_cell,
54-
at::Tensor X,
55-
at::Tensor gate_weights,
56-
at::Tensor weights) {
47+
std::vector<torch::Tensor> lltm_backward(
48+
torch::Tensor grad_h,
49+
torch::Tensor grad_cell,
50+
torch::Tensor new_cell,
51+
torch::Tensor input_gate,
52+
torch::Tensor output_gate,
53+
torch::Tensor candidate_cell,
54+
torch::Tensor X,
55+
torch::Tensor gate_weights,
56+
torch::Tensor weights) {
5757
CHECK_INPUT(grad_h);
5858
CHECK_INPUT(grad_cell);
5959
CHECK_INPUT(input_gate);

cuda/lltm_cuda_kernel.cu

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <ATen/ATen.h>
1+
#include <torch/extension.h>
22

33
#include <cuda.h>
44
#include <cuda_runtime.h>
@@ -99,23 +99,23 @@ __global__ void lltm_cuda_backward_kernel(
9999
}
100100
} // namespace
101101

102-
std::vector<at::Tensor> lltm_cuda_forward(
103-
at::Tensor input,
104-
at::Tensor weights,
105-
at::Tensor bias,
106-
at::Tensor old_h,
107-
at::Tensor old_cell) {
108-
auto X = at::cat({old_h, input}, /*dim=*/1);
109-
auto gates = at::addmm(bias, X, weights.transpose(0, 1));
102+
std::vector<torch::Tensor> lltm_cuda_forward(
103+
torch::Tensor input,
104+
torch::Tensor weights,
105+
torch::Tensor bias,
106+
torch::Tensor old_h,
107+
torch::Tensor old_cell) {
108+
auto X = torch::cat({old_h, input}, /*dim=*/1);
109+
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));
110110

111111
const auto batch_size = old_cell.size(0);
112112
const auto state_size = old_cell.size(1);
113113

114-
auto new_h = at::zeros_like(old_cell);
115-
auto new_cell = at::zeros_like(old_cell);
116-
auto input_gate = at::zeros_like(old_cell);
117-
auto output_gate = at::zeros_like(old_cell);
118-
auto candidate_cell = at::zeros_like(old_cell);
114+
auto new_h = torch::zeros_like(old_cell);
115+
auto new_cell = torch::zeros_like(old_cell);
116+
auto input_gate = torch::zeros_like(old_cell);
117+
auto output_gate = torch::zeros_like(old_cell);
118+
auto candidate_cell = torch::zeros_like(old_cell);
119119

120120
const int threads = 1024;
121121
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
@@ -135,18 +135,18 @@ std::vector<at::Tensor> lltm_cuda_forward(
135135
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
136136
}
137137

138-
std::vector<at::Tensor> lltm_cuda_backward(
139-
at::Tensor grad_h,
140-
at::Tensor grad_cell,
141-
at::Tensor new_cell,
142-
at::Tensor input_gate,
143-
at::Tensor output_gate,
144-
at::Tensor candidate_cell,
145-
at::Tensor X,
146-
at::Tensor gate_weights,
147-
at::Tensor weights) {
148-
auto d_old_cell = at::zeros_like(new_cell);
149-
auto d_gates = at::zeros_like(gate_weights);
138+
std::vector<torch::Tensor> lltm_cuda_backward(
139+
torch::Tensor grad_h,
140+
torch::Tensor grad_cell,
141+
torch::Tensor new_cell,
142+
torch::Tensor input_gate,
143+
torch::Tensor output_gate,
144+
torch::Tensor candidate_cell,
145+
torch::Tensor X,
146+
torch::Tensor gate_weights,
147+
torch::Tensor weights) {
148+
auto d_old_cell = torch::zeros_like(new_cell);
149+
auto d_gates = torch::zeros_like(gate_weights);
150150

151151
const auto batch_size = new_cell.size(0);
152152
const auto state_size = new_cell.size(1);

grad_check.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
import argparse
55
import torch
6-
7-
from torch.autograd import Variable, gradcheck
6+
from torch.autograd import gradcheck
87

98
parser = argparse.ArgumentParser()
109
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
@@ -22,18 +21,20 @@
2221
from cuda.lltm import LLTMFunction
2322
options.cuda = True
2423

25-
X = torch.randn(options.batch_size, options.features)
26-
h = torch.randn(options.batch_size, options.state_size)
27-
C = torch.randn(options.batch_size, options.state_size)
28-
W = torch.randn(3 * options.state_size, options.features + options.state_size)
29-
b = torch.randn(1, 3 * options.state_size)
24+
device = torch.device("cuda") if options.cuda else torch.device("cpu")
25+
26+
kwargs = {'dtype': torch.float64,
27+
'device': device,
28+
'requires_grad': True}
29+
30+
X = torch.randn(options.batch_size, options.features, **kwargs)
31+
h = torch.randn(options.batch_size, options.state_size, **kwargs)
32+
C = torch.randn(options.batch_size, options.state_size, **kwargs)
33+
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
34+
b = torch.randn(1, 3 * options.state_size, **kwargs)
3035

3136
variables = [X, W, b, h, C]
3237

33-
for i, var in enumerate(variables):
34-
if options.cuda:
35-
var = var.cuda()
36-
variables[i] = Variable(var.double(), requires_grad=True)
3738

3839
if gradcheck(LLTMFunction.apply, variables):
3940
print('Ok')

0 commit comments

Comments
 (0)