diff --git a/src/recode.jl b/src/recode.jl index 0242a4c4..8493741e 100644 --- a/src/recode.jl +++ b/src/recode.jl @@ -33,6 +33,8 @@ recode!(dest::AbstractArray, src::AbstractArray, pairs::Pair...) = # To fix ambiguity recode!(dest::CategoricalArray, src::AbstractArray, pairs::Pair...) = recode!(dest, src, nothing, pairs...) +recode!(dest::AbstractArray, src::CategoricalArray, pairs::Pair...) = + recode!(dest, src, nothing, pairs...) recode!(dest::CategoricalArray, src::CategoricalArray, pairs::Pair...) = recode!(dest, src, nothing, pairs...) @@ -52,6 +54,29 @@ A user defined type could override this method to define an appropriate test fun optimize_pair(pair::Pair) = pair optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second +function missing_check(value) + ismissing(value) && throw(MissingException("missing value found, but dest does not support them: " * + "recode them to a supported value")) + value +end + +function recode!(dest::AbstractArray{T}, src::CategoricalArray, default::Any, pairs::Pair...) where {T} + if length(dest) != length(src) + throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))")) + end + + pairs = map(pairs) do p + p.first => convert(T, p.second) + end + recoded = recode(src, default, pairs...) + if T >: Missing + dest .= unwrap.(recoded) + else + dest .= missing_check.(unwrap.(recoded)) + end + dest +end + function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T} if length(dest) != length(src) throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))")) @@ -59,15 +84,17 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs opt_pairs = map(optimize_pair, pairs) - @inbounds for i in eachindex(dest, src) - x = src[i] + map!(dest, src) do x - for j in 1:length(opt_pairs) - p = opt_pairs[j] - # we use isequal and recode_in because we cannot really distinguish scalars from collections - if x ≅ p.first || recode_in(x, p.first) - dest[i] = p.second - @goto nextitem + # we use isequal and recode_in because we cannot really distinguish scalars from collections + for p in opt_pairs + if x ≅ p.first + return p.second + end + end + for p in opt_pairs + if recode_in(x, p.first) + return p.second end end @@ -76,10 +103,10 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs eltype(dest) >: Missing || throw(MissingException("missing value found, but dest does not support them: " * "recode them to a supported value")) - dest[i] = missing + return missing elseif default isa Nothing try - dest[i] = x + return convert(T, x) catch err isa(err, MethodError) || rethrow(err) throw(ArgumentError("cannot `convert` value $(repr(x)) (of type $(typeof(x))) to type of recoded levels ($T). " * @@ -87,10 +114,8 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs "(i.e. some are preserved) and their type is incompatible with that of recoded levels.")) end else - dest[i] = default + return default end - - @label nextitem end dest