@@ -350,153 +350,3 @@ true
350
350
See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
351
351
"""
352
352
rrule (:: Any , :: Vararg{Any} ; kwargs... ) = nothing
353
-
354
-
355
- # ####
356
- # #### macros
357
- # ####
358
-
359
- """
360
- @scalar_rule(f(x₁, x₂, ...),
361
- @setup(statement₁, statement₂, ...),
362
- (∂f₁_∂x₁, ∂f₁_∂x₂, ...),
363
- (∂f₂_∂x₁, ∂f₂_∂x₂, ...),
364
- ...)
365
-
366
- A convenience macro that generates simple scalar forward or reverse rules using
367
- the provided partial derivatives. Specifically, generates the corresponding
368
- methods for `frule` and `rrule`:
369
-
370
- function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
371
- Ω = f(x₁, x₂, ...)
372
- \$ (statement₁, statement₂, ...)
373
- return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
374
- Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
375
- ...)
376
- end
377
-
378
- function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
379
- Ω = f(x₁, x₂, ...)
380
- \$ (statement₁, statement₂, ...)
381
- return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
382
- Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
383
- ...)
384
- end
385
-
386
- If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are
387
- provided, each parameter in the resulting `frule`/`rrule` definition is given a
388
- type constraint of `Number`.
389
- Constraints may also be explicitly be provided to override the `Number` constraint,
390
- e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
391
- `Number`.
392
-
393
- Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
394
- allows the primal result to be conveniently referenced (as `Ω`) within the
395
- derivative/setup expressions.
396
-
397
- Note that the `@setup` argument can be elided if no setup code is need. In other
398
- words:
399
-
400
- @scalar_rule(f(x₁, x₂, ...),
401
- (∂f₁_∂x₁, ∂f₁_∂x₂, ...),
402
- (∂f₂_∂x₁, ∂f₂_∂x₂, ...),
403
- ...)
404
-
405
- is equivalent to:
406
-
407
- @scalar_rule(f(x₁, x₂, ...),
408
- @setup(nothing),
409
- (∂f₁_∂x₁, ∂f₁_∂x₂, ...),
410
- (∂f₂_∂x₁, ∂f₂_∂x₂, ...),
411
- ...)
412
-
413
- For examples, see ChainRulesCore' `rules` directory.
414
-
415
- See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
416
- """
417
- macro scalar_rule (call, maybe_setup, partials... )
418
- if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
419
- setup_stmts = map (esc, maybe_setup. args[3 : end ])
420
- else
421
- setup_stmts = (nothing ,)
422
- partials = (maybe_setup, partials... )
423
- end
424
- @assert Meta. isexpr (call, :call )
425
- f = esc (call. args[1 ])
426
- # Annotate all arguments in the signature as scalars
427
- inputs = map (call. args[2 : end ]) do arg
428
- esc (Meta. isexpr (arg, :(:: )) ? arg : Expr (:(:: ), arg, :Number ))
429
- end
430
- # Remove annotations and escape names for the call
431
- for (i, arg) in enumerate (call. args)
432
- if Meta. isexpr (arg, :(:: ))
433
- call. args[i] = esc (first (arg. args))
434
- else
435
- call. args[i] = esc (arg)
436
- end
437
- end
438
- if all (Meta. isexpr (partial, :tuple ) for partial in partials)
439
- forward_rules = Any[rule_from_partials (partial. args... ) for partial in partials]
440
- reverse_rules = Any[]
441
- for i in 1 : length (inputs)
442
- reverse_partials = [partial. args[i] for partial in partials]
443
- push! (reverse_rules, rule_from_partials (reverse_partials... ))
444
- end
445
- else
446
- @assert length (inputs) == 1 && all (! Meta. isexpr (partial, :tuple ) for partial in partials)
447
- forward_rules = Any[rule_from_partials (partial) for partial in partials]
448
- reverse_rules = Any[rule_from_partials (partials... )]
449
- end
450
- forward_rules = length (forward_rules) == 1 ? forward_rules[1 ] : Expr (:tuple , forward_rules... )
451
- reverse_rules = length (reverse_rules) == 1 ? reverse_rules[1 ] : Expr (:tuple , reverse_rules... )
452
- return quote
453
- function ChainRulesCore. frule (:: typeof ($ f), $ (inputs... ))
454
- $ (esc (:Ω )) = $ call
455
- $ (setup_stmts... )
456
- return $ (esc (:Ω )), $ forward_rules
457
- end
458
- function ChainRulesCore. rrule (:: typeof ($ f), $ (inputs... ))
459
- $ (esc (:Ω )) = $ call
460
- $ (setup_stmts... )
461
- return $ (esc (:Ω )), $ reverse_rules
462
- end
463
- end
464
- end
465
-
466
- function rule_from_partials (∂s... )
467
- wirtinger_indices = findall (x -> Meta. isexpr (x, :call ) && x. args[1 ] === :Wirtinger , ∂s)
468
- ∂s = map (esc, ∂s)
469
- Δs = [Symbol (string (:Δ , i)) for i in 1 : length (∂s)]
470
- Δs_tuple = Expr (:tuple , Δs... )
471
- if isempty (wirtinger_indices)
472
- ∂_mul_Δs = [:(mul (@thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
473
- return :(Rule ($ Δs_tuple -> add ($ (∂_mul_Δs... ))))
474
- else
475
- ∂_mul_Δs_primal = Any[]
476
- ∂_mul_Δs_conjugate = Any[]
477
- ∂_wirtinger_defs = Any[]
478
- for i in 1 : length (∂s)
479
- if i in wirtinger_indices
480
- Δi = Δs[i]
481
- ∂i = Symbol (string (:∂ , i))
482
- push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
483
- ∂f∂i_mul_Δ = :(mul (wirtinger_primal ($ ∂i), wirtinger_primal ($ Δi)))
484
- ∂f∂ī_mul_Δ̄ = :(mul (conj (wirtinger_conjugate ($ ∂i)), wirtinger_conjugate ($ Δi)))
485
- ∂f̄∂i_mul_Δ = :(mul (wirtinger_conjugate ($ ∂i), wirtinger_primal ($ Δi)))
486
- ∂f̄∂ī_mul_Δ̄ = :(mul (conj (wirtinger_primal ($ ∂i)), wirtinger_conjugate ($ Δi)))
487
- push! (∂_mul_Δs_primal, :(add ($ ∂f∂i_mul_Δ, $ ∂f∂ī_mul_Δ̄)))
488
- push! (∂_mul_Δs_conjugate, :(add ($ ∂f̄∂i_mul_Δ, $ ∂f̄∂ī_mul_Δ̄)))
489
- else
490
- ∂_mul_Δ = :(mul (@thunk ($ (∂s[i])), $ (Δs[i])))
491
- push! (∂_mul_Δs_primal, ∂_mul_Δ)
492
- push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
493
- end
494
- end
495
- primal_rule = :(Rule ($ Δs_tuple -> add ($ (∂_mul_Δs_primal... ))))
496
- conjugate_rule = :(Rule ($ Δs_tuple -> add ($ (∂_mul_Δs_conjugate... ))))
497
- return quote
498
- $ (∂_wirtinger_defs... )
499
- WirtingerRule ($ primal_rule, $ conjugate_rule)
500
- end
501
- end
502
- end
0 commit comments