|
| 1 | +############################################################################### |
| 2 | +# Type |
| 3 | +############################################################################### |
| 4 | + |
| 5 | +# TODO: consider constrain T<:AbstractTimeAxis |
| 6 | +mutable struct TimeTable{T} <: AbstractTimeSeries{T} |
| 7 | + ta::T |
| 8 | + vecs::OrderedDict{Symbol,AbstractVector} |
| 9 | + n::Int # length, in case of infinte time axis |
| 10 | + |
| 11 | + function TimeTable{T}(ta::T, vecs) where {T} |
| 12 | + m = mapreduce(length, max, values(vecs)) |
| 13 | + n = if Base.haslength(T) |
| 14 | + n′ = length(ta) |
| 15 | + (n′ ≥ m) || throw(DimensionMismatch( |
| 16 | + "The vector length should less or equal than the one of time axis")) |
| 17 | + n′ |
| 18 | + else |
| 19 | + m |
| 20 | + end |
| 21 | + |
| 22 | + # note that it will copy, if the length of a col is shorter than `m` |
| 23 | + for (k, v) in vecs |
| 24 | + (length(v) == n) && continue |
| 25 | + vecs[k] = collect(PaddedView(missing, v, (n,))) |
| 26 | + end |
| 27 | + |
| 28 | + new(ta, vecs, n) |
| 29 | + end |
| 30 | + # other design style: |
| 31 | + # colnames::Vector{Symbol} |
| 32 | + # cols::Vector{AbstractVector} |
| 33 | +end |
| 34 | + |
| 35 | +TimeTable(ta::T, vecs::OrderedDict{Symbol}) where T = TimeTable{T}(ta, vecs) |
| 36 | +function TimeTable(ta::T; kw...) where T |
| 37 | + vecs = OrderedDict{Symbol,AbstractVector}() |
| 38 | + for (k, v) ∈ kw |
| 39 | + vecs[k] = v |
| 40 | + end |
| 41 | + TimeTable(ta, vecs) |
| 42 | +end |
| 43 | + |
| 44 | +const TimeTableTimeCol = :time |
| 45 | + |
| 46 | +struct TimeTableRow{T,V} |
| 47 | + i::Int |
| 48 | + t::T |
| 49 | + v::V |
| 50 | +end |
| 51 | + |
| 52 | + |
| 53 | +############################################################################### |
| 54 | +# Iterator interfaces |
| 55 | +############################################################################### |
| 56 | + |
| 57 | +Base.size(tt::TimeTable) = (length(tt), length(keys(_vecs(tt)))) |
| 58 | +Base.size(tt::TimeTable, dim) = |
| 59 | + (dim == 1) ? length(tt) : |
| 60 | + (dim == 2) ? length(keys(_vecs(tt))) : |
| 61 | + 1 |
| 62 | + |
| 63 | +@inline Base.length(tt::TimeTable) = getfield(tt, :n) |
| 64 | + |
| 65 | + |
| 66 | +############################################################################### |
| 67 | +# Indexing |
| 68 | +############################################################################### |
| 69 | + |
| 70 | +Base.lastindex(tt::TimeTable) = getfield(tt, :n) |
| 71 | + |
| 72 | +Base.checkindex(::Type{Bool}, tt::TimeTable, i::Int) = (1 ≤ i ≤ lastindex(tt)) |
| 73 | + |
| 74 | +Base.getindex(tt::TimeTable, s::Symbol) = |
| 75 | + (s ≡ TimeTableTimeCol) ? getfield(tt, :ta) : getvec(tt, s) |
| 76 | + |
| 77 | +function Base.getindex(tt::TimeTable, i::Int) |
| 78 | + @boundscheck checkbounds(tt, i) |
| 79 | + TimeTableRow(i, _ta(tt)[i], map(x -> x[i], values(_vecs(tt)))) |
| 80 | +end |
| 81 | + |
| 82 | +Base.getindex(tt::TimeTable, t::TimeType) = tt[time2idx(tt, t)] |
| 83 | +Base.getindex(tt::TimeTable, i::Int, s::Symbol) = |
| 84 | + (@boundscheck checkbounds(tt, i); (s ≡ TimeTableTimeCol) ? _ta(tt)[i] : _vecs(tt)[s][i]) |
| 85 | +Base.getindex(tt::TimeTable, t::TimeType, s::Symbol) = tt[time2idx(tt, t), s] |
| 86 | + |
| 87 | +for func ∈ [:findfirst, :findlast] |
| 88 | + @eval function Base.$func(f::Function, tt::TimeTable) |
| 89 | + i = $func(f, _ta(tt)) |
| 90 | + isnothing(i) && return nothing |
| 91 | + ifelse(i > getfield(tt, :n), nothing, i) |
| 92 | + end |
| 93 | + |
| 94 | + # TODO: handle case of infinte timegrid for findlast |
| 95 | +end |
| 96 | + |
| 97 | +for func ∈ [:findprev, :findnext] |
| 98 | + @eval function Base.$func(f::Function, tt::TimeTable, j::Int) |
| 99 | + i = $func(f, _ta(tt), j) |
| 100 | + isnothing(i) && return nothing |
| 101 | + ifelse(i > getfield(tt, :n), nothing, i) |
| 102 | + end |
| 103 | +end |
| 104 | + |
| 105 | +function Base.getindex(r::TimeTableRow, i::Int) |
| 106 | + (i == 1) ? r.i : |
| 107 | + (i == 2) ? r.t : |
| 108 | + (i == 3) ? r.v : |
| 109 | + throw(BoundsError(r, i)) |
| 110 | +end |
| 111 | + |
| 112 | +############################################################################### |
| 113 | +# Value modification |
| 114 | +############################################################################### |
| 115 | + |
| 116 | +function Base.setproperty!(tt::TimeTable, name::Symbol, x::AbstractVector) |
| 117 | + (length(tt) != length(x)) && throw(DimensionMismatch("length unmatched")) |
| 118 | + _vecs(tt)[name] = x |
| 119 | +end |
| 120 | + |
| 121 | +# TODO: support time axis modification |
| 122 | +Base.setindex!(tt::TimeTable, v, i::Int, s::Symbol) = |
| 123 | + (@boundscheck checkbounds(tt, i); _vecs(tt)[s][i] = v) |
| 124 | +Base.setindex!(tt::TimeTable, v, t::TimeType, s::Symbol) = (tt[time2idx(tt, t), s] = v) |
| 125 | + |
| 126 | +function Base.resize!(tt::TimeTable, n′::Int) |
| 127 | + n = length(tt) |
| 128 | + (n == n′) && return tt |
| 129 | + |
| 130 | + for v ∈ values(_vecs(tt)) |
| 131 | + resize!(v, n′) |
| 132 | + end |
| 133 | + setfield!(tt, :n, n′) |
| 134 | + tt |
| 135 | +end |
| 136 | + |
| 137 | +function Base.push!(tt::TimeTable{<:TimeGrid}, x::NamedTuple) |
| 138 | + d = _vecs(tt) |
| 139 | + (size(tt, 2) == length(x)) || throw(DimensionMismatch("input length unmatched")) |
| 140 | + |
| 141 | + ks = keys(d) |
| 142 | + for k ∈ keys(x) |
| 143 | + (k ∈ ks) || throw(ArgumentError("unknown column $k")) |
| 144 | + end |
| 145 | + |
| 146 | + for (k, v) ∈ d |
| 147 | + push!(v, x[k]) |
| 148 | + end |
| 149 | + |
| 150 | + n = length(tt) + 1 |
| 151 | + setfield!(tt, :n, n) |
| 152 | + resize!(_ta(tt), n) |
| 153 | + |
| 154 | + tt |
| 155 | +end |
| 156 | + |
| 157 | + |
| 158 | +############################################################################### |
| 159 | +# Time axis modification |
| 160 | +############################################################################### |
| 161 | + |
| 162 | +# TODO: add a `shrink` kwarg for shrinking length after lag/lead |
| 163 | +lag(tt::TimeTable{<:TimeGrid}, n::Int) = TimeTable(_ta(tt) + n, _vecs(tt)) |
| 164 | +lead(tt::TimeTable{<:TimeGrid}, n::Int) = TimeTable(_ta(tt) - n, _vecs(tt)) |
| 165 | + |
| 166 | +# TODO: reindex ? |
| 167 | + |
| 168 | + |
| 169 | +############################################################################### |
| 170 | +# Join |
| 171 | +############################################################################### |
| 172 | + |
| 173 | +# TODO: after DataAPI.jl v0.17 released, import method from it |
| 174 | + |
| 175 | +# TODO: support `on` kwarg |
| 176 | +function innerjoin(x::TimeTable{<:TimeGrid}, y::TimeTable{<:TimeGrid}) |
| 177 | + dx = _vecs(x) |
| 178 | + dy = _vecs(y) |
| 179 | + dz = OrderedDict{Symbol,AbstractVector}() |
| 180 | + |
| 181 | + tax = _ta(x) |
| 182 | + tay = _ta(y) |
| 183 | + |
| 184 | + idxx = Int[] |
| 185 | + idxy = Int[] |
| 186 | + sizehint!(idxy, length(x)) |
| 187 | + sizehint!(idxy, length(x)) |
| 188 | + for (i, j) ∈ enumerate(findall(tax, tay)) |
| 189 | + ismissing(j) && continue |
| 190 | + push!(idxx, i) |
| 191 | + push!(idxy, j) |
| 192 | + end |
| 193 | + |
| 194 | + for (k, v) ∈ dx |
| 195 | + dz[k] = v[idxx] # this will copy |
| 196 | + end |
| 197 | + |
| 198 | + ks = keys(dx) |
| 199 | + for (k, v) ∈ dy |
| 200 | + k′ = ifelse(k ∈ ks, Symbol(k, :_), k) |
| 201 | + dz[k′] = v[idxy] |
| 202 | + end |
| 203 | + |
| 204 | + ta′ = [tax[i] for i ∈ idxx] |
| 205 | + TimeTable(ta′, dz) |
| 206 | +end |
| 207 | + |
| 208 | + |
| 209 | +############################################################################### |
| 210 | +# Private utils |
| 211 | +############################################################################### |
| 212 | + |
| 213 | + |
| 214 | +checkbounds(tt::TimeTable, i::Int) = |
| 215 | + (checkindex(Bool, tt, i) || throw(BoundsError(tt, i)); nothing) |
| 216 | + |
| 217 | +@inline getvec(tt::TimeTable, s::Symbol) = _vecs(tt)[s] |
| 218 | +@inline _vecs(tt::TimeTable) = getfield(tt, :vecs) |
| 219 | +@inline _ta(tt::TimeTable) = getfield(tt, :ta) |
| 220 | + |
| 221 | +@inline time2idx(tt::TimeTable, t::TimeType) = _ta(tt)[t] |
0 commit comments