From bd5a8098f132490acf9fe0324386cfa9c405e95c Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 14 Oct 2024 20:24:38 -0400 Subject: [PATCH 1/3] Added naming based on input types --- src/types.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/types.jl b/src/types.jl index 683f58d44..d2ab6e376 100644 --- a/src/types.jl +++ b/src/types.jl @@ -816,13 +816,18 @@ function show_ref(io, f, args) print(io, "]") end +import Base.nameof +Base.nameof(f, arg, args...) = nameof(f) + function show_call(io, f, args) - fname = iscall(f) ? Symbol(repr(f)) : nameof(f) + fname = nameof(f, symtype.(args)...) + frep = Symbol(repr(f)) len_args = length(args) - if Base.isunaryoperator(fname) && len_args == 1 + + if Base.isunaryoperator(frep) && len_args == 1 print(io, "$fname") print_arg(io, first(args), paren=true) - elseif Base.isbinaryoperator(fname) && len_args > 1 + elseif Base.isbinaryoperator(frep) && len_args > 1 for (i, t) in enumerate(args) i != 1 && print(io, " $fname ") print_arg(io, t, paren=true) @@ -831,12 +836,12 @@ function show_call(io, f, args) if issym(f) Base.show_unquoted(io, nameof(f)) else - Base.show(io, f) + Base.show_unquoted(io, fname) end print(io, "(") - for i=1:length(args) + for i=1:len_args print(io, args[i]) - i != length(args) && print(io, ", ") + i != len_args && print(io, ", ") end print(io, ")") end From 7034b18e2be39defe240f3850361ba0b35eaeb45 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:54:31 -0400 Subject: [PATCH 2/3] Added docs and started on tests Still need tests for decorated case. --- src/types.jl | 25 +++++++++++++++---------- test/basics.jl | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/types.jl b/src/types.jl index d2ab6e376..97d97b5a7 100644 --- a/src/types.jl +++ b/src/types.jl @@ -102,7 +102,7 @@ end """ $(SIGNATURES) -Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) +Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) of `x`. By default this is just `typeof(x)`. Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules specific to numbers (such as commutativity of multiplication). Or such @@ -561,9 +561,9 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) st = symtype(T) pst = _promote_symtype(head, args) # Use promoted symtype only if not a subtype of the existing symtype of T. - # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` - # Where the result would have a symtype of Bool. - # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 + # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` + # Where the result would have a symtype of Bool. + # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 # TODO this should be optimized. new_st = if st <: AbstractArray st @@ -817,13 +817,22 @@ function show_ref(io, f, args) end import Base.nameof +# To fall through the `nameof` in the `show_call` below Base.nameof(f, arg, args...) = nameof(f) +""" + show_call(io, f, args) +Displays the function call with given args. There are different outputs if `f` +is unary, binary or otherwise. `f`'s output can also be decorated using +`Base.nameof` provided with the function as well as with the `symtype` +of `f`'s arguments. +""" function show_call(io, f, args) fname = nameof(f, symtype.(args)...) frep = Symbol(repr(f)) + len_args = length(args) - + if Base.isunaryoperator(frep) && len_args == 1 print(io, "$fname") print_arg(io, first(args), paren=true) @@ -833,11 +842,7 @@ function show_call(io, f, args) print_arg(io, t, paren=true) end else - if issym(f) - Base.show_unquoted(io, nameof(f)) - else - Base.show_unquoted(io, fname) - end + print(io, "$fname") print(io, "(") for i=1:len_args print(io, args[i]) diff --git a/test/basics.jl b/test/basics.jl index 024533c9e..99b8ec6bb 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -3,10 +3,11 @@ using SymbolicUtils using IfElse: ifelse using Setfield using Test +import Base.nameof @testset "@syms" begin let - @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int + @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int @test issym(a) && symtype(a) == Number @test a.name === :a @@ -233,6 +234,30 @@ end @test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14))) end +function nameof(::typeof(sq), arg) + if arg <: Real + return :sqrt_R + elseif arg <: Complex + return :sqrt_C + else + return :sqrt + end +end +@testset "call printing" begin + get_print(sym) = begin b = IOBuffer(); print(b, sym); String(take!(b)); end + + x,y,z = @syms x::Real y::Complex z + @syms e() f(x) g(x,y) h(x,y,z) + + @test get_print(e()) == "e()" + @test get_print(f(x)) == "f(x)" + @test get_print(g(x,y)) == "g(x, y)" + @test get_print(h(x,y,z)) == "h(x, y, z)" + + @nospecialize + sq(x) = return SymbolicUtils.Term{Number}(sq, [x]) +end + @testset "maketerm" begin @syms a b c @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) @@ -247,7 +272,7 @@ end # test that maketerm sets metadata correctly metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1") metadata2 = Base.ImmutableDict{DataType, Any}(Ctx2, "meta_2") - + d = b * c @set! d.metadata = metadata2 @@ -275,12 +300,12 @@ end @test symtype(new_expr) == Bool # Doesn't know return type, promoted symtype is Any - foo(x,y) = x^2 + x + foo(x,y) = x^2 + x new_expr = SymbolicUtils.maketerm(typeof(ref_expr), foo, [a, b], nothing) @test symtype(new_expr) == Number # Promoted symtype is a subtype of referred - @syms x::Int y::Int + @syms x::Int y::Int new_expr = SymbolicUtils.maketerm(typeof(ref_expr), (+), [x, y], nothing) @test symtype(new_expr) == Int64 @@ -382,5 +407,5 @@ end ax = adjoint(x) @test isequal(ax, x) @test ax === x - @test isequal(adjoint(y), conj(y)) + @test isequal(adjoint(y), conj(y)) end From 128c433fae34b7bbc36efcfda3273fbb8eb11fa9 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:35:46 -0400 Subject: [PATCH 3/3] Added test with typed printing --- test/basics.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index 99b8ec6bb..496dcc184 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -234,7 +234,11 @@ end @test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14))) end -function nameof(::typeof(sq), arg) +let + +sq(x) = return SymbolicUtils.Term{Number}(sq, [x]) + +function Base.nameof(::typeof(sq), arg) if arg <: Real return :sqrt_R elseif arg <: Complex @@ -243,6 +247,7 @@ function nameof(::typeof(sq), arg) return :sqrt end end + @testset "call printing" begin get_print(sym) = begin b = IOBuffer(); print(b, sym); String(take!(b)); end @@ -254,8 +259,11 @@ end @test get_print(g(x,y)) == "g(x, y)" @test get_print(h(x,y,z)) == "h(x, y, z)" - @nospecialize - sq(x) = return SymbolicUtils.Term{Number}(sq, [x]) + @test get_print(sq(x)) == "sqrt_R(x)" + @test get_print(sq(y)) == "sqrt_C(y)" + @test get_print(sq(z)) == "sqrt(z)" +end + end @testset "maketerm" begin