Skip to content

Commit c081299

Browse files
committed
Initial commit
0 parents  commit c081299

17 files changed

+721
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# pytorch-cpp-extension
2+
3+
An example of writing a C++ extension for PyTorch.

benchmark.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
4+
import argparse
5+
import math
6+
import time
7+
8+
import torch
9+
10+
TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000}
11+
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
14+
parser.add_argument('-b', '--batch-size', type=int, default=16)
15+
parser.add_argument('-f', '--features', type=int, default=32)
16+
parser.add_argument('-s', '--state-size', type=int, default=128)
17+
parser.add_argument('-r', '--runs', type=int, default=100)
18+
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
19+
parser.add_argument('-c', '--cuda', action='store_true')
20+
options = parser.parse_args()
21+
22+
if options.example == 'py':
23+
from python.lltm import LLTM
24+
elif options.example == 'cpp':
25+
from cpp.lltm import LLTM
26+
else:
27+
from cuda.lltm import LLTM
28+
options.cuda = True
29+
30+
X = torch.randn(options.batch_size, options.features)
31+
h = torch.randn(options.batch_size, options.state_size)
32+
C = torch.randn(options.batch_size, options.state_size)
33+
rnn = LLTM(options.features, options.state_size)
34+
35+
if options.cuda:
36+
X = X.cuda()
37+
h = h.cuda()
38+
C = C.cuda()
39+
rnn.cuda()
40+
41+
# Force CUDA initialization
42+
new_h, new_C = rnn(X, (h, C))
43+
(new_h.sum() + new_C.sum()).backward()
44+
45+
forward_min = math.inf
46+
forward_time = 0
47+
backward_min = math.inf
48+
backward_time = 0
49+
for _ in range(options.runs):
50+
rnn.zero_grad()
51+
52+
start = time.time()
53+
new_h, new_C = rnn(X, (h, C))
54+
elapsed = time.time() - start
55+
forward_min = min(forward_min, elapsed)
56+
forward_time += elapsed
57+
58+
start = time.time()
59+
(new_h.sum() + new_C.sum()).backward()
60+
elapsed = time.time() - start
61+
backward_min = min(backward_min, elapsed)
62+
backward_time += elapsed
63+
64+
scale = TIME_SCALES[options.scale]
65+
forward_min *= scale
66+
backward_min *= scale
67+
forward_average = forward_time / options.runs * scale
68+
backward_average = backward_time / options.runs * scale
69+
70+
print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format(
71+
forward_min, forward_average, backward_min, backward_average,
72+
options.scale))

cpp/__init__.py

Whitespace-only changes.

cpp/jit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torch.utils.cpp_extension import load
2+
lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"], verbose=True)
3+
help(lltm_cpp)

cpp/lltm.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#include <torch/torch.h>
2+
3+
#include <vector>
4+
5+
// s'(z) = (1 - s(z)) * s(z)
6+
at::Tensor d_sigmoid(at::Tensor z) {
7+
auto s = at::sigmoid(z);
8+
return (1 - s) * s;
9+
}
10+
11+
// tanh'(z) = 1 - tanh^2(z)
12+
at::Tensor d_tanh(at::Tensor z) {
13+
return 1 - z.tanh().pow(2);
14+
}
15+
16+
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
17+
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
18+
auto e = z.exp();
19+
auto mask = (alpha * (e - 1)) < 0;
20+
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
21+
}
22+
23+
std::vector<at::Tensor> lltm_forward(
24+
at::Tensor input,
25+
at::Tensor weights,
26+
at::Tensor bias,
27+
at::Tensor old_h,
28+
at::Tensor old_cell) {
29+
auto X = at::cat({old_h, input}, /*dim=*/1);
30+
31+
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
32+
auto gates = gate_weights.chunk(3, /*dim=*/1);
33+
34+
auto input_gate = at::sigmoid(gates[0]);
35+
auto output_gate = at::sigmoid(gates[1]);
36+
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
37+
38+
auto new_cell = old_cell + candidate_cell * input_gate;
39+
auto new_h = at::tanh(new_cell) * output_gate;
40+
41+
return {new_h,
42+
new_cell,
43+
input_gate,
44+
output_gate,
45+
candidate_cell,
46+
X,
47+
gate_weights};
48+
}
49+
50+
std::vector<at::Tensor> lltm_backward(
51+
at::Tensor grad_h,
52+
at::Tensor grad_cell,
53+
at::Tensor new_cell,
54+
at::Tensor input_gate,
55+
at::Tensor output_gate,
56+
at::Tensor candidate_cell,
57+
at::Tensor X,
58+
at::Tensor gate_weights,
59+
at::Tensor weights) {
60+
auto d_output_gate = at::tanh(new_cell) * grad_h;
61+
auto d_tanh_new_cell = output_gate * grad_h;
62+
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
63+
64+
auto d_old_cell = d_new_cell;
65+
auto d_candidate_cell = input_gate * d_new_cell;
66+
auto d_input_gate = candidate_cell * d_new_cell;
67+
68+
auto gates = gate_weights.chunk(3, /*dim=*/1);
69+
d_input_gate *= d_sigmoid(gates[0]);
70+
d_output_gate *= d_sigmoid(gates[1]);
71+
d_candidate_cell *= d_elu(gates[2]);
72+
73+
auto d_gates =
74+
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
75+
76+
auto d_weights = d_gates.t().mm(X);
77+
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
78+
79+
auto d_X = d_gates.mm(weights);
80+
const auto state_size = grad_h.size(1);
81+
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
82+
auto d_input = d_X.slice(/*dim=*/1, state_size);
83+
84+
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
85+
}
86+
87+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
88+
m.def("forward", &lltm_forward, "LLTM forward");
89+
m.def("backward", &lltm_backward, "LLTM backward");
90+
}

cpp/lltm.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import math
2+
from torch import nn
3+
from torch.autograd import Function
4+
import torch
5+
6+
import lltm_cpp
7+
8+
torch.manual_seed(42)
9+
10+
11+
class LLTMFunction(Function):
12+
@staticmethod
13+
def forward(ctx, input, weights, bias, old_h, old_cell):
14+
outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
15+
new_h, new_cell = outputs[:2]
16+
variables = outputs[1:] + [weights]
17+
ctx.save_for_backward(*variables)
18+
19+
return new_h, new_cell
20+
21+
@staticmethod
22+
def backward(ctx, grad_h, grad_cell):
23+
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward(
24+
grad_h, grad_cell, *ctx.saved_variables)
25+
return d_input, d_weights, d_bias, d_old_h, d_old_cell
26+
27+
28+
class LLTM(nn.Module):
29+
def __init__(self, input_features, state_size):
30+
super(LLTM, self).__init__()
31+
self.input_features = input_features
32+
self.state_size = state_size
33+
self.weights = nn.Parameter(
34+
torch.Tensor(3 * state_size, input_features + state_size))
35+
self.bias = nn.Parameter(torch.Tensor(3 * state_size))
36+
self.reset_parameters()
37+
38+
def reset_parameters(self):
39+
stdv = 1.0 / math.sqrt(self.state_size)
40+
for weight in self.parameters():
41+
weight.data.uniform_(-stdv, +stdv)
42+
43+
def forward(self, input, state):
44+
return LLTMFunction.apply(input, self.weights, self.bias, *state)

cpp/setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from setuptools import setup
2+
from torch.utils.cpp_extension import BuildExtension, CppExtension
3+
4+
setup(
5+
name='lltm_cpp',
6+
ext_modules=[
7+
CppExtension('lltm_cpp', ['lltm.cpp']),
8+
],
9+
cmdclass={
10+
'build_ext': BuildExtension
11+
})

cuda/__init__.py

Whitespace-only changes.

cuda/jit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torch.utils.cpp_extension import load
2+
lltm_cuda = load(
3+
'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True)
4+
help(lltm_cuda)

cuda/lltm.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import math
2+
from torch import nn
3+
from torch.autograd import Function
4+
import torch
5+
6+
import lltm_cuda
7+
8+
torch.manual_seed(42)
9+
10+
11+
class LLTMFunction(Function):
12+
@staticmethod
13+
def forward(ctx, input, weights, bias, old_h, old_cell):
14+
outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell)
15+
new_h, new_cell = outputs[:2]
16+
variables = outputs[1:] + [weights]
17+
ctx.save_for_backward(*variables)
18+
19+
return new_h, new_cell
20+
21+
@staticmethod
22+
def backward(ctx, grad_h, grad_cell):
23+
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cuda.backward(
24+
grad_h, grad_cell, *ctx.saved_variables)
25+
return d_input, d_weights, d_bias, d_old_h, d_old_cell
26+
27+
28+
class LLTM(nn.Module):
29+
def __init__(self, input_features, state_size):
30+
super(LLTM, self).__init__()
31+
self.input_features = input_features
32+
self.state_size = state_size
33+
self.weights = nn.Parameter(
34+
torch.Tensor(3 * state_size, input_features + state_size))
35+
self.bias = nn.Parameter(torch.Tensor(3 * state_size))
36+
self.reset_parameters()
37+
38+
def reset_parameters(self):
39+
stdv = 1.0 / math.sqrt(self.state_size)
40+
for weight in self.parameters():
41+
weight.data.uniform_(-stdv, +stdv)
42+
43+
def forward(self, input, state):
44+
return LLTMFunction.apply(input, self.weights, self.bias, *state)

0 commit comments

Comments
 (0)