@@ -342,3 +342,165 @@ function rrule(::typeof(fill), x::Any, dims...)
342
342
fill_pullback (Ȳ) = (NoTangent (), project (sum (Ȳ)), nots... )
343
343
return fill (x, dims... ), fill_pullback
344
344
end
345
+
346
+ # ####
347
+ # #### `findmax`, `maximum`, etc.
348
+ # ####
349
+
350
+ for findm in (:findmin , :findmax )
351
+ findm_pullback = Symbol (findm, :_pullback )
352
+
353
+ @eval function frule ((_, xdot), :: typeof ($ findm), x; dims= :)
354
+ y, ind = $ findm (x; dims= dims)
355
+ return (y, ind), Tangent {typeof((y, ind))} (xdot[ind], NoTangent ())
356
+ end
357
+
358
+ @eval function rrule (:: typeof ($ findm), x:: AbstractArray ; dims= :)
359
+ y, ind = $ findm (x; dims= dims)
360
+ project = ProjectTo (x)
361
+ # This pullback is a lot like the one for getindex. Ideally they would probably be combined?
362
+ function $findm_pullback ((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
363
+ dy isa AbstractZero && return (NoTangent (), NoTangent ())
364
+ x_thunk = @thunk project (_zerolike_writeat (x, unthunk (dy), dims, ind))
365
+ x_ithunk = InplaceableThunk (x_thunk) do dx
366
+ if dims isa Colon
367
+ view (dx, ind) .= view (dx, ind) .+ Ref (unthunk (dy))
368
+ else
369
+ view (dx, ind) .= view (dx, ind) .+ unthunk (dy) # this could be .+=, but not on Julia 1.0
370
+ end
371
+ dx
372
+ end
373
+ return (NoTangent (), x_ithunk)
374
+ end
375
+ return (y, ind), $ findm_pullback
376
+ end
377
+ end
378
+
379
+ # This function is roughly `setindex!(zero(x), dy, inds...)`:
380
+
381
+ function _zerolike_writeat (x:: AbstractArray{<:Number} , dy, dims, inds... )
382
+ # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
383
+ # allow `eltype(dy)`, nor does it work for many structured matrices.
384
+ dx = fill! (similar (x, eltype (dy), axes (x)), 0 )
385
+ view (dx, inds... ) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
386
+ dx
387
+ end
388
+ function _zerolike_writeat (x:: AbstractArray , dy, dims, inds... )
389
+ # Since we have `x`, we can also handle arrays of arrays.
390
+ dx = map (zero, x)
391
+ if dims isa Colon
392
+ view (dx, inds... ) .= Ref (dy)
393
+ else
394
+ view (dx, inds... ) .= dy
395
+ end
396
+ dx
397
+ end
398
+
399
+ # Allow for second derivatives, by writing rules for `_zerolike_writeat`;
400
+ # these rules are the reason it takes a `dims` argument.
401
+
402
+ function frule ((_, _, dydot), :: typeof (_zerolike_writeat), x, dy, dims, inds... )
403
+ return _zerolike_writeat (x, dy, dims, inds... ), _zerolike_writeat (x, dydot, dims, inds... )
404
+ end
405
+
406
+ function rrule (:: typeof (_zerolike_writeat), x, dy, dims, inds... )
407
+ z = _zerolike_writeat (x, dy, dims, inds... )
408
+ function _zerolike_writeat_pullback (dz)
409
+ dx = sum (view (unthunk (dz), inds... ); dims= dims)
410
+ nots = map (_ -> NoTangent (), inds)
411
+ return (NoTangent (), NoTangent (), dx, NoTangent (), nots... )
412
+ end
413
+ return z, _zerolike_writeat_pullback
414
+ end
415
+
416
+ # These rules for `maximum` pick the same subgradient as `findmax`:
417
+
418
+ function frule ((_, xdot), :: typeof (maximum), x; dims= :)
419
+ y, ind = findmax (x; dims= dims)
420
+ return y, xdot[ind]
421
+ end
422
+
423
+ function rrule (:: typeof (maximum), x:: AbstractArray ; dims= :)
424
+ (y, _), back = rrule (findmax, x; dims= dims)
425
+ maximum_pullback (dy) = back ((dy, nothing ))
426
+ return y, maximum_pullback
427
+ end
428
+
429
+ function frule ((_, xdot), :: typeof (minimum), x; dims= :)
430
+ y, ind = findmin (x; dims= dims)
431
+ return y, xdot[ind]
432
+ end
433
+
434
+ function rrule (:: typeof (minimum), x:: AbstractArray ; dims= :)
435
+ (y, _), back = rrule (findmin, x; dims= dims)
436
+ minimum_pullback (dy) = back ((dy, nothing ))
437
+ return y, minimum_pullback
438
+ end
439
+
440
+ # ####
441
+ # #### `extrema`
442
+ # ####
443
+
444
+ function rrule (:: typeof (extrema), x:: AbstractArray{<:Number} ; dims= :)
445
+ if dims isa Colon
446
+ return _extrema_colon (x)
447
+ else
448
+ return _extrema_dims (x, dims)
449
+ end
450
+ end
451
+
452
+ function _extrema_colon (x)
453
+ ylo, ilo = findmin (x)
454
+ yhi, ihi = findmax (x)
455
+ project = ProjectTo (x)
456
+ function extrema_pullback ((dylo, dyhi)) # accepts Tangent
457
+ if (dylo, dyhi) isa Tuple{AbstractZero, AbstractZero}
458
+ return (NoTangent (), NoTangent ())
459
+ end
460
+ # One argument may be AbstractZero here. Use promote_op because
461
+ # promote_type allows for * as well as +, hence gives Any.
462
+ T = Base. promote_op (+ , typeof (dylo), typeof (dyhi))
463
+ x_nothunk = let
464
+ # x_thunk = @thunk begin # this doesn't infer
465
+ dx = fill! (similar (x, T, axes (x)), false )
466
+ view (dx, ilo) .= dylo
467
+ view (dx, ihi) .= view (dx, ihi) .+ dyhi
468
+ project (dx)
469
+ end
470
+ # x_ithunk = InplaceableThunk(x_thunk) do dx
471
+ # view(dx, ilo) .= view(dx, ilo) .+ dylo
472
+ # view(dx, ihi) .= view(dx, ihi) .+ dyhi
473
+ # dx
474
+ # end
475
+ return (NoTangent (), x_nothunk)
476
+ end
477
+ return (ylo, yhi), extrema_pullback
478
+ end
479
+
480
+ function _extrema_dims (x, dims)
481
+ ylo, ilo = findmin (x; dims= dims)
482
+ yhi, ihi = findmax (x; dims= dims)
483
+ y = similar (ylo, Tuple{eltype (ylo), eltype (yhi)})
484
+ map! (tuple, y, ylo, yhi) # this is a GPU-friendly version of collect(zip(ylo, yhi))
485
+ project = ProjectTo (x)
486
+ function extrema_pullback_dims (dy_raw)
487
+ dy = unthunk (dy_raw)
488
+ @assert dy isa AbstractArray{<: Tuple{Any,Any} }
489
+ # Can we actually get Array{Tuple{Float64,ZeroTangent}} here? Not sure.
490
+ T = Base. promote_op (+ , eltype (dy). parameters... )
491
+ x_nothunk = let
492
+ # x_thunk = @thunk begin # this doesn't infer
493
+ dx = fill! (similar (x, T, axes (x)), false )
494
+ view (dx, ilo) .= first .(dy)
495
+ view (dx, ihi) .= view (dx, ihi) .+ last .(dy)
496
+ project (dx)
497
+ end
498
+ # x_ithunk = InplaceableThunk(x_thunk) do dx
499
+ # view(dx, ilo) .= first.(dy)
500
+ # view(dx, ihi) .= view(dx, ihi) .+ last.(dy)
501
+ # dx
502
+ # end
503
+ return (NoTangent (), x_nothunk)
504
+ end
505
+ return y, extrema_pullback_dims
506
+ end
0 commit comments