Skip to content

Commit fe765e6

Browse files
Merge pull request #20 from DhairyaLGandhi/dg/abstract
Check for existing constructor when dealing with Arrays
2 parents f175b0a + b7e9500 commit fe765e6

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ authors = ["SciML"]
44
version = "1.2.0"
55

66
[deps]
7+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
78

89
[compat]
10+
ArrayInterface = "7.11"
911
Aqua = "0.8"
1012
SafeTestsets = "0.1"
1113
Test = "1.10"

src/SciMLStructures.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module SciMLStructures
22

3+
using ArrayInterface: has_trivial_array_constructor
4+
35
include("interface.jl")
46
include("array.jl")
57

src/array.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
1-
hasportion(::Tunable, ::Array) = true
2-
hasportion(::Constants, ::Array) = false
3-
hasportion(::Caches, ::Array) = false
4-
hasportion(::Discrete, ::Array) = false
1+
hasportion(::Tunable, ::AbstractArray) = true
2+
hasportion(::Constants, ::AbstractArray) = false
3+
hasportion(::Caches, ::AbstractArray) = false
4+
hasportion(::Discrete, ::AbstractArray) = false
55

6-
struct ArrayRepack{T}
6+
struct ArrayRepack{T, Ty}
77
sz::T
8+
type::Ty
89
end
910
function (f::ArrayRepack)(A)
1011
@assert length(A) == prod(f.sz)
11-
reshape(A, f.sz)
12+
A_ = if has_trivial_array_constructor(f.type, A)
13+
convert(f.type, A)
14+
else
15+
error("The original type $(typeof(f.type)) does not support the SciMLStructures interface via the AbstractArray `repack` rules. No method exists to take in a regular array and construct the parent type back. Please define the SciMLStructures interface for this type.")
16+
end
17+
reshape(A_, f.sz)
1218
end
1319

14-
canonicalize(::Tunable, p::Array) = vec(p), ArrayRepack(size(p)), true
15-
canonicalize(::Constants, p::Array) = nothing, nothing, nothing
16-
canonicalize(::Caches, p::Array) = nothing, nothing, nothing
17-
canonicalize(::Discrete, p::Array) = nothing, nothing, nothing
20+
canonicalize(::Tunable, p::AbstractArray) = vec(p), ArrayRepack(size(p), typeof(p)), true
21+
canonicalize(::Constants, p::AbstractArray) = nothing, nothing, nothing
22+
canonicalize(::Caches, p::AbstractArray) = nothing, nothing, nothing
23+
canonicalize(::Discrete, p::AbstractArray) = nothing, nothing, nothing
1824

19-
isscimlstructure(::Array) = true
25+
isscimlstructure(::AbstractArray) = true
2026

21-
function SciMLStructures.replace(::SciMLStructures.Tunable, arr::AbstractArray, new_arr::AbstractArray)
27+
function SciMLStructures.replace(
28+
::SciMLStructures.Tunable, arr::AbstractArray, new_arr::AbstractArray)
2229
reshape(new_arr, size(arr))
2330
end

0 commit comments

Comments
 (0)