Skip to content
This repository was archived by the owner on Aug 26, 2022. It is now read-only.

Commit 0cec28d

Browse files
committed
merge changes from functorch upstream, fix docs
1 parent a86065d commit 0cec28d

File tree

10 files changed

+270
-43
lines changed

10 files changed

+270
-43
lines changed

oslo/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright 2021 TUNiB Inc.
22

3-
version = "2.0.0"
3+
version = "2.0.1"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
LAST UPSTREAM INFORMATION
33
4-
- date: 2022/02/14
5-
- commit: https://github.com/pytorch/functorch/commit/cd41d6ebc0402d94ae6af51f163ee728277a7aa4
4+
- date: 2022/02/21
5+
- commit: https://github.com/pytorch/functorch/commit/0c0f325ba3c83e70c215f231cfd810af68141767
66
"""

oslo/pytorch/kernel_fusion/mem_efficient/aot_autograd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _reshape_alias(x, shape, strides):
110110
return aten.view(x, shape)
111111

112112

113+
113114
def create_aot_autograd_function(
114115
flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state
115116
):

oslo/pytorch/kernel_fusion/mem_efficient/compilers.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@
1313
)
1414

1515

16+
def _canonicalize(fx_g):
17+
for node in fx_g.graph.nodes:
18+
if node.target == torch.ops.aten._s_where:
19+
node.target = torch.ops.aten.where
20+
fx_g.recompile()
21+
return fx_g
22+
23+
1624
def ts_compile(fx_g, _):
17-
# print(fx_g.code)
25+
fx_g = _canonicalize(fx_g)
1826
for node in fx_g.graph.nodes:
1927
if node.target == torch.ops.aten.new_zeros:
2028
if node.args[1] == []:
@@ -215,6 +223,7 @@ def nop(f, _):
215223

216224

217225
def simple_ts_compile(fx_g, _):
226+
fx_g = _canonicalize(fx_g)
218227
f = torch.jit.script(fx_g)
219228
f = torch.jit.freeze(f.eval())
220229
return f
@@ -284,12 +293,13 @@ def debug_compile(fx_g, inps):
284293
##############################################################
285294
import torch
286295
import torch.fx as fx
287-
from torch.compile import minimizer, check_nvfuser_subprocess
296+
from functorch.compile import minifier, check_nvfuser_subprocess
288297
inps = {[(i.shape, i.dtype) for i in inps]}
298+
inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
289299
from foo import FxModule
290300
mod = FxModule().cuda()
291301
with torch.jit.fuser("fuser2"):
292-
minimizer(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
302+
minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
293303
"""
294304
)
295305

oslo/pytorch/kernel_fusion/mem_efficient/decompositions.py

Lines changed: 218 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Optional, List
2+
from typing import Optional, List, Tuple
33

44
import torch
55
from torch import Tensor
@@ -125,14 +125,34 @@ def leaky_relu_backward_decomposition(
125125

126126

127127
@register_decomposition(aten.gelu_backward)
128-
def gelu_backward_decomposition(grad: Tensor, self: Tensor):
128+
def gelu_backward_decomposition(grad: Tensor, self: Tensor, approximate: str = "none"):
129+
M_SQRT2 = 1.41421356237309504880
129130
M_SQRT1_2 = 0.70710678118654752440
130131
M_2_SQRTPI = 1.12837916709551257390
131-
kAlpha = M_SQRT1_2
132-
kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
133-
cdf = 0.5 * (1 + aten.erf(self * kAlpha))
134-
pdf = kBeta * aten.exp(self * self * -0.5)
135-
return grad * (cdf + self * pdf)
132+
if approximate == "none":
133+
kAlpha = M_SQRT1_2
134+
kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
135+
cdf = 0.5 * (1 + aten.erf(self * kAlpha))
136+
pdf = kBeta * aten.exp(self * self * -0.5)
137+
return grad * (cdf + self * pdf)
138+
else:
139+
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
140+
kKappa = 0.044715
141+
x_sq = self * self
142+
x_cube = x_sq * self
143+
inner = kBeta * (self + kKappa * x_cube)
144+
tanh_inner = aten.tanh(inner)
145+
146+
left = 0.5 * self
147+
right = 1 + tanh_inner
148+
149+
left_derivative = 0.5 * right
150+
151+
tanh_derivative = 1 - tanh_inner * tanh_inner
152+
inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
153+
right_derivative = left * tanh_derivative * inner_derivative
154+
155+
return grad * (left_derivative + right_derivative)
136156

137157

138158
@register_decomposition(aten.mish_backward)
@@ -152,16 +172,62 @@ def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
152172
# whyyyy does log_sigmoid do 2 different things for CPU and CUDA >:(
153173

154174

175+
@register_decomposition(aten.softshrink_backward)
176+
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
177+
return aten.where(
178+
(self >= -lambd) & (self <= lambd), aten.new_zeros(grad_output, ()), grad_output
179+
)
180+
181+
182+
@register_decomposition(aten.prelu_backward)
183+
def prelu_backward(
184+
grad_output: Tensor, self: Tensor, weight: Tensor
185+
) -> Tuple[Tensor, Tensor]:
186+
# Logic is more complicated than I would like. Basically, weight can either
187+
# be a scalar or a vector of size [C], and in the forward pass it's
188+
# broadcast against [N, C, ...]. So now, we need to do the corresponding
189+
# reduction, which is harder than we'd like...
190+
cur_weight = weight
191+
for _ in range(2, grad_output.dim()):
192+
cur_weight = cur_weight.unsqueeze(-1)
193+
input_grad = aten.where(self > 0, grad_output, cur_weight * grad_output)
194+
weight_grad_collector = aten.where(
195+
self > 0, aten.new_zeros(grad_output, ()), self * grad_output
196+
)
197+
out = aten.sum_to_size(weight_grad_collector, cur_weight.shape)
198+
while out.dim() > weight.dim():
199+
out = out.squeeze(-1)
200+
return (input_grad, out)
201+
202+
203+
@register_decomposition(aten.rrelu_with_noise_backward)
204+
def rrelu_with_noise_backward(
205+
grad_output: Tensor,
206+
self: Tensor,
207+
noise: Tensor,
208+
lower: float,
209+
upper: float,
210+
training: bool,
211+
self_is_result: bool,
212+
) -> Tensor:
213+
if training and upper - lower > 1e-6:
214+
return grad_output.mul(noise)
215+
else:
216+
negative_slope = (lower + upper) / 2
217+
return aten.leaky_relu_backward(
218+
grad_output, self, negative_slope, self_is_result
219+
)
220+
221+
155222
@register_decomposition(aten.log_sigmoid_backward)
156223
def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
157224
in_negative = self < 0
158225
max_deriv = aten.where(in_negative, 1, 0)
159226
sign = aten.where(in_negative, 1, -1)
160-
if grad_output.is_cuda: # buffer is not used on CUDA
161-
z = aten.exp(-aten.abs(self))
162-
return grad_output * (max_deriv - sign * (z / (1 + z)))
163-
else:
164-
return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
227+
z = aten.exp(-aten.abs(self))
228+
return grad_output * (max_deriv - sign * (z / (1 + z)))
229+
# CPU has a special formula that uses buffer, but disabled for convenience sake
230+
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
165231

166232

167233
@register_decomposition(aten.mse_loss_backward)
@@ -185,6 +251,22 @@ def huber_loss_backward(
185251
)
186252

187253

254+
@register_decomposition(aten.binary_cross_entropy_backward)
255+
def binary_cross_entropy_backward(
256+
grad_output: Tensor,
257+
self: Tensor,
258+
target: Tensor,
259+
weight: Optional[Tensor] = None,
260+
reduction: int = Reduction.MEAN,
261+
) -> Tensor:
262+
if weight is None:
263+
weight = 1
264+
result = weight * (self - target) / self / (1 - self)
265+
if reduction == Reduction.MEAN:
266+
result = result * (1.0 / self.numel())
267+
return result * grad_output
268+
269+
188270
@register_decomposition(aten.slice_backward)
189271
def slice_backward(
190272
grad_output: Tensor,
@@ -252,6 +334,17 @@ def im2col_backward(
252334
return aten.col2im(grad_output, input_size, kernel_size, dilation, padding, stride)
253335

254336

337+
@register_decomposition(aten.col2im_backward)
338+
def col2im_backward(
339+
grad_output: Tensor,
340+
kernel_size: List[int],
341+
dilation: List[int],
342+
padding: List[int],
343+
stride: List[int],
344+
) -> Tensor:
345+
return aten.im2col(grad_output, kernel_size, dilation, padding, stride)
346+
347+
255348
@register_decomposition(aten.logit_backward)
256349
def logit_backward(
257350
grad_output: Tensor, self: Tensor, eps: Optional[float] = None
@@ -287,15 +380,114 @@ def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
287380
return shifted - shifted_logsumexp
288381

289382

290-
@register_decomposition(aten.addmm)
291-
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
292-
if not self.is_floating_point():
293-
beta = int(beta)
294-
alpha = int(alpha)
295-
out = alpha * aten.mm(mat1, mat2)
296-
if beta == 0:
297-
return out
298-
return beta * self + out
383+
@register_decomposition(aten.addcdiv)
384+
def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
385+
return self + value * (tensor1 / tensor2)
386+
387+
388+
@register_decomposition(aten.addcmul)
389+
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
390+
if self.is_floating_point():
391+
return self + value * tensor1 * tensor2
392+
else:
393+
return self + int(value) * tensor1 * tensor2
394+
395+
396+
@register_decomposition(aten.embedding_dense_backward)
397+
def embedding_dense_backward(
398+
grad_output: Tensor,
399+
indices: Tensor,
400+
num_weights: int,
401+
padding_idx: int,
402+
scale_grad_by_freq: bool,
403+
):
404+
numel = indices.numel()
405+
grad = grad_output.view(numel, grad_output.size(-1))
406+
grad_weight = aten.new_zeros(grad_output, (num_weights, grad_output.shape[-1]))
407+
indices_rank1 = indices.view(numel)
408+
if scale_grad_by_freq:
409+
counts = aten.new_zeros(indices, (num_weights,))
410+
ones = aten.new_ones(indices, (numel,))
411+
counts = aten.index_put(counts, [indices_rank1], ones, accumulate=True)
412+
grad_weights_scale = aten.index(counts, [indices_rank1])
413+
grad = grad / grad_weights_scale.unsqueeze(1)
414+
skip_padding = (indices_rank1 != padding_idx).unsqueeze(1)
415+
skip_padding = skip_padding.expand_as(grad)
416+
zero_grad = aten.full_like(grad, 0)
417+
return aten.index_put(
418+
grad_weight,
419+
[indices_rank1],
420+
aten.where(skip_padding, grad, zero_grad),
421+
accumulate=True,
422+
)
423+
424+
425+
def prod(x):
426+
r = 1
427+
for i in x:
428+
r *= i
429+
return r
430+
431+
432+
@register_decomposition(aten.native_layer_norm)
433+
def native_layer_norm(
434+
input: Tensor,
435+
normalized_shape: List[int],
436+
weight: Optional[Tensor],
437+
bias: Optional[Tensor],
438+
eps: float,
439+
) -> Tuple[Tensor, Tensor, Tensor]:
440+
input_shape = input.shape
441+
input_ndim = input.dim()
442+
443+
axis = input_ndim - len(normalized_shape)
444+
M = prod(input_shape[:axis])
445+
446+
# Hmm... not sure how I get around this...
447+
# Basically, native_batch_norm doesn't support 0-entry tensors, while
448+
# native_layer_norm does (and is tested by OpInfos!)
449+
if M > 0:
450+
input_reshaped = input.view(1, M, -1)
451+
else:
452+
return (input, aten.new_empty(input, (0,)), aten.new_empty(input, (0,)))
453+
454+
# Unlike Batch Normalization, which applies scalar scale and bias for each
455+
# entire channel/plane with the affine option, Layer Normalization applies
456+
# per-element scale and bias. E.g. For input {N, C, H, W}, weight for
457+
# batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
458+
out, mean, rstd = aten.native_batch_norm(
459+
input_reshaped,
460+
weight=None,
461+
bias=None,
462+
running_mean=None,
463+
running_var=None,
464+
training=True,
465+
momentum=0,
466+
eps=eps,
467+
)
468+
out = out.view(input_shape)
469+
if weight is not None:
470+
out = out * weight
471+
if bias is not None:
472+
out = out + bias
473+
474+
stat_shape = list(input_shape[:axis])
475+
for _ in range(axis, input.dim()):
476+
stat_shape.append(1)
477+
mean = mean.view(stat_shape)
478+
rstd = rstd.view(stat_shape)
479+
return (out, mean, rstd)
480+
481+
482+
# @register_decomposition(aten.addmm)
483+
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
484+
# if not self.is_floating_point():
485+
# beta = int(beta)
486+
# alpha = int(alpha)
487+
# out = alpha * aten.mm(mat1, mat2)
488+
# if beta == 0:
489+
# return out
490+
# return beta * self + out
299491

300492

301493
@register_decomposition(aten.clamp_min)
@@ -308,18 +500,14 @@ def clamp_max(self: Tensor, min: float):
308500
return aten.clamp(self, max=max)
309501

310502

311-
# @register_decomposition(aten._fused_dropout)
312-
# def _fused_dropout_decomposition(input, p, generator=None):
313-
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
314-
# res = mask.type_as(input) * input * (1./p)
315-
# return [res, mask]
503+
@register_decomposition(aten._fused_dropout)
504+
def _fused_dropout_decomposition(input, p, generator=None):
505+
mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
506+
res = mask.type_as(input) * input * (1.0 / p)
507+
return [res, mask]
316508

317509

318510
# Questionable decompositions
319-
@register_decomposition(aten._s_where)
320-
def _s_where_canonicalization(a, b, c):
321-
return aten.where(a, b, c)
322-
323511

324512
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
325513
# Note that this decomposition causes issues with in-place ops

0 commit comments

Comments
 (0)