@@ -9,7 +9,7 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
9
9
(width, height, color channels, batch dimension).
10
10
11
11
## Keyword arguments
12
- - `colorscheme::ColorScheme`: Color scheme from ColorSchemes.jl.
12
+ - `colorscheme::Union{ ColorScheme,Symbol} `: Color scheme from ColorSchemes.jl.
13
13
Defaults to `seismic`.
14
14
- `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme.
15
15
The following methods can be selected, which are then applied over the color channels
@@ -32,14 +32,15 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
32
32
"""
33
33
function heatmap (
34
34
val:: AbstractArray{T,N} ;
35
- colorscheme:: ColorScheme = DEFAULT_COLORSCHEME,
35
+ colorscheme:: Union{ ColorScheme,Symbol} = DEFAULT_COLORSCHEME,
36
36
reduce:: Symbol = DEFAULT_REDUCE,
37
37
rangescale:: Symbol = DEFAULT_RANGESCALE,
38
38
permute:: Bool = true ,
39
39
unpack_singleton:: Bool = true ,
40
40
process_batch:: Bool = false ,
41
41
) where {T,N}
42
42
N != 4 && throw (InputDimensionError)
43
+ colorscheme = get_colorscheme (colorscheme)
43
44
if unpack_singleton && size (val, 4 ) == 1
44
45
return single_heatmap (val[:, :, :, 1 ], colorscheme, reduce, rangescale, permute)
45
46
end
@@ -58,6 +59,9 @@ const InputDimensionError = ArgumentError(
58
59
Please reshape your input to match this format if your model doesn't adhere to this convention." ,
59
60
)
60
61
62
+ get_colorscheme (c:: ColorScheme ) = c
63
+ get_colorscheme (s:: Symbol ):: ColorScheme = colorschemes[s]
64
+
61
65
# Lower level function, mapped along batch dimension
62
66
function single_heatmap (
63
67
val, colorscheme:: ColorScheme , reduce:: Symbol , rangescale:: Symbol , permute:: Bool
0 commit comments