Skip to content

Commit 3df1717

Browse files
authored
Access color schemes through symbols (#3)
1 parent c7ec885 commit 3df1717

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

src/VisionHeatmaps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module VisionHeatmaps
22

3-
using ColorSchemes: ColorScheme, get, seismic
3+
using ColorSchemes: ColorScheme, colorschemes, get, seismic
44
using ImageCore
55

66
include("heatmap.jl")

src/heatmap.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
99
(width, height, color channels, batch dimension).
1010
1111
## Keyword arguments
12-
- `colorscheme::ColorScheme`: Color scheme from ColorSchemes.jl.
12+
- `colorscheme::Union{ColorScheme,Symbol}`: Color scheme from ColorSchemes.jl.
1313
Defaults to `seismic`.
1414
- `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme.
1515
The following methods can be selected, which are then applied over the color channels
@@ -32,14 +32,15 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
3232
"""
3333
function heatmap(
3434
val::AbstractArray{T,N};
35-
colorscheme::ColorScheme=DEFAULT_COLORSCHEME,
35+
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
3636
reduce::Symbol=DEFAULT_REDUCE,
3737
rangescale::Symbol=DEFAULT_RANGESCALE,
3838
permute::Bool=true,
3939
unpack_singleton::Bool=true,
4040
process_batch::Bool=false,
4141
) where {T,N}
4242
N != 4 && throw(InputDimensionError)
43+
colorscheme = get_colorscheme(colorscheme)
4344
if unpack_singleton && size(val, 4) == 1
4445
return single_heatmap(val[:, :, :, 1], colorscheme, reduce, rangescale, permute)
4546
end
@@ -58,6 +59,9 @@ const InputDimensionError = ArgumentError(
5859
Please reshape your input to match this format if your model doesn't adhere to this convention.",
5960
)
6061

62+
get_colorscheme(c::ColorScheme) = c
63+
get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s]
64+
6165
# Lower level function, mapped along batch dimension
6266
function single_heatmap(
6367
val, colorscheme::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool

test/test_heatmap.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ end
3030
@testset "ColorSchemes" begin
3131
h = heatmap(A; colorscheme=ColorSchemes.inferno)
3232
@test_reference "references/inferno.txt" h
33+
34+
# Test colorscheme symbols
35+
h = heatmap(A; colorscheme=:inferno)
36+
@test_reference "references/inferno.txt" h
3337
end
3438

3539
@testset "Error handling" begin

0 commit comments

Comments
 (0)