Skip to content

Commit 8819abe

Browse files
authored
Add insertdims (#830)
1 parent 1e623ad commit 8819abe

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ changes in `julia`.
7272

7373
## Supported features
7474

75+
* `insertdims(D; dims)` is the opposite of `dropdims` ([#45793]) (since Compat 4.16.0)
76+
7577
* `Compat.Fix{N}` which fixes an argument at the `N`th position ([#54653]) (since Compat 4.16.0)
7678

7779
* `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0)
@@ -190,6 +192,7 @@ Note that you should specify the correct minimum version for `Compat` in the
190192
[#43852]: https://github.com/JuliaLang/julia/issues/43852
191193
[#45052]: https://github.com/JuliaLang/julia/issues/45052
192194
[#45607]: https://github.com/JuliaLang/julia/issues/45607
195+
[#45793]: https://github.com/JuliaLang/julia/issues/45793
193196
[#47354]: https://github.com/JuliaLang/julia/issues/47354
194197
[#47679]: https://github.com/JuliaLang/julia/pull/47679
195198
[#48038]: https://github.com/JuliaLang/julia/issues/48038

src/Compat.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,37 @@ if VERSION < v"1.8.0-DEV.1016"
11221122
export chopprefix, chopsuffix
11231123
end
11241124

1125+
if VERSION < v"1.12.0-DEV.974" # contrib/commit-name.sh 2635dea
1126+
1127+
insertdims(A; dims) = _insertdims(A, dims)
1128+
1129+
function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M}
1130+
for i in eachindex(dims)
1131+
1 dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1"))
1132+
dims[i] N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
1133+
for j = 1:i-1
1134+
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
1135+
end
1136+
end
1137+
1138+
# acc is a tuple, where the first entry is the final shape
1139+
# the second entry off acc is a counter for the axes of A
1140+
inds = Base._foldoneto((acc, i) ->
1141+
i dims
1142+
? ((acc[1]..., Base.OneTo(1)), acc[2])
1143+
: ((acc[1]..., axes(A, acc[2])), acc[2] + 1),
1144+
((), 1), Val(N+M))
1145+
new_shape = inds[1]
1146+
return reshape(A, new_shape)
1147+
end
1148+
1149+
_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))
1150+
1151+
export insertdims
1152+
else
1153+
using Base: insertdims, _insertdims
1154+
end
1155+
11251156
# https://github.com/JuliaLang/julia/pull/54653: add Fix
11261157
@static if !isdefined(Base, :Fix) # VERSION < v"1.12.0-DEV.981"
11271158
@static if !isdefined(Base, :_stable_typeof)

test/runtests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,37 @@ end
908908
end
909909
end
910910

911+
# https://github.com/JuliaLang/julia/pull/45793
912+
@testset "insertdims" begin
913+
a = rand(8, 7)
914+
@test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7))
915+
@test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1))
916+
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7))
917+
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7))
918+
@test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1))
919+
@test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1))
920+
@test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1))
921+
@test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1))
922+
@test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1))
923+
@test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1))
924+
925+
@test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3))
926+
@test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3))
927+
@test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3))
928+
@test_throws UndefKeywordError insertdims(a)
929+
@test_throws ArgumentError insertdims(a, dims=0)
930+
@test_throws ArgumentError insertdims(a, dims=(1, 2, 1))
931+
@test_throws ArgumentError insertdims(a, dims=4)
932+
@test_throws ArgumentError insertdims(a, dims=6)
933+
934+
# insertdims and dropdims are inverses
935+
b = rand(1,1,1,5,1,1,7)
936+
for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)]
937+
@test dropdims(insertdims(a; dims); dims) == a
938+
@test insertdims(dropdims(b; dims); dims) == b
939+
end
940+
end
941+
911942
# https://github.com/JuliaLang/julia/pull/54653: add Fix
912943
@testset "Fix" begin
913944
function test_fix1(Fix1=Compat.Fix1)

0 commit comments

Comments
 (0)