Skip to content

formatter with Runic #1987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .formatting/Project.toml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file and all other files in this folder. It has to be maintained and there is no reason for users to do anything else than following the Runic documentation.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
Runic = "62bfec6d-59d7-401d-8490-b29ee721c001"
14 changes: 14 additions & 0 deletions .formatting/format_all.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Runic

project_path = Base.Filesystem.joinpath(Base.Filesystem.dirname(Base.source_path()), "..")

println("Formatting code with Runic...")

# Format all files in the project
not_formatted = Runic.main(["--inplace", project_path])
if not_formatted == 0
@info "Formatting completed successfully."
else
@warn "Formatting failed!"
end
exit(not_formatted)
18 changes: 18 additions & 0 deletions .formatting/format_check.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using Runic

project_path = Base.Filesystem.joinpath(Base.Filesystem.dirname(Base.source_path()), "..")

println("Checking code formatting with Runic...")

# Check if files are properly formatted
not_formatted = Runic.main(["--check", "--diff", project_path])

if not_formatted == 0
println("✅ All files are properly formatted!")
exit(0)
else
println("❌ Formatting check failed!")
println("Some files are not properly formatted.")
println("To fix formatting, run: julia --project=.formatting -e 'using Pkg; Pkg.instantiate(); include(\".formatting/format_all.jl\")'")
exit(1)
end
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work - this is referring to the commit in your fork but we have to ignore the commit that might finally land in the master branch. We have to add this file only once the formatting commit is in the master branch and its hash is known.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Runic formatting commit
e0864bbfea8e3a766f4d49ec956dc8ea52f19e9e
23 changes: 19 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
push:
branches:
- master
tags: '*'
tags: "*"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change this?

Suggested change
tags: "*"
tags: '*'

workflow_dispatch:
merge_group:

Expand All @@ -16,15 +16,30 @@ concurrency:
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
format:
name: Format Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: "1"
show-versioninfo: true
- uses: julia-actions/cache@v2
- run: |
julia --project=.formatting -e '
using Pkg
Pkg.instantiate()
include(".formatting/format_check.jl")'
Comment on lines +19 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep it simple and just follow the recommended and documented workflow in the Runic docs:

Suggested change
format:
name: Format Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: "1"
show-versioninfo: true
- uses: julia-actions/cache@v2
- run: |
julia --project=.formatting -e '
using Pkg
Pkg.instantiate()
include(".formatting/format_check.jl")'
runic:
name: Runic formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: julia-actions/cache@v2
- uses: fredrikekre/runic-action@v1
with:
version: '1'

Possibly it could be made a separate workflow as in e.g. PDMats (https://github.com/JuliaStats/PDMats.jl/blob/master/.github/workflows/Format.yml) but that doesn't matter too much IMO.

test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- 'min'
- '1'
- "min"
- "1"
Comment on lines +41 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change these lines:

Suggested change
- "min"
- "1"
- 'min'
- '1'

- pre
os:
- ubuntu-latest
Expand Down Expand Up @@ -57,7 +72,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: "1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
version: "1"
version: '1'

show-versioninfo: true
- run: |
julia --project=docs -e '
Expand Down
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Distributions.jl
[![](https://zenodo.org/badge/DOI/10.5281/zenodo.2647458.svg)](https://zenodo.org/record/2647458)
[![Coverage Status](https://coveralls.io/repos/JuliaStats/Distributions.jl/badge.svg?branch=master)](https://coveralls.io/r/JuliaStats/Distributions.jl?branch=master)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![code style: runic](https://img.shields.io/badge/code_style-%E1%9A%B1%E1%9A%A2%E1%9A%BE%E1%9B%81%E1%9A%B2-black)](https://github.com/fredrikekre/Runic.jl)

[![](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaStats.github.io/Distributions.jl/latest/)
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaStats.github.io/Distributions.jl/stable/)
Expand Down Expand Up @@ -40,7 +41,7 @@ not been reported yet on the issues of the repository.
If not, you can file a new issue, add your version of the package
which you can get with this command in the Julia REPL:
```julia
julia> ]status Distributions
julia> ] status Distributions
```

Be exhaustive in your report, summarize the bug, and provide:
Expand All @@ -55,6 +56,20 @@ clone it and make modifications on a new branch,
Once your changes are made, push them on your fork and create the
Pull Request on the main repository.

To format the code, run the following command:
```bash
julia --project=.formatting -e 'using Pkg; Pkg.instantiate(); include(".formatting/format_all.jl")'
```

**Note:** Code formatting is automatically checked in CI using Runic.
The formatting command can be run locally with
```julia
julia --project=.formatting -e 'using Pkg; Pkg.instantiate(); include(".formatting/format_check.jl")'
```
The `.git-blame-ignore-revs` file contains commit hashes for mass formatting changes.
This allows `git blame` to show the actual authors of code changes rather than the formatting commit.
When viewing blame information, use `git blame --ignore-revs-file .git-blame-ignore-revs <filename>`.

Comment on lines +59 to +72
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to document it more prominently (not sure about it, users don't have to care and everyone who submits a PR will notice and learn it anyway), we should just refer to the Runic docs.

Suggested change
To format the code, run the following command:
```bash
julia --project=.formatting -e 'using Pkg; Pkg.instantiate(); include(".formatting/format_all.jl")'
```
**Note:** Code formatting is automatically checked in CI using Runic.
The formatting command can be run locally with
```julia
julia --project=.formatting -e 'using Pkg; Pkg.instantiate(); include(".formatting/format_check.jl")'
```
The `.git-blame-ignore-revs` file contains commit hashes for mass formatting changes.
This allows `git blame` to show the actual authors of code changes rather than the formatting commit.
When viewing blame information, use `git blame --ignore-revs-file .git-blame-ignore-revs <filename>`.

### Requirements

Distributions is a central package which many rely on,
Expand Down
10 changes: 6 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ import Random: AbstractRNG, rand!

makedocs(;
sitename = "Distributions.jl",
modules = [Distributions],
format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true",
assets = ["assets/favicon.ico"]),
pages = [
modules = [Distributions],
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", nothing) == "true",
assets = ["assets/favicon.ico"],
),
pages = [
"index.md",
"starting.md",
"types.md",
Expand Down
5 changes: 4 additions & 1 deletion ext/DistributionsChainRulesCoreExt/eachvariate.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
function ChainRulesCore.rrule(::Type{Distributions.EachVariate{V}}, x::AbstractArray{<:Real}) where {V}
function ChainRulesCore.rrule(
::Type{Distributions.EachVariate{V}},
x::AbstractArray{<:Real},
) where {V}
y = Distributions.EachVariate{V}(x)
size_x = size(x)
function EachVariate_pullback(Δ)
Expand Down
58 changes: 43 additions & 15 deletions ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,54 @@
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
function ChainRulesCore.frule(
(_, Δalpha)::Tuple{Any, Any},
::Type{DT},
alpha::AbstractVector{T};
check_args::Bool = true,
) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args = check_args)
∂alpha0 = sum(Δalpha)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
end))
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
∂lmnB = sum(
Broadcast.instantiate(
Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
end,
),
)
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha = Δalpha, alpha0 = ∂alpha0, lmnB = ∂lmnB)
return d, Δd
end

function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
function ChainRulesCore.rrule(
::Type{DT},
alpha::AbstractVector{T};
check_args::Bool = true,
) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args = check_args)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
function Dirichlet_pullback(_Δd)
Δd = ChainRulesCore.unthunk(_Δd)
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
Δalpha =
Δd.alpha .+ Δd.alpha0 .+
Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), Δalpha
end
return d, Dirichlet_pullback
end

function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real})
function ChainRulesCore.frule(
(_, Δd, Δx)::Tuple{Any, Any, Any},
::typeof(Distributions._logpdf),
d::Dirichlet,
x::AbstractVector{<:Real},
)
Ω = Distributions._logpdf(d, x)
∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
StatsFuns.xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
end))
∂alpha = sum(
Broadcast.instantiate(
Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
StatsFuns.xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
end,
),
)
∂lmnB = -Δd.lmnB
ΔΩ = ∂alpha + ∂lmnB
if !isfinite(Ω)
Expand All @@ -33,15 +57,19 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distri
return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
function ChainRulesCore.rrule(
::typeof(Distributions._logpdf),
d::T,
x::AbstractVector{<:Real},
) where {T <: Dirichlet}
Ω = Distributions._logpdf(d, x)
isfinite_Ω = isfinite(Ω)
alpha = d.alpha
function _logpdf_Dirichlet_pullback(_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
Δd = ChainRulesCore.Tangent{T}(; alpha = ∂alpha, lmnB = ∂lmnB)
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
return ChainRulesCore.NoTangent(), Δd, Δx
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ function ChainRulesCore.rrule(::typeof(logpdf), d::Uniform, x::Real)
function logpdf_Uniform_pullback(Δ)
Δa = Δ / diff
Δd = if insupport
ChainRulesCore.Tangent{typeof(d)}(; a=Δa, b=-Δa)
ChainRulesCore.Tangent{typeof(d)}(; a = Δa, b = -Δa)
else
ChainRulesCore.Tangent{typeof(d)}(; a=zero(Δa), b=zero(Δa))
ChainRulesCore.Tangent{typeof(d)}(; a = zero(Δa), b = zero(Δa))
end
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.ZeroTangent()
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
## Callable struct to fix type inference issues caused by captured values
struct LogPDFNegativeBinomialPullback{D,T<:Real}
struct LogPDFNegativeBinomialPullback{D, T <: Real}
∂r::T
∂p::T
end

function (f::LogPDFNegativeBinomialPullback{D})(Δ) where {D}
Δr = Δ * f.∂r
Δp = Δ * f.∂p
Δd = ChainRulesCore.Tangent{D}(; r=Δr, p=Δp)
Δd = ChainRulesCore.Tangent{D}(; r = Δr, p = Δp)
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.NoTangent()
end

Expand All @@ -18,19 +18,24 @@ function ChainRulesCore.rrule(::typeof(logpdf), d::NegativeBinomial, k::Real)
if iszero(k)
Ω = z
∂r = oftype(z, log(p))
∂p = oftype(z, r/p)
∂p = oftype(z, r / p)
elseif insupport(d, k)
Ω = z - log(k + r) - SpecialFunctions.logbeta(r, k + 1)
∂r = oftype(z, log(p) - inv(k + r) - SpecialFunctions.digamma(r) + SpecialFunctions.digamma(r + k + 1))
∂p = oftype(z, r/p - k / (1 - p))
∂r = oftype(
z,
log(p) - inv(k + r) - SpecialFunctions.digamma(r) +
SpecialFunctions.digamma(r + k + 1),
)
∂p = oftype(z, r / p - k / (1 - p))
else
Ω = oftype(z, -Inf)
∂r = oftype(z, NaN)
∂p = oftype(z, NaN)
end

# Define pullback
logpdf_NegativeBinomial_pullback = LogPDFNegativeBinomialPullback{typeof(d),typeof(z)}(∂r, ∂p)
logpdf_NegativeBinomial_pullback =
LogPDFNegativeBinomialPullback{typeof(d), typeof(z)}(∂r, ∂p)

return Ω, logpdf_NegativeBinomial_pullback
end
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ for f in (:poissonbinomial_pdf, :poissonbinomial_pdf_fft)
pullback = Symbol(f, :_pullback)
@eval begin
function ChainRulesCore.frule(
(_, Δp)::Tuple{<:Any,<:AbstractVector{<:Real}}, ::typeof(Distributions.$f), p::AbstractVector{<:Real}
)
(_, Δp)::Tuple{<:Any, <:AbstractVector{<:Real}},
::typeof(Distributions.$f),
p::AbstractVector{<:Real},
)
y = Distributions.$f(p)
A = Distributions.poissonbinomial_pdf_partialderivatives(p)
return y, A' * Δp
Expand Down
23 changes: 19 additions & 4 deletions ext/DistributionsDensityInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,30 @@ for (di_func, d_func) in ((:logdensityof, :logpdf), (:densityof, :pdf))
DensityInterface.$di_func(d::Distribution, x) = $d_func(d, x)

function DensityInterface.$di_func(d::UnivariateDistribution, x::AbstractArray)
throw(ArgumentError("$(DensityInterface.$di_func) doesn't support multiple samples as an argument"))
throw(
ArgumentError(
"$(DensityInterface.$di_func) doesn't support multiple samples as an argument",
),
)
end

function DensityInterface.$di_func(d::MultivariateDistribution, x::AbstractMatrix)
throw(ArgumentError("$(DensityInterface.$di_func) doesn't support multiple samples as an argument"))
throw(
ArgumentError(
"$(DensityInterface.$di_func) doesn't support multiple samples as an argument",
),
)
end

function DensityInterface.$di_func(d::MatrixDistribution, x::AbstractArray{<:AbstractMatrix{<:Real}})
throw(ArgumentError("$(DensityInterface.$di_func) doesn't support multiple samples as an argument"))
function DensityInterface.$di_func(
d::MatrixDistribution,
x::AbstractArray{<:AbstractMatrix{<:Real}},
)
throw(
ArgumentError(
"$(DensityInterface.$di_func) doesn't support multiple samples as an argument",
),
)
end
end
end
Expand Down
Loading
Loading