diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index f57b3105da..4ab5c6a43c 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -67,8 +67,30 @@ rmap(fn::F, X::Tuple, Y) where {F} = (rmap(fn, first(X), Y), rmap(fn, Base.tail(X), Y)...) function rmap(fn::F, X::NamedTuple, Y::NamedTuple) where {F} - @assert nt_names(X) === nt_names(Y) - return NamedTuple{nt_names(X)}(rmap(fn, Tuple(X), Tuple(Y))) + # @assert nt_names(X) === nt_names(Y) + # return NamedTuple{nt_names(X)}(rmap(fn, Tuple(X), Tuple(Y))) + x_names = nt_names(X) + y_names = nt_names(Y) + + # Check if Y names are a subset of X names + if !issubset(y_names, x_names) + throw(ArgumentError("Names in Y must be a subset of names in X. Y has names: $y_names, X has names: $x_names")) + end + + # Create a new NamedTuple with the same structure as X + # For matching names, apply fn to the corresponding values + # For non-matching names, keep the original values from X + result_values = map(x_names) do name + if name in y_names + # Apply fn to the matching values + rmap(fn, getproperty(X, name), getproperty(Y, name)) + else + # Keep the original value from X + getproperty(X, name) + end + end + + return NamedTuple{x_names}(result_values) end rmap(fn::F, X::NamedTuple, Y) where {F} = NamedTuple{nt_names(X)}(rmap(fn, Tuple(X), Y)) diff --git a/test/RecursiveApply/unit_recursive_apply.jl b/test/RecursiveApply/unit_recursive_apply.jl index 1ccc7cf49f..ee65183632 100644 --- a/test/RecursiveApply/unit_recursive_apply.jl +++ b/test/RecursiveApply/unit_recursive_apply.jl @@ -1,6 +1,7 @@ using JET using Test +import ClimaCore as CC using ClimaCore.RecursiveApply using ClimaCore.Geometry @@ -98,3 +99,53 @@ end @test rz.b.u == 2 @test rz.b.v == 4 end + +@testset "NamedTuple subset functionality" begin + # Test basic subset functionality + X = (a=1, b=2.0, d=[1, 2, 3]) + Y = (a=10, b=3.0) + + result = RecursiveApply.rmap(+, X, Y) + @test result.a == 11 # 1 + 10 + @test result.b == 5.0 # 2.0 + 3.0 + @test result.d == [1, 2, 3] # unchanged from X + + # Test with nested NamedTuples + X_nested = (a=(x=1, y=2), b=(z=3, w=4)) + Y_nested = (a=(x=10,),) + + result_nested = RecursiveApply.rmap(+, X_nested, Y_nested) + @test result_nested.a.x == 11 # 1 + 10 + @test result_nested.a.y == 2 # unchanged from X + @test result_nested.b.z == 3 # unchanged from X + @test result_nested.b.w == 4 # unchanged from X + + # Test error case (Y has names not in X) + Y_error = (a=1, e=5) # 'e' is not in X + @test_throws ArgumentError RecursiveApply.rmap(+, X, Y_error) + + # Test type stability + @test_opt RecursiveApply.rmap(+, X, Y) + @test_opt RecursiveApply.rmap(+, X_nested, Y_nested) + + FT = Float64 + domain = CC.Domains.IntervalDomain( + CC.Geometry.ZPoint{FT}(0), + CC.Geometry.ZPoint{FT}(1), + boundary_names = (:bottom, :top), + ) + mesh = CC.Meshes.IntervalMesh(domain, nelems = 2) + space = CC.Spaces.CenterFiniteDifferenceSpace(mesh) + coord = CC.Fields.coordinate_field(space) + + X_nt = (; a = 1.0, b = 2.0, c = 3.0) + X_field = map(Returns(X_nt), coord) + + Y_nt = (; a = 10.0, b = 3.0) + Y_field = map(Returns(Y_nt), coord) + + result_field = @. RecursiveApply.rmap(+, X_field, Y_field) + @test all(==(11.0), parent(result_field.a)) + @test all(==(5.0), parent(result_field.b)) + @test all(==(3.0), parent(result_field.c)) +end