Skip to content

Commit 4a86842

Browse files
author
Clément Pinard
committed
Upgrade check scripts to pytorch 1.0
* dismiss deprecated Variable semantics * use dtype and device keywords * add possibility to benchmark with float64
1 parent 554c55d commit 4a86842

File tree

3 files changed

+36
-34
lines changed

3 files changed

+36
-34
lines changed

benchmark.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
parser.add_argument('-r', '--runs', type=int, default=100)
1818
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
1919
parser.add_argument('-c', '--cuda', action='store_true')
20+
parser.add_argument('-d', '--double', action='store_true')
2021
options = parser.parse_args()
2122

2223
if options.example == 'py':
@@ -27,16 +28,16 @@
2728
from cuda.lltm import LLTM
2829
options.cuda = True
2930

30-
X = torch.randn(options.batch_size, options.features)
31-
h = torch.randn(options.batch_size, options.state_size)
32-
C = torch.randn(options.batch_size, options.state_size)
33-
rnn = LLTM(options.features, options.state_size)
31+
device = torch.device("cuda") if options.cuda else torch.device("cpu")
32+
dtype = torch.float64 if options.double else torch.float32
3433

35-
if options.cuda:
36-
X = X.cuda()
37-
h = h.cuda()
38-
C = C.cuda()
39-
rnn.cuda()
34+
kwargs = {'dtype': dtype,
35+
'device': device,
36+
'requires_grad': True}
37+
X = torch.randn(options.batch_size, options.features, **kwargs)
38+
h = torch.randn(options.batch_size, options.state_size, **kwargs)
39+
C = torch.randn(options.batch_size, options.state_size, **kwargs)
40+
rnn = LLTM(options.features, options.state_size).to(device, dtype)
4041

4142
# Force CUDA initialization
4243
new_h, new_C = rnn(X, (h, C))

check.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import numpy as np
66
import torch
77

8-
from torch.autograd import Variable
9-
108
import python.lltm_baseline
119
import cpp.lltm
1210

@@ -85,21 +83,23 @@ def check_backward(variables, with_cuda, verbose):
8583

8684
if options.cuda:
8785
import cuda.lltm
88-
options.cuda = True
89-
90-
X = torch.randn(options.batch_size, options.features)
91-
h = torch.randn(options.batch_size, options.state_size)
92-
C = torch.randn(options.batch_size, options.state_size)
93-
W = torch.randn(3 * options.state_size, options.features + options.state_size)
94-
b = torch.randn(1, 3 * options.state_size)
86+
device = torch.device("cuda")
87+
else:
88+
device = torch.device("cpu")
89+
90+
kwargs = {'dtype': torch.float64,
91+
'device': device,
92+
'requires_grad': True}
93+
X = torch.randn(options.batch_size,
94+
options.features,
95+
**kwargs)
96+
h = torch.randn(options.batch_size, options.state_size, **kwargs)
97+
C = torch.randn(options.batch_size, options.state_size, **kwargs)
98+
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
99+
b = torch.randn(1, 3 * options.state_size, **kwargs)
95100

96101
variables = [X, W, b, h, C]
97102

98-
for i, var in enumerate(variables):
99-
if options.cuda:
100-
var = var.cuda()
101-
variables[i] = Variable(var.double(), requires_grad=True)
102-
103103
if 'forward' in options.direction:
104104
check_forward(variables, options.cuda, options.verbose)
105105

grad_check.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
import argparse
55
import torch
6-
7-
from torch.autograd import Variable, gradcheck
6+
from torch.autograd import gradcheck
87

98
parser = argparse.ArgumentParser()
109
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
@@ -22,18 +21,20 @@
2221
from cuda.lltm import LLTMFunction
2322
options.cuda = True
2423

25-
X = torch.randn(options.batch_size, options.features)
26-
h = torch.randn(options.batch_size, options.state_size)
27-
C = torch.randn(options.batch_size, options.state_size)
28-
W = torch.randn(3 * options.state_size, options.features + options.state_size)
29-
b = torch.randn(1, 3 * options.state_size)
24+
device = torch.device("cuda") if options.cuda else torch.device("cpu")
25+
26+
kwargs = {'dtype': torch.float64,
27+
'device': device,
28+
'requires_grad': True}
29+
30+
X = torch.randn(options.batch_size, options.features, **kwargs)
31+
h = torch.randn(options.batch_size, options.state_size, **kwargs)
32+
C = torch.randn(options.batch_size, options.state_size, **kwargs)
33+
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
34+
b = torch.randn(1, 3 * options.state_size, **kwargs)
3035

3136
variables = [X, W, b, h, C]
3237

33-
for i, var in enumerate(variables):
34-
if options.cuda:
35-
var = var.cuda()
36-
variables[i] = Variable(var.double(), requires_grad=True)
3738

3839
if gradcheck(LLTMFunction.apply, variables):
3940
print('Ok')

0 commit comments

Comments
 (0)