Skip to content

Commit 8ec2e28

Browse files
committed
Add disk cache infrastructure back with tests
1 parent da06a34 commit 8ec2e28

File tree

8 files changed

+129
-0
lines changed

8 files changed

+129
-0
lines changed

Manifest.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
114114
[[SHA]]
115115
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
116116

117+
[[Scratch]]
118+
deps = ["Dates"]
119+
git-tree-sha1 = "f94f779c94e58bf9ea243e77a37e16d9de9126bd"
120+
uuid = "6c6a2e73-6563-6170-7368-637461726353"
121+
version = "1.1.1"
122+
117123
[[Serialization]]
118124
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
119125

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
99
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1010
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1111
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
13+
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
14+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
15+
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1216
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1317
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1418

src/GPUCompiler.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ using ExprTools: splitdef, combinedef
99

1010
using Libdl
1111

12+
using Preferences
13+
using Scratch
14+
using Serialization
15+
16+
using TOML
17+
# Get the current version at compile-time, that's fine it's not going to change. ;)
18+
function get_version()
19+
return VersionNumber(TOML.parsefile(joinpath(dirname(@__DIR__), "Project.toml"))["version"])
20+
end
21+
const pkg_version = get_version()
22+
1223
const to = TimerOutput()
1324

1425
timings() = (TimerOutputs.print_timer(to); println())

src/cache.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,40 @@ const specialization_counter = Ref{UInt}(0)
6262
return new_ci
6363
end
6464

65+
const disk_cache = parse(Bool, @load_preference("disk_cache", "false"))
66+
const cache_key = @load_preference("cache_key", "")
67+
68+
"""
69+
enable_cache!(state=true)
70+
71+
Activate the GPUCompiler disk cache in the current environment.
72+
You will need to restart your Julia environment for it to take effect.
73+
74+
!!! warning
75+
The disk cache is not automatically invalidated. It is sharded upon
76+
`cache_key` (see [`set_cache_key``](@ref)), the GPUCompiler version
77+
and your Julia version.
78+
"""
79+
function enable_cache!(state=true)
80+
@set_preferences!("disk_cache"=>state)
81+
end
82+
83+
"""
84+
set_cache_key(key)
85+
86+
If you are deploying an application it is recommended that you use your
87+
application name and version as a cache key. To minimize the risk of
88+
encountering spurios cache hits.
89+
"""
90+
function set_cache_key(key)
91+
@set_preferences!("cache_key"=>key)
92+
end
93+
94+
key(ver::VersionNumber) = "$(ver.major)_$(ver.minor)_$(ver.patch)"
95+
cache_path() = @get_scratch!(cache_key * "-kernels-" * key(VERSION) * "-" * key(pkg_version))
96+
clear_disk_cache!() = rm(cache_path(); recursive=true, force=true)
97+
98+
6599
const cache_lock = ReentrantLock()
66100
function cached_compilation(cache::AbstractDict,
67101
@nospecialize(job::CompilerJob),
@@ -81,13 +115,30 @@ function cached_compilation(cache::AbstractDict,
81115
if obj === nothing || force_compilation
82116
asm = nothing
83117

118+
# can we load from the disk cache?
119+
if disk_cache && !force_compilation
120+
path = joinpath(cache_path(), "$key.jls")
121+
if isfile(path)
122+
try
123+
asm = deserialize(path)
124+
@debug "Loading compiled kernel for $spec from $path"
125+
catch ex
126+
@warn "Failed to load compiled kernel at $path" exception=(ex, catch_backtrace())
127+
end
128+
end
129+
end
130+
84131
# compile
85132
if asm === nothing
86133
if compile_hook[] !== nothing
87134
compile_hook[](job)
88135
end
89136

90137
asm = compiler(job)
138+
139+
if disk_cache && !isfile(path)
140+
serialize(path, asm)
141+
end
91142
end
92143

93144
# link (but not if we got here because of forced compilation)

test/CacheEnv/LocalPreferences.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[GPUCompiler]
2+
disk_cache = "true"
3+
cache_key = "test"

test/CacheEnv/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[extras]
2+
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"

test/cache.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using GPUCompiler
2+
using Test
3+
4+
const TOTAL_KERNELS = 1
5+
6+
clear = parse(Bool, ARGS[1])
7+
8+
@test GPUCompiler.disk_cache == true
9+
10+
if clear
11+
GPUCompiler.clear_disk_cache!()
12+
@test length(readdir(GPUCompiler.cache_path())) == 0
13+
else
14+
@test length(readdir(GPUCompiler.cache_path())) == TOTAL_KERNELS
15+
end
16+
17+
using LLVM, LLVM.Interop
18+
19+
include("util.jl")
20+
include("definitions/native.jl")
21+
22+
kernel() = return
23+
24+
const runtime_cache = Dict{UInt, Any}()
25+
26+
function compiler(job)
27+
return GPUCompiler.compile(:asm, job)
28+
end
29+
30+
function linker(job, asm)
31+
asm
32+
end
33+
34+
let (job, kwargs) = native_job(kernel, Tuple{})
35+
GPUCompiler.cached_compilation(runtime_cache, job, compiler, linker)
36+
end
37+
38+
@test length(readdir(GPUCompiler.cache_path())) == TOTAL_KERNELS

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,18 @@ include("examples.jl")
2828

2929
haskey(ENV, "CI") && GPUCompiler.timings()
3030

31+
@testset "Disk cache" begin
32+
@test GPUCompiler.disk_cache == false
33+
34+
cmd = Base.julia_cmd()
35+
if Base.JLOptions().project != C_NULL
36+
cmd = `$cmd --project=$(unsafe_string(Base.JLOptions().project))`
37+
end
38+
39+
withenv("JULIA_LOAD_PATH" => "$(get(ENV, "JULIA_LOAD_PATH", "")):$(joinpath(@__DIR__, "CacheEnv"))" do
40+
@test success(pipeline(`$cmd cache.jl true`, stderr=stderr, stdout=stdout))
41+
@test success(pipeline(`$cmd cache.jl false`, stderr=stderr, stdout=stdout))
42+
end
43+
end
44+
3145
end

0 commit comments

Comments
 (0)