Skip to content

Commit 96a95c2

Browse files
authored
Fix keyword argument add_batch_dim (#8)
1 parent 67eb0f1 commit 96a95c2

File tree

5 files changed

+52
-1
lines changed

5 files changed

+52
-1
lines changed

src/XAIBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ module XAIBase
33
using TextHeatmaps
44
using VisionHeatmaps
55

6+
include("compat.jl")
7+
include("utils.jl")
8+
69
# Abstract super type of all XAI methods.
710
# Is expected that all methods are callable types that return an `Explanation`:
811
#

src/analyze.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
const BATCHDIM_MISSING = ArgumentError(
33
"""The input is a 1D vector and therefore missing the required batch dimension.
4-
Call `analyze` with the keyword argument `add_batch_dim=false`."""
4+
Call `analyze` with the keyword argument `add_batch_dim=true`."""
55
)
66

77
"""

src/compat.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# https://github.com/JuliaLang/julia/pull/39794
2+
if VERSION < v"1.7.0-DEV.793"
3+
export Returns
4+
5+
struct Returns{V} <: Function
6+
value::V
7+
Returns{V}(value) where {V} = new{V}(value)
8+
Returns(value) = new{Core.Typeof(value)}(value)
9+
end
10+
11+
(obj::Returns)(args...; kw...) = obj.value
12+
function Base.show(io::IO, obj::Returns)
13+
show(io, typeof(obj))
14+
print(io, "(")
15+
show(io, obj.value)
16+
return print(io, ")")
17+
end
18+
end

src/utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
batch_dim_view(A)
3+
4+
Return a view onto the array `A` that contains an extra singleton batch dimension at the end.
5+
This avoids allocating a new array.
6+
7+
## Example
8+
```juliarepl
9+
julia> A = [1 2; 3 4]
10+
2×2 Matrix{Int64}:
11+
1 2
12+
3 4
13+
14+
julia> batch_dim_view(A)
15+
2×2×1 view(::Array{Int64, 3}, 1:2, 1:2, :) with eltype Int64:
16+
[:, :, 1] =
17+
1 2
18+
3 4
19+
```
20+
"""
21+
batch_dim_view(A::AbstractArray{T,N}) where {T,N} = view(A, ntuple(Returns(:), N + 1)...)

test/test_api.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ expl = analyze(input, analyzer)
2020
expl = analyzer(input)
2121
@test expl.val == val
2222

23+
# Max activation + add_batch_dim
24+
input_vec = [1, 2, 3]
25+
expl = analyzer(input_vec; add_batch_dim=true)
26+
@test expl.val == val[:, 1:1]
27+
2328
# Ouput selection
2429
output_neuron = 2
2530
val = [2 30; 4 25; 6 20]
@@ -29,3 +34,7 @@ expl = analyze(input, analyzer, output_neuron)
2934
@test isnothing(expl.extras)
3035
expl = analyzer(input, output_neuron)
3136
@test expl.val == val
37+
38+
# Ouput selection + add_batch_dim
39+
expl = analyzer(input_vec, output_neuron; add_batch_dim=true)
40+
@test expl.val == val[:, 1:1]

0 commit comments

Comments
 (0)