Skip to content

Commit fa9fc14

Browse files
authored
Merge pull request #210 from MilesCranmer/fix-anonymous-edgecases
Fix anonymous function edge cases
2 parents db8ed6c + b3e148c commit fa9fc14

File tree

2 files changed

+80
-3
lines changed

2 files changed

+80
-3
lines changed

src/utils.jl

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,39 @@ isshortdef(ex) = (@capture(ex, (fcall_ = body_)) &&
259259

260260
function longdef1(ex)
261261
if @capture(ex, (arg_ -> body_))
262-
Expr(:function, arg isa Symbol ? :($arg,) : arg, body)
262+
263+
if isexpr(arg, :tuple) && length(arg.args) == 1 && isexpr(arg.args[1], :parameters)
264+
# Special case (; kws...) ->
265+
fcall = Expr(:tuple, arg.args[1])
266+
267+
Expr(:function, fcall, body)
268+
elseif isexpr(arg, :block) && any(a -> isexpr(a, :...) || isexpr(a, :(=)) || isexpr(a, :kw), arg.args)
269+
# Has keywords in a block
270+
pos_args = []
271+
kw_args = []
272+
for a in arg.args
273+
if !(a isa LineNumberNode)
274+
if isexpr(a, :...)
275+
push!(kw_args, a)
276+
elseif isexpr(a, :(=))
277+
# Transform = to :kw for keyword arguments
278+
push!(kw_args, Expr(:kw, a.args[1], a.args[2]))
279+
elseif isexpr(a, :kw)
280+
push!(kw_args, a)
281+
else
282+
push!(pos_args, a)
283+
end
284+
end
285+
end
286+
fcall = Expr(:tuple, Expr(:parameters, kw_args...), pos_args...)
287+
288+
Expr(:function, fcall, body)
289+
elseif isexpr(arg, :...)
290+
# Special case for a varargs argument
291+
Expr(:function, Expr(:tuple, arg), body)
292+
else
293+
Expr(:function, arg isa Symbol ? :($arg,) : arg, body)
294+
end
263295
elseif isshortdef(ex)
264296
@assert @capture(ex, (fcall_ = body_))
265297
Expr(:function, fcall, body)
@@ -324,8 +356,13 @@ function splitdef(fdef)
324356
(func_(args__)) |
325357
(func_(args__)::rtype_)))
326358
elseif isexpr(fcall_nowhere, :tuple)
327-
if length(fcall_nowhere.args) > 1 && isexpr(fcall_nowhere.args[1], :parameters)
328-
args = fcall_nowhere.args[2:end]
359+
if length(fcall_nowhere.args) > 0 && isexpr(fcall_nowhere.args[1], :parameters)
360+
# Handle both cases: parameters with args and parameters only
361+
if length(fcall_nowhere.args) > 1
362+
args = fcall_nowhere.args[2:end]
363+
else
364+
args = []
365+
end
329366
kwargs = fcall_nowhere.args[1].args
330367
else
331368
args = fcall_nowhere.args

test/split.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,46 @@ let
8383
@test (@splitcombine function (x::T, y::Vector{U}) where T <: U where U
8484
(T, U)
8585
end)(1, Number[2.0]) == (Int, Number)
86+
87+
# Test for lambda expressions with keyword arguments
88+
@test (@splitcombine (a::Int; b=2) -> a + b)(1) === 3
89+
@test (@splitcombine (a::Int; b::Float64=2.0) -> Float64(a) + b)(1) === 3.0
90+
@test (@splitcombine (a::Int, x; b=2, c=3) -> a + b + c + x)(1, 4) === 10
91+
@test (@splitcombine (a::Int, x=2) -> a + x)(1) === 3
92+
@test (@splitcombine (a::Int, x=2; y) -> a + x + y)(1; y=3) === 6
93+
@test (@splitcombine (a, x::Int=2; y) -> a + x + y)(1; y=3) === 6
94+
@test (@splitcombine (a::Int, x::Int=2; y) -> a + x + y)(1; y=3) === 6
95+
96+
# With tuple unpacking
97+
@test (@splitcombine (((a, b)::Tuple{Int, Int}, c; d=1) -> a + b + c + d))((1, 2), 3; d=4) === 10
98+
@test (@splitcombine ((c, (a, b); d=1) -> a + b + c + d))(3, (1, 2); d=4) === 10
99+
@test (@splitcombine ((c, (a, b); d) -> a + b + c + d))(3, (1, 2); d=4) === 10
100+
101+
# Test for single varargs argument in lambda
102+
@test splitdef(Meta.parse("(args...) -> 0"))[:args] == [:(args...)]
103+
@test (@splitcombine (args...) -> sum(args))(1, 2, 3) == 6
104+
@test (@splitcombine (args::Int...) -> sum(args))(1, 2, 3) == 6
105+
@test (@splitcombine (args::Int...; y=2) -> sum(args) + y)(1, 2, 3) == 8
106+
@test (@splitcombine (arg, args::Int...; y=2) -> arg + sum(args) + y)(1, 2, 3) == 8
107+
@test (@splitcombine (::Int...) -> 1)(1, 2, 3) === 1
108+
109+
# Splatted keyword arguments
110+
@test (@splitcombine (a::Int; kws...) -> a + sum(values(kws)))(1; b=2, c=3) == 6
111+
@test (@splitcombine (; kws...) -> sum(values(kws)))(b=2, c=3) == 5
112+
@test (@splitcombine (a::Int; b, kws...) -> a + b + sum(values(kws)))(1; b=2, c=3) == 6
113+
@test (@splitcombine (a::Int; b=2, kws...) -> a + b + sum(values(kws)))(1; c=3) == 6
114+
115+
# Both splatted positional and keyword arguments
116+
@test (@splitcombine (a::Int, args::Int...; kws...) -> a + sum(args) + sum(values(kws)))(1, 2, 3; b=4, c=5) == 15
117+
@test (@splitcombine (a, ::Int...; b, kws...) -> a + sum(values(kws)))(1, 2, 3; b=4, c=5) == 1 + 5
118+
119+
# Issue with longdef
120+
ex = longdef(:((a::Int; b=2) -> a + b))
121+
any_kw(ex) = ex isa Expr ? (any_kw(ex.head) || any(any_kw, ex.args)) : ex == :kw
122+
@test any_kw(ex)
123+
## ^Ensure we get a :kw expression in the output AST
124+
@test eval(ex) isa Function
125+
## Shouldn't have issues evaluating
86126
end
87127

88128
@testset "combinestructdef, splitstructdef" begin

0 commit comments

Comments
 (0)