Skip to content

Commit 245fb69

Browse files
converted notebook to script
1 parent 235ec0e commit 245fb69

File tree

4 files changed

+166
-0
lines changed

4 files changed

+166
-0
lines changed

scripts/PCA_recon.png

47.8 KB
Loading

scripts/belief_script.jl

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
using Statistics
2+
using StableRNGs
3+
using JSON
4+
5+
using Plots
6+
using POMDPs
7+
using POMDPTools
8+
using CompressedBeliefMDPs
9+
10+
using ExpFamilyPCA
11+
12+
rng = StableRNG(123)
13+
logocolors = Colors.JULIA_LOGO_COLORS
14+
15+
function is_terminal_belief(belief)
16+
return all(x -> x == 0, belief[1:end-1]) && belief[end] == 1
17+
end
18+
19+
function drop_terminal_beliefs(B)
20+
B_numerical = make_numerical(B, pomdp)
21+
B_numerical = filter(row -> !is_terminal_belief(row), eachrow(B_numerical)) # exclude belief in terminal state
22+
B_numerical = reduce(hcat, B_numerical)'
23+
B_numerical = B_numerical[:, 1:end - 1]
24+
return B_numerical
25+
end
26+
27+
pomdp = CircularMaze(2, 100)
28+
policy = RandomPolicy(pomdp; rng = rng)
29+
sampler = PolicySampler(pomdp; policy = policy, rng = rng, n = 40)
30+
beliefs = drop_terminal_beliefs(sampler(pomdp))
31+
32+
@show size(beliefs)
33+
plot(x, b1, linestyle=:solid, linewidth=2, legend=false)
34+
xlabel!("State")
35+
ylabel!("Probability")
36+
title!("Initial Belief State")
37+
savefig("ExpFamilyPCA/scripts/init_belief.png")
38+
39+
x = 1:length(states(pomdp)) - 1
40+
b1 = beliefs[1, :]
41+
n, indim = size(beliefs)
42+
outdim = 10
43+
44+
pca = GaussianEPCA(indim, outdim)
45+
A1 = ExpFamilyPCA.fit!(pca, beliefs; verbose=true)
46+
47+
recon1 = decompress(pca, A1)
48+
plot(x, b1,
49+
linestyle=:solid,
50+
linewidth=1,
51+
color=logocolors.red,
52+
label="Original Belief"
53+
)
54+
55+
plot!(x, recon1[1, :],
56+
linestyle=:dash,
57+
linewidth=3,
58+
color=logocolors.blue,
59+
label="Reconstructed Belief"
60+
)
61+
62+
xlabel!("State", fontsize=12)
63+
ylabel!("Probability", fontsize=12)
64+
title!("PCA Reconstruction", fontsize=14)
65+
66+
xlims!(0, 200)
67+
ylims!(0.003, 0.008)
68+
69+
plot!(legend=:topright, legendfontsize=10, framestyle=:box)
70+
p1 = plot!(xtickfontsize=10, ytickfontsize=10)
71+
savefig("ExpFamilyPCA/scripts/PCA_recon.png")
72+
73+
epca = PoissonEPCA(indim, outdim)
74+
A2 = ExpFamilyPCA.fit!(epca, beliefs; verbose=true)
75+
76+
recon2 = decompress(epca, A2)
77+
plot(x, b1,
78+
linestyle=:solid,
79+
linewidth=1,
80+
color=logocolors.red,
81+
label="Original",
82+
dpi=600
83+
)
84+
85+
plot!(x, recon1[1, :],
86+
linestyle=:solid,
87+
linewidth=2,
88+
color=logocolors.green,
89+
label="PCA",
90+
dpi=600
91+
)
92+
93+
plot!(x, recon2[1, :],
94+
linestyle=:dash,
95+
linewidth=3,
96+
color=logocolors.blue,
97+
label="EPCA",
98+
dpi=600
99+
)
100+
101+
label_size = 8
102+
tick_size = 8
103+
xlabel!("State", fontsize=label_size)
104+
ylabel!("Probability", fontsize=label_size)
105+
title!("Belief Reconstructions", fontsize=14)
106+
107+
xlims!(0, 200)
108+
ylims!(0.003, 0.008)
109+
110+
plot!(legend=:topright, legendfontsize=10, framestyle=:box)
111+
figure2 = plot!(xtickfontsize=tick_size, ytickfontsize=tick_size)
112+
savefig("ExpFamilyPCA/scripts/reconstructions.png")
113+
114+
epca_data = Dict(parse(Int, k) => v for (k, v) in JSON.parsefile("ExpFamilyPCA/scripts/jsons/epca_data.json"))
115+
pca_data = Dict(parse(Int, k) => v for (k, v) in JSON.parsefile("ExpFamilyPCA/scripts/jsons/pca_data.json"))
116+
117+
kl_divs_poisson_epca = []
118+
kl_divs_gaussian_epca = []
119+
120+
outdims = 1:6
121+
for k in outdims
122+
push!(kl_divs_poisson_epca, mean(epca_data[k]))
123+
push!(kl_divs_gaussian_epca, mean(pca_data[k]))
124+
end
125+
126+
plot(
127+
outdims,
128+
kl_divs_poisson_epca,
129+
label="EPCA",
130+
color=logocolors.blue,
131+
yscale=:log10,
132+
marker=:cross,
133+
linestyle=:solid,
134+
lw=2,
135+
dpi=600
136+
)
137+
plot!(
138+
outdims,
139+
kl_divs_gaussian_epca,
140+
label="PCA",
141+
yscale=:log10,
142+
color=logocolors.green,
143+
marker=:x,
144+
linestyle=:dash,
145+
lw=2,
146+
dpi=600
147+
148+
)
149+
title!("KL Divergence Across Bases", fontsize=14)
150+
xlabel!("Number of Bases", fontsize=label_size)
151+
ylabel!("Average KL Divergence", fontsize=label_size)
152+
figure1 = plot!(
153+
legendfontsize = 10,
154+
xtickfontsize = tick_size,
155+
ytickfontsize = tick_size,
156+
framestyle = :box,
157+
aspectration = :equal,
158+
markerstrokewidth = 2, # Makes markers more prominent
159+
markeralpha = 0.7 # Adds transparency to markers
160+
)
161+
savefig("ExpFamilyPCA/scripts/kl_divergence_plot.png")
162+
163+
164+
page_length_in_pixels = 11 * 72 # 11 inches at 72 DPI ~ 800 pixels
165+
combined_plot = plot(figure1, figure2, layout = (1, 2), size = (page_length_in_pixels, 400), left_margin = 2 * Plots.mm, bottom_margin = 2 * Plots.mm, dpi=800)
166+
savefig("ExpFamilyPCA/scripts/combo.png")

scripts/init_belief.png

27.5 KB
Loading

scripts/kl_divergence_plot.png

-11.8 KB
Loading

0 commit comments

Comments
 (0)