Skip to content

Commit f248a6e

Browse files
Merge pull request #1886 from CliMA/ck/compare
Define rcompare and rprint_diff for FieldVectors
2 parents 331571c + bb4d8b3 commit f248a6e

File tree

3 files changed

+96
-23
lines changed

3 files changed

+96
-23
lines changed

src/Fields/fieldvector.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,80 @@ import ClimaComms
361361
ClimaComms.array_type(x::FieldVector) = promote_type(
362362
UnrolledFunctions.unrolled_map(ClimaComms.array_type, _values(x))...,
363363
)
364+
365+
function __rprint_diff(
366+
io::IO,
367+
x::T,
368+
y::T;
369+
pc,
370+
xname,
371+
yname,
372+
) where {T <: Union{FieldVector, Field, DataLayouts.AbstractData, NamedTuple}}
373+
for pn in propertynames(x)
374+
pc_full = (pc..., ".", pn)
375+
xi = getproperty(x, pn)
376+
yi = getproperty(y, pn)
377+
__rprint_diff(io, xi, yi; pc = pc_full, xname, yname)
378+
end
379+
end;
380+
381+
function __rprint_diff(io::IO, xi, yi; pc, xname, yname) # assume we can compute difference here
382+
if !(xi == yi)
383+
xs = xname * string(join(pc))
384+
ys = yname * string(join(pc))
385+
println(io, "==================== Difference found:")
386+
println(io, "$xs: ", xi)
387+
println(io, "$ys: ", yi)
388+
println(io, "($xs .- $ys): ", (xi .- yi))
389+
end
390+
return nothing
391+
end
392+
393+
"""
394+
rprint_diff(io::IO, ::T, ::T) where {T <: FieldVector}
395+
rprint_diff(::T, ::T) where {T <: FieldVector}
396+
397+
Recursively print differences in given `FieldVector`.
398+
"""
399+
_rprint_diff(io::IO, x::T, y::T, xname, yname) where {T <: FieldVector} =
400+
__rprint_diff(io, x, y; pc = (), xname, yname)
401+
_rprint_diff(x::T, y::T, xname, yname) where {T <: FieldVector} =
402+
_rprint_diff(stdout, x, y, xname, yname)
403+
404+
"""
405+
@rprint_diff(::T, ::T) where {T <: FieldVector}
406+
407+
Recursively print differences in given `FieldVector`.
408+
"""
409+
macro rprint_diff(x, y)
410+
return :(_rprint_diff(
411+
stdout,
412+
$(esc(x)),
413+
$(esc(y)),
414+
$(string(x)),
415+
$(string(y)),
416+
))
417+
end
418+
419+
420+
# Recursively compare contents of similar fieldvectors
421+
_rcompare(pass, x::T, y::T) where {T <: Field} =
422+
pass && _rcompare(pass, field_values(x), field_values(y))
423+
_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
424+
pass && (parent(x) == parent(y))
425+
_rcompare(pass, x::T, y::T) where {T} = pass && (x == y)
426+
427+
function _rcompare(pass, x::T, y::T) where {T <: FieldVector}
428+
for pn in propertynames(x)
429+
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn))
430+
end
431+
return pass
432+
end
433+
434+
"""
435+
rcompare(x::T, y::T) where {T <: FieldVector}
436+
437+
Recursively compare given fieldvectors via `==`.
438+
Returns `true` if `x == y` recursively.
439+
"""
440+
rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y)

test/Fields/unit_field.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,21 @@ end
305305

306306
Y.k.z = 3.0
307307
@test Y.k.z === 3.0
308+
309+
@test Fields.rcompare(Y, Y)
310+
Ydc = deepcopy(Y)
311+
Ydc.k.z += 1
312+
@test !Fields.rcompare(Ydc, Y)
313+
# Fields.@rprint_diff(Ydc, Y)
314+
s = sprint(
315+
Fields._rprint_diff,
316+
Ydc,
317+
Y,
318+
"Ydc",
319+
"Y";
320+
context = IOContext(stdout),
321+
)
322+
@test occursin("==================== Difference found:", s)
308323
end
309324

310325
# https://github.com/CliMA/ClimaCore.jl/issues/1465

test/Fields/utils_field_multi_broadcast_fusion.jl

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,27 +83,6 @@ function benchmark_kernel!(f!, X, Y, device)
8383
show(stdout, MIME("text/plain"), trial)
8484
end
8585

86-
function show_diff(A, B)
87-
for pn in propertynames(A)
88-
Ai = getproperty(A, pn)
89-
Bi = getproperty(B, pn)
90-
println("==================== Comparing $pn")
91-
@show Ai
92-
@show Bi
93-
@show abs.(Ai .- Bi)
94-
end
95-
end
96-
97-
function compare(A, B)
98-
pass = true
99-
for pn in propertynames(A)
100-
pass =
101-
pass &&
102-
all(parent(getproperty(A, pn)) .== parent(getproperty(B, pn)))
103-
end
104-
pass || show_diff(A, B)
105-
return pass
106-
end
10786
function test_kernel!(; fused!, unfused!, X, Y)
10887
for pn in propertynames(X)
10988
rand_field!(getproperty(X, pn))
@@ -122,8 +101,10 @@ function test_kernel!(; fused!, unfused!, X, Y)
122101
unfused!(X_unfused, Y_unfused)
123102
fused!(X_fused, Y_fused)
124103
@testset "Test correctness of $(nameof(typeof(fused!)))" begin
125-
@test compare(X_fused, X_unfused)
126-
@test compare(Y_fused, Y_unfused)
104+
Fields.@rprint_diff(X_fused, X_unfused)
105+
Fields.@rprint_diff(Y_fused, Y_unfused)
106+
@test Fields.rcompare(X_fused, X_unfused)
107+
@test Fields.rcompare(Y_fused, Y_unfused)
127108
end
128109
end
129110

0 commit comments

Comments
 (0)