Skip to content

Commit b38c3f0

Browse files
committed
Add Staggered Grid
1 parent 7ab6005 commit b38c3f0

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EarthSciMLBase"
22
uuid = "e53f1632-a13c-4728-9402-0c66d48804b0"
33
authors = ["EarthSciML Authors and Contributors"]
4-
version = "0.21.1"
4+
version = "0.21.2"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/domaininfo.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct DomainInfo{T}
5151
"The time offset for the domain."
5252
time_offset::T
5353

54-
function DomainInfo(pdfs, gs, icbc, sr, to::T) where T
54+
function DomainInfo(pdfs, gs, icbc, sr, to::T) where {T}
5555
new{T}(pdfs, gs, icbc, sr, to)
5656
end
5757
function DomainInfo(icbc::ICBCcomponent...; dtype=Float64, grid_spacing=nothing,
@@ -120,8 +120,10 @@ struct DomainInfo{T}
120120
end
121121
end
122122

123-
Base.size(d::DomainInfo) = (length(g) for g grid(d))
123+
Base.size(d::DomainInfo) = tuple((length(g) for g grid(d))...)
124+
Base.size(d::DomainInfo, staggering::NTuple{3, Bool}) = tuple((length(g) for g grid(d, staggering))...)
124125
Base.size(d::DomainInfo, i) = length(grid(d)[i])
126+
Base.size(d::DomainInfo, staggering::NTuple{3, Bool}, i) = length(grid(d, staggering)[i])
125127

126128
"""
127129
$(SIGNATURES)
@@ -137,22 +139,41 @@ $(SIGNATURES)
137139
Return the ranges representing the discretization of the partial independent
138140
variables for this domain, based on the discretization intervals given in `Δs`.
139141
"""
140-
function grid(d::DomainInfo{T}) where T
142+
function grid(d::DomainInfo{T}) where {T}
141143
if !((d.grid_spacing isa Base.AbstractVecOrTuple) &&
142-
(length(pvars(d)) == length(d.grid_spacing)))
144+
(length(pvars(d)) == length(d.grid_spacing)))
143145
throw(ArgumentError("The number of partial independent variables ($(length(pvars(d)))) must equal the number of grid spacings provided ($(d.grid_spacing))."))
144146
end
145147
endpts = endpoints(d)
146148
[s:d:e for ((s, e), d) in zip(endpts, d.grid_spacing)]
147149
end
150+
function grid(d::DomainInfo{T}, staggering) where {T}
151+
if !((d.grid_spacing isa Base.AbstractVecOrTuple) &&
152+
(length(pvars(d)) == length(d.grid_spacing)))
153+
throw(ArgumentError("The number of partial independent variables ($(length(pvars(d)))) must equal the number of grid spacings provided ($(d.grid_spacing))."))
154+
end
155+
endpts = endpoints(d)
156+
@assert length(staggering) == length(endpts) "The number of staggering values $(length(staggering)) must match the number of partial independent variables $(length(endpts))."
157+
@assert all(isa.(staggering, (Bool,))) "Staggering must be a vector of booleans."
158+
g_nostagger = [s:d:e for ((s, e), d) in zip(endpts, d.grid_spacing)]
159+
endpts = [staggering[i] ? (e[1] - d.grid_spacing[i] / 2, e[2] + d.grid_spacing[i] / 2) : e
160+
for (i, e) in enumerate(endpts)]
161+
g = [s:d:e for ((s, e), d) in zip(endpts, d.grid_spacing)]
162+
for (i, g) in enumerate(g)
163+
if staggering[i]
164+
@assert length(g) == length(g_nostagger[i]) + 1 "Staggering was not applied correctly. (Probably a rounding error.)"
165+
end
166+
end
167+
g
168+
end
148169

149170
"""
150171
$(SIGNATURES)
151172
152173
Return the endpoints of the partial independent
153174
variables for this domain.
154175
"""
155-
function endpoints(d::DomainInfo{T}) where T
176+
function endpoints(d::DomainInfo{T}) where {T}
156177
i = 1
157178
rngs = []
158179
for icbc d.icbc
@@ -274,7 +295,7 @@ function partialderivative_transform_vars(mtk_sys, di::DomainInfo)
274295
vs = []
275296
for (i, x) in enumerate(xs)
276297
n = Symbol("δ$(x)_transform")
277-
v = only(@variables $n(iv) [unit=ModelingToolkit.get_unit(ts[i]),
298+
v = only(@variables $n(iv) [unit = ModelingToolkit.get_unit(ts[i]),
278299
description = "Transform factor for $(x)"])
279300
push!(vs, v)
280301
end

test/domaininfo_test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ end
219219

220220
@test Symbol.(pvars(di)) == [:x, :y]
221221
@test grid(di) == [0.0:0.1:1.0, 0.0:0.1:2.0]
222+
@test grid(di, (true, false)) == [-0.05:0.1:1.05, 0.0:0.1:2.0]
222223
@test get_tspan(di) == (1.7040672e9, 1.704078e9)
223224
@test length(di.partial_derivative_funcs) == 0
224225
end

0 commit comments

Comments
 (0)