1
1
from enum import Enum
2
- from typing import Optional , List
2
+ from typing import Optional , List , Tuple
3
3
4
4
import torch
5
5
from torch import Tensor
@@ -125,14 +125,34 @@ def leaky_relu_backward_decomposition(
125
125
126
126
127
127
@register_decomposition (aten .gelu_backward )
128
- def gelu_backward_decomposition (grad : Tensor , self : Tensor ):
128
+ def gelu_backward_decomposition (grad : Tensor , self : Tensor , approximate : str = "none" ):
129
+ M_SQRT2 = 1.41421356237309504880
129
130
M_SQRT1_2 = 0.70710678118654752440
130
131
M_2_SQRTPI = 1.12837916709551257390
131
- kAlpha = M_SQRT1_2
132
- kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
133
- cdf = 0.5 * (1 + aten .erf (self * kAlpha ))
134
- pdf = kBeta * aten .exp (self * self * - 0.5 )
135
- return grad * (cdf + self * pdf )
132
+ if approximate == "none" :
133
+ kAlpha = M_SQRT1_2
134
+ kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
135
+ cdf = 0.5 * (1 + aten .erf (self * kAlpha ))
136
+ pdf = kBeta * aten .exp (self * self * - 0.5 )
137
+ return grad * (cdf + self * pdf )
138
+ else :
139
+ kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
140
+ kKappa = 0.044715
141
+ x_sq = self * self
142
+ x_cube = x_sq * self
143
+ inner = kBeta * (self + kKappa * x_cube )
144
+ tanh_inner = aten .tanh (inner )
145
+
146
+ left = 0.5 * self
147
+ right = 1 + tanh_inner
148
+
149
+ left_derivative = 0.5 * right
150
+
151
+ tanh_derivative = 1 - tanh_inner * tanh_inner
152
+ inner_derivative = kBeta * (1 + 3 * kKappa * x_sq )
153
+ right_derivative = left * tanh_derivative * inner_derivative
154
+
155
+ return grad * (left_derivative + right_derivative )
136
156
137
157
138
158
@register_decomposition (aten .mish_backward )
@@ -152,16 +172,62 @@ def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
152
172
# whyyyy does log_sigmoid do 2 different things for CPU and CUDA >:(
153
173
154
174
175
+ @register_decomposition (aten .softshrink_backward )
176
+ def softshrink_backward (grad_output : Tensor , self : Tensor , lambd : float ) -> Tensor :
177
+ return aten .where (
178
+ (self >= - lambd ) & (self <= lambd ), aten .new_zeros (grad_output , ()), grad_output
179
+ )
180
+
181
+
182
+ @register_decomposition (aten .prelu_backward )
183
+ def prelu_backward (
184
+ grad_output : Tensor , self : Tensor , weight : Tensor
185
+ ) -> Tuple [Tensor , Tensor ]:
186
+ # Logic is more complicated than I would like. Basically, weight can either
187
+ # be a scalar or a vector of size [C], and in the forward pass it's
188
+ # broadcast against [N, C, ...]. So now, we need to do the corresponding
189
+ # reduction, which is harder than we'd like...
190
+ cur_weight = weight
191
+ for _ in range (2 , grad_output .dim ()):
192
+ cur_weight = cur_weight .unsqueeze (- 1 )
193
+ input_grad = aten .where (self > 0 , grad_output , cur_weight * grad_output )
194
+ weight_grad_collector = aten .where (
195
+ self > 0 , aten .new_zeros (grad_output , ()), self * grad_output
196
+ )
197
+ out = aten .sum_to_size (weight_grad_collector , cur_weight .shape )
198
+ while out .dim () > weight .dim ():
199
+ out = out .squeeze (- 1 )
200
+ return (input_grad , out )
201
+
202
+
203
+ @register_decomposition (aten .rrelu_with_noise_backward )
204
+ def rrelu_with_noise_backward (
205
+ grad_output : Tensor ,
206
+ self : Tensor ,
207
+ noise : Tensor ,
208
+ lower : float ,
209
+ upper : float ,
210
+ training : bool ,
211
+ self_is_result : bool ,
212
+ ) -> Tensor :
213
+ if training and upper - lower > 1e-6 :
214
+ return grad_output .mul (noise )
215
+ else :
216
+ negative_slope = (lower + upper ) / 2
217
+ return aten .leaky_relu_backward (
218
+ grad_output , self , negative_slope , self_is_result
219
+ )
220
+
221
+
155
222
@register_decomposition (aten .log_sigmoid_backward )
156
223
def log_sigmoid_backward (grad_output : Tensor , self : Tensor , buffer : Tensor ) -> Tensor :
157
224
in_negative = self < 0
158
225
max_deriv = aten .where (in_negative , 1 , 0 )
159
226
sign = aten .where (in_negative , 1 , - 1 )
160
- if grad_output .is_cuda : # buffer is not used on CUDA
161
- z = aten .exp (- aten .abs (self ))
162
- return grad_output * (max_deriv - sign * (z / (1 + z )))
163
- else :
164
- return (max_deriv - sign * (buffer / (1 + buffer ))) * grad_output
227
+ z = aten .exp (- aten .abs (self ))
228
+ return grad_output * (max_deriv - sign * (z / (1 + z )))
229
+ # CPU has a special formula that uses buffer, but disabled for convenience sake
230
+ # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
165
231
166
232
167
233
@register_decomposition (aten .mse_loss_backward )
@@ -185,6 +251,22 @@ def huber_loss_backward(
185
251
)
186
252
187
253
254
+ @register_decomposition (aten .binary_cross_entropy_backward )
255
+ def binary_cross_entropy_backward (
256
+ grad_output : Tensor ,
257
+ self : Tensor ,
258
+ target : Tensor ,
259
+ weight : Optional [Tensor ] = None ,
260
+ reduction : int = Reduction .MEAN ,
261
+ ) -> Tensor :
262
+ if weight is None :
263
+ weight = 1
264
+ result = weight * (self - target ) / self / (1 - self )
265
+ if reduction == Reduction .MEAN :
266
+ result = result * (1.0 / self .numel ())
267
+ return result * grad_output
268
+
269
+
188
270
@register_decomposition (aten .slice_backward )
189
271
def slice_backward (
190
272
grad_output : Tensor ,
@@ -252,6 +334,17 @@ def im2col_backward(
252
334
return aten .col2im (grad_output , input_size , kernel_size , dilation , padding , stride )
253
335
254
336
337
+ @register_decomposition (aten .col2im_backward )
338
+ def col2im_backward (
339
+ grad_output : Tensor ,
340
+ kernel_size : List [int ],
341
+ dilation : List [int ],
342
+ padding : List [int ],
343
+ stride : List [int ],
344
+ ) -> Tensor :
345
+ return aten .im2col (grad_output , kernel_size , dilation , padding , stride )
346
+
347
+
255
348
@register_decomposition (aten .logit_backward )
256
349
def logit_backward (
257
350
grad_output : Tensor , self : Tensor , eps : Optional [float ] = None
@@ -287,15 +380,114 @@ def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
287
380
return shifted - shifted_logsumexp
288
381
289
382
290
- @register_decomposition (aten .addmm )
291
- def addmm (self : Tensor , mat1 : Tensor , mat2 : Tensor , beta = 1 , alpha = 1 ):
292
- if not self .is_floating_point ():
293
- beta = int (beta )
294
- alpha = int (alpha )
295
- out = alpha * aten .mm (mat1 , mat2 )
296
- if beta == 0 :
297
- return out
298
- return beta * self + out
383
+ @register_decomposition (aten .addcdiv )
384
+ def addcdiv (self : Tensor , tensor1 : Tensor , tensor2 : Tensor , value : float = 1 ):
385
+ return self + value * (tensor1 / tensor2 )
386
+
387
+
388
+ @register_decomposition (aten .addcmul )
389
+ def addcmul (self : Tensor , tensor1 : Tensor , tensor2 : Tensor , value : float = 1 ):
390
+ if self .is_floating_point ():
391
+ return self + value * tensor1 * tensor2
392
+ else :
393
+ return self + int (value ) * tensor1 * tensor2
394
+
395
+
396
+ @register_decomposition (aten .embedding_dense_backward )
397
+ def embedding_dense_backward (
398
+ grad_output : Tensor ,
399
+ indices : Tensor ,
400
+ num_weights : int ,
401
+ padding_idx : int ,
402
+ scale_grad_by_freq : bool ,
403
+ ):
404
+ numel = indices .numel ()
405
+ grad = grad_output .view (numel , grad_output .size (- 1 ))
406
+ grad_weight = aten .new_zeros (grad_output , (num_weights , grad_output .shape [- 1 ]))
407
+ indices_rank1 = indices .view (numel )
408
+ if scale_grad_by_freq :
409
+ counts = aten .new_zeros (indices , (num_weights ,))
410
+ ones = aten .new_ones (indices , (numel ,))
411
+ counts = aten .index_put (counts , [indices_rank1 ], ones , accumulate = True )
412
+ grad_weights_scale = aten .index (counts , [indices_rank1 ])
413
+ grad = grad / grad_weights_scale .unsqueeze (1 )
414
+ skip_padding = (indices_rank1 != padding_idx ).unsqueeze (1 )
415
+ skip_padding = skip_padding .expand_as (grad )
416
+ zero_grad = aten .full_like (grad , 0 )
417
+ return aten .index_put (
418
+ grad_weight ,
419
+ [indices_rank1 ],
420
+ aten .where (skip_padding , grad , zero_grad ),
421
+ accumulate = True ,
422
+ )
423
+
424
+
425
+ def prod (x ):
426
+ r = 1
427
+ for i in x :
428
+ r *= i
429
+ return r
430
+
431
+
432
+ @register_decomposition (aten .native_layer_norm )
433
+ def native_layer_norm (
434
+ input : Tensor ,
435
+ normalized_shape : List [int ],
436
+ weight : Optional [Tensor ],
437
+ bias : Optional [Tensor ],
438
+ eps : float ,
439
+ ) -> Tuple [Tensor , Tensor , Tensor ]:
440
+ input_shape = input .shape
441
+ input_ndim = input .dim ()
442
+
443
+ axis = input_ndim - len (normalized_shape )
444
+ M = prod (input_shape [:axis ])
445
+
446
+ # Hmm... not sure how I get around this...
447
+ # Basically, native_batch_norm doesn't support 0-entry tensors, while
448
+ # native_layer_norm does (and is tested by OpInfos!)
449
+ if M > 0 :
450
+ input_reshaped = input .view (1 , M , - 1 )
451
+ else :
452
+ return (input , aten .new_empty (input , (0 ,)), aten .new_empty (input , (0 ,)))
453
+
454
+ # Unlike Batch Normalization, which applies scalar scale and bias for each
455
+ # entire channel/plane with the affine option, Layer Normalization applies
456
+ # per-element scale and bias. E.g. For input {N, C, H, W}, weight for
457
+ # batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
458
+ out , mean , rstd = aten .native_batch_norm (
459
+ input_reshaped ,
460
+ weight = None ,
461
+ bias = None ,
462
+ running_mean = None ,
463
+ running_var = None ,
464
+ training = True ,
465
+ momentum = 0 ,
466
+ eps = eps ,
467
+ )
468
+ out = out .view (input_shape )
469
+ if weight is not None :
470
+ out = out * weight
471
+ if bias is not None :
472
+ out = out + bias
473
+
474
+ stat_shape = list (input_shape [:axis ])
475
+ for _ in range (axis , input .dim ()):
476
+ stat_shape .append (1 )
477
+ mean = mean .view (stat_shape )
478
+ rstd = rstd .view (stat_shape )
479
+ return (out , mean , rstd )
480
+
481
+
482
+ # @register_decomposition(aten.addmm)
483
+ # def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
484
+ # if not self.is_floating_point():
485
+ # beta = int(beta)
486
+ # alpha = int(alpha)
487
+ # out = alpha * aten.mm(mat1, mat2)
488
+ # if beta == 0:
489
+ # return out
490
+ # return beta * self + out
299
491
300
492
301
493
@register_decomposition (aten .clamp_min )
@@ -308,18 +500,14 @@ def clamp_max(self: Tensor, min: float):
308
500
return aten .clamp (self , max = max )
309
501
310
502
311
- # @register_decomposition(aten._fused_dropout)
312
- # def _fused_dropout_decomposition(input, p, generator=None):
313
- # mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
314
- # res = mask.type_as(input) * input * (1./ p)
315
- # return [res, mask]
503
+ @register_decomposition (aten ._fused_dropout )
504
+ def _fused_dropout_decomposition (input , p , generator = None ):
505
+ mask = aten .to (aten .rand_like (input ) < p , dtype = torch .uint8 )
506
+ res = mask .type_as (input ) * input * (1.0 / p )
507
+ return [res , mask ]
316
508
317
509
318
510
# Questionable decompositions
319
- @register_decomposition (aten ._s_where )
320
- def _s_where_canonicalization (a , b , c ):
321
- return aten .where (a , b , c )
322
-
323
511
324
512
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
325
513
# Note that this decomposition causes issues with in-place ops
0 commit comments