Skip to content

Commit e4633cc

Browse files
authored
(0.95.15) Fix bugs in ReactantExt for OrthogonalSphericalShellGrid (#4151)
* Fix bugs in ReactantExt for OrthogonalSphericalShellGrid * Fix type info * Add test for OSSG deconcretization * Fix Bounded missing bug * Thrashing around a bit * Clean up architectures a bit * Generalize types for OSSG and LLG plus bugfixes * Bump to 0.95.15 * Make OSSG more friendly, fix bugs * Mystical journey to fix broken dispatch
1 parent 18ebabd commit e4633cc

12 files changed

+258
-212
lines changed

ext/OceananigansReactantExt/Architectures.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import Oceananigans.Architectures: device, architecture, array_type, on_architec
88
const ReactantKernelAbstractionsExt = Base.get_extension(
99
Reactant, :ReactantKernelAbstractionsExt
1010
)
11+
1112
const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend
13+
1214
device(::ReactantState) = ReactantBackend()
1315

1416
architecture(::Reactant.AnyConcretePJRTArray) = ReactantState
@@ -21,18 +23,14 @@ to_reactant_sharding(s::Sharding.AbstractSharding) = s
2123
to_reactant_sharding(::T) where {T} = error("Unsupported sharding type $T")
2224

2325
on_architecture(::ReactantState, a::Reactant.AnyTracedRArray) = a
24-
function on_architecture(r::ReactantState, a::Array)
25-
return Reactant.to_rarray(a; sharding=to_reactant_sharding(r.sharding))
26-
end
27-
function on_architecture(r::ReactantState, a::Reactant.AnyConcretePJRTArray)
28-
return Reactant.to_rarray(a; sharding=to_reactant_sharding(r.sharding))
29-
end
30-
function on_architecture(r::ReactantState, a::BitArray)
31-
return Reactant.to_rarray(a; sharding=to_reactant_sharding(r.sharding))
32-
end
33-
function on_architecture(r::ReactantState, a::SubArray{<:Any,<:Any,<:Array})
34-
return Reactant.to_rarray(a; sharding=to_reactant_sharding(r.sharding))
35-
end
26+
27+
const ArraysToRArray = Union{Array,
28+
Reactant.AnyConcretePJRTArray,
29+
BitArray,
30+
SubArray{<:Any, <:Any, <:Array}}
31+
32+
on_architecture(r::ReactantState, a::ArraysToRArray) =
33+
Reactant.to_rarray(a; sharding=to_reactant_sharding(r.sharding))
3634

3735
unified_array(::ReactantState, a) = a
3836

ext/OceananigansReactantExt/Grids.jl

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using Oceananigans.Fields: Field
88
using Oceananigans.ImmersedBoundaries: GridFittedBottom
99

1010
import ..OceananigansReactantExt: deconcretize
11-
import Oceananigans.Grids: LatitudeLongitudeGrid, RectilinearGrid
11+
import Oceananigans.Grids: LatitudeLongitudeGrid, RectilinearGrid, OrthogonalSphericalShellGrid
12+
import Oceananigans.OrthogonalSphericalShellGrids: RotatedLatitudeLongitudeGrid
1213
import Oceananigans.ImmersedBoundaries: ImmersedBoundaryGrid
1314

1415
const ReactantGrid{FT, TX, TY, TZ} = AbstractGrid{FT, TX, TY, TZ, <:ReactantState}
@@ -40,29 +41,57 @@ function RectilinearGrid(arch::ReactantState, FT::DataType; kw...)
4041
return RectilinearGrid{TX, TY, TZ}(arch, other_properties...)
4142
end
4243

44+
function OrthogonalSphericalShellGrid(arch::ReactantState, FT::DataType; kw...)
45+
cpu_grid = OrthogonalSphericalShellGrid(CPU(), FT; kw...)
46+
other_names = propertynames(cpu_grid)[2:end] # exclude architecture
47+
other_properties = Tuple(getproperty(cpu_grid, name) for name in other_names)
48+
TX, TY, TZ = Oceananigans.Grids.topology(cpu_grid)
49+
return OrthogonalSphericalShellGrid{TX, TY, TZ}(arch, other_properties...)
50+
end
51+
52+
# This is a kind of OrthogonalSphericalShellGrid
53+
function RotatedLatitudeLongitudeGrid(arch::ReactantState, FT::DataType; kw...)
54+
cpu_grid = RotatedLatitudeLongitudeGrid(CPU(), FT; kw...)
55+
other_names = propertynames(cpu_grid)[2:end] # exclude architecture
56+
other_properties = Tuple(getproperty(cpu_grid, name) for name in other_names)
57+
TX, TY, TZ = Oceananigans.Grids.topology(cpu_grid)
58+
return OrthogonalSphericalShellGrid{TX, TY, TZ}(arch, other_properties...)
59+
end
60+
61+
# This low-level constructor supports the external package OrthogonalSphericalShellGrids.jl.
4362
function OrthogonalSphericalShellGrid{TX, TY, TZ}(arch::ReactantState,
4463
Nx, Ny, Nz, Hx, Hy, Hz,
4564
Lz :: FT,
46-
λᶜᶜᵃ :: A, λᶠᶜᵃ :: A, λᶜᶠᵃ :: A, λᶠᶠᵃ :: A,
47-
φᶜᶜᵃ :: A, φᶠᶜᵃ :: A, φᶜᶠᵃ :: A, φᶠᶠᵃ :: A,
48-
z :: Z,
49-
Δxᶜᶜᵃ :: A, Δxᶠᶜᵃ :: A, Δxᶜᶠᵃ :: A, Δxᶠᶠᵃ :: A,
50-
Δyᶜᶜᵃ :: A, Δyᶜᶠᵃ :: A, Δyᶠᶜᵃ :: A, Δyᶠᶠᵃ :: A,
51-
Azᶜᶜᵃ :: A, Azᶠᶜᵃ :: A, Azᶜᶠᵃ :: A, Azᶠᶠᵃ :: A,
65+
λᶜᶜᵃ :: CC, λᶠᶜᵃ :: FC, λᶜᶠᵃ :: CF, λᶠᶠᵃ :: FF,
66+
φᶜᶜᵃ :: CC, φᶠᶜᵃ :: FC, φᶜᶠᵃ :: CF, φᶠᶠᵃ :: FF, z :: Z,
67+
Δxᶜᶜᵃ :: CC, Δxᶠᶜᵃ :: FC, Δxᶜᶠᵃ :: CF, Δxᶠᶠᵃ :: FF,
68+
Δyᶜᶜᵃ :: CC, Δyᶠᶜᵃ :: FC, Δyᶜᶠᵃ :: CF, Δyᶠᶠᵃ :: FF,
69+
Azᶜᶜᵃ :: CC, Azᶠᶜᵃ :: FC, Azᶜᶠᵃ :: CF, Azᶠᶠᵃ :: FF,
5270
radius :: FT,
53-
conformal_mapping :: C) where {TX, TY, TZ, FT, Z, A, C}
54-
55-
args = (λᶜᶜᵃ, λᶠᶜᵃ, λᶜᶠᵃ, λᶠᶠᵃ,
56-
φᶜᶜᵃ, φᶠᶜᵃ, φᶜᶠᵃ, φᶠᶠᵃ,
57-
z,
58-
Δxᶜᶜᵃ, Δxᶠᶜᵃ, Δxᶜᶠᵃ, Δxᶠᶠᵃ,
59-
Δyᶜᶜᵃ, Δyᶜᶠᵃ, Δyᶠᶜᵃ, Δyᶠᶠᵃ,
60-
Azᶜᶜᵃ, Azᶠᶜᵃ, Azᶜᶠᵃ, Azᶠᶠᵃ)
61-
62-
dargs = Tuple(deconcretize(a) for a in args)
63-
64-
return OrthogonalSphericalShellGrid{FT, TX, TY, TZ, CZ, A, C, Arch}(arch, Nx, Ny, Nz, Hx, Hy, Hz, Lz,
65-
dargs..., radius, conformal_mapping)
71+
conformal_mapping :: Map) where {TX, TY, TZ, FT, Z, Map,
72+
CC, FC, CF, FF, C}
73+
74+
args1 = (λᶜᶜᵃ, λᶠᶜᵃ, λᶜᶠᵃ, λᶠᶠᵃ,
75+
φᶜᶜᵃ, φᶠᶜᵃ, φᶜᶠᵃ, φᶠᶠᵃ)
76+
77+
args2 = (Δxᶜᶜᵃ, Δxᶠᶜᵃ, Δxᶜᶠᵃ, Δxᶠᶠᵃ,
78+
Δyᶜᶜᵃ, Δyᶠᶜᵃ, Δyᶜᶠᵃ, Δyᶠᶠᵃ,
79+
Azᶜᶜᵃ, Azᶠᶜᵃ, Azᶜᶠᵃ, Azᶠᶠᵃ)
80+
81+
dargs1 = Tuple(deconcretize(a) for a in args1)
82+
dz = deconcretize(z)
83+
dargs2 = Tuple(deconcretize(a) for a in args2)
84+
85+
Arch = typeof(arch)
86+
DCC = typeof(dargs1[1]) # deconcretized
87+
DFC = typeof(dargs1[2]) # deconcretized
88+
DCF = typeof(dargs1[3]) # deconcretized
89+
DFF = typeof(dargs1[4]) # deconcretized
90+
DZ = typeof(dz) # deconcretized
91+
92+
return OrthogonalSphericalShellGrid{FT, TX, TY, TZ, DZ, Map,
93+
DCC, DFC, DCF, DFF, Arch}(arch, Nx, Ny, Nz, Hx, Hy, Hz, Lz,
94+
dargs1..., dz, dargs2..., radius, conformal_mapping)
6695
end
6796

6897
deconcretize(gfb::GridFittedBottom) = GridFittedBottom(deconcretize(gfb.bottom_height),
@@ -89,8 +118,6 @@ ImmersedBoundaryGrid(grid::ReactantGrid, ib::GridFittedBottom; active_cells_map:
89118

90119
ImmersedBoundaryGrid(grid::ReactantGrid, ib; active_cells_map::Bool=true) =
91120
reactant_immersed_boundary_grid(grid, ib; active_cells_map)
92-
93-
94121

95122
end # module
96123

src/Grids/conformal_cubed_sphere_panel.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function conformal_cubed_sphere_panel(filepath::AbstractString, architecture = C
9999
φᶜᶜᵃ, φᶠᶜᵃ, φᶜᶠᵃ, φᶠᶠᵃ,
100100
z,
101101
Δxᶜᶜᵃ, Δxᶠᶜᵃ, Δxᶜᶠᵃ, Δxᶠᶠᵃ,
102-
Δyᶜᶜᵃ, Δyᶜᶠᵃ, Δyᶠᶜᵃ, Δyᶠᶠᵃ,
102+
Δyᶜᶜᵃ, Δyᶠᶜᵃ, Δyᶜᶠᵃ, Δyᶠᶠᵃ,
103103
Azᶜᶜᵃ, Azᶠᶜᵃ, Azᶜᶠᵃ, Azᶠᶠᵃ,
104104
radius,
105105
conformal_mapping)
@@ -213,23 +213,17 @@ function conformal_cubed_sphere_panel(architecture::AbstractArchitecture = CPU()
213213
halo = (1, 1, 1),
214214
rotation = nothing)
215215

216-
if architecture == GPU() && !has_cuda()
217-
throw(ArgumentError("Cannot create a GPU grid. No CUDA-enabled GPU was detected!"))
218-
end
219-
220216
radius = FT(radius)
221-
222217
TX, TY, TZ = topology
223218
Nξ, Nη, Nz = size
224219
Hx, Hy, Hz = halo
225220

226-
## Use a regular rectilinear grid for the face of the cube
227-
221+
# Use a regular rectilinear grid for the face of the cube
228222
ξη_grid_topology = (Bounded, Bounded, topology[3])
229223

230224
# construct the grid on CPU and convert to architecture later...
231225
ξη_grid = RectilinearGrid(CPU(), FT;
232-
size=(Nξ, Nη, Nz),
226+
size = (Nξ, Nη, Nz),
233227
topology = ξη_grid_topology,
234228
x=ξ, y=η, z, halo)
235229

@@ -511,7 +505,6 @@ function conformal_cubed_sphere_panel(architecture::AbstractArchitecture = CPU()
511505
Azᶜᶠᵃ[i, j] = 2 * spherical_area_quadrilateral(a, b, c, d) * radius^2
512506
end
513507

514-
515508
# Azᶠᶠᵃ
516509

517510
for j in 2:Nη, i in 2:
@@ -640,7 +633,7 @@ function conformal_cubed_sphere_panel(architecture::AbstractArchitecture = CPU()
640633
zc)
641634

642635
metric_arrays = (Δxᶜᶜᵃ, Δxᶠᶜᵃ, Δxᶜᶠᵃ, Δxᶠᶠᵃ,
643-
Δyᶜᶜᵃ, Δyᶜᶠᵃ, Δyᶠᶜᵃ, Δyᶠᶠᵃ,
636+
Δyᶜᶜᵃ, Δyᶠᶜᵃ, Δyᶜᶠᵃ, Δyᶠᶠᵃ,
644637
Azᶜᶜᵃ, Azᶠᶜᵃ, Azᶜᶠᵃ, Azᶠᶠᵃ)
645638

646639
conformal_mapping = CubedSphereConformalMapping(ξ, η, rotation)
@@ -660,7 +653,7 @@ function conformal_cubed_sphere_panel(architecture::AbstractArchitecture = CPU()
660653
grid.z)
661654

662655
metric_arrays = (grid.Δxᶜᶜᵃ, grid.Δxᶠᶜᵃ, grid.Δxᶜᶠᵃ, grid.Δxᶠᶠᵃ,
663-
grid.Δyᶜᶜᵃ, grid.Δyᶜᶠᵃ, grid.Δyᶠᶜᵃ, grid.Δyᶠᶠᵃ,
656+
grid.Δyᶜᶜᵃ, grid.Δyᶠᶜᵃ, grid.Δyᶜᶠᵃ, grid.Δyᶠᶠᵃ,
664657
grid.Azᶜᶜᵃ, grid.Azᶠᶜᵃ, grid.Azᶜᶠᵃ, grid.Azᶠᶠᵃ)
665658

666659
coordinate_arrays = map(a -> on_architecture(architecture, a), coordinate_arrays)

src/Grids/grid_generation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function generate_coordinate(FT, topo::AT, N, H, node_generator, coordinate_name
7878
end
7979

8080
Δᶜ = OffsetArray(on_architecture(arch, Δᶜ), -H)
81-
Δᶠ = OffsetArray(on_architecture(arch, Δᶠ), -H-1)
81+
Δᶠ = OffsetArray(on_architecture(arch, Δᶠ), -H - 1)
8282

8383
F = OffsetArray(F, -H)
8484
C = OffsetArray(C, -H)

0 commit comments

Comments
 (0)