1+ using Statistics
2+ using StableRNGs
3+ using JSON
4+
5+ using Plots
6+ using POMDPs
7+ using POMDPTools
8+ using CompressedBeliefMDPs
9+
10+ using ExpFamilyPCA
11+
12+ rng = StableRNG (123 )
13+ logocolors = Colors. JULIA_LOGO_COLORS
14+
15+ function is_terminal_belief (belief)
16+ return all (x -> x == 0 , belief[1 : end - 1 ]) && belief[end ] == 1
17+ end
18+
19+ function drop_terminal_beliefs (B)
20+ B_numerical = make_numerical (B, pomdp)
21+ B_numerical = filter (row -> ! is_terminal_belief (row), eachrow (B_numerical)) # exclude belief in terminal state
22+ B_numerical = reduce (hcat, B_numerical)'
23+ B_numerical = B_numerical[:, 1 : end - 1 ]
24+ return B_numerical
25+ end
26+
27+ pomdp = CircularMaze (2 , 100 )
28+ policy = RandomPolicy (pomdp; rng = rng)
29+ sampler = PolicySampler (pomdp; policy = policy, rng = rng, n = 40 )
30+ beliefs = drop_terminal_beliefs (sampler (pomdp))
31+
32+ @show size (beliefs)
33+ plot (x, b1, linestyle= :solid , linewidth= 2 , legend= false )
34+ xlabel! (" State" )
35+ ylabel! (" Probability" )
36+ title! (" Initial Belief State" )
37+ savefig (" ExpFamilyPCA/scripts/init_belief.png" )
38+
39+ x = 1 : length (states (pomdp)) - 1
40+ b1 = beliefs[1 , :]
41+ n, indim = size (beliefs)
42+ outdim = 10
43+
44+ pca = GaussianEPCA (indim, outdim)
45+ A1 = ExpFamilyPCA. fit! (pca, beliefs; verbose= true )
46+
47+ recon1 = decompress (pca, A1)
48+ plot (x, b1,
49+ linestyle= :solid ,
50+ linewidth= 1 ,
51+ color= logocolors. red,
52+ label= " Original Belief"
53+ )
54+
55+ plot! (x, recon1[1 , :],
56+ linestyle= :dash ,
57+ linewidth= 3 ,
58+ color= logocolors. blue,
59+ label= " Reconstructed Belief"
60+ )
61+
62+ xlabel! (" State" , fontsize= 12 )
63+ ylabel! (" Probability" , fontsize= 12 )
64+ title! (" PCA Reconstruction" , fontsize= 14 )
65+
66+ xlims! (0 , 200 )
67+ ylims! (0.003 , 0.008 )
68+
69+ plot! (legend= :topright , legendfontsize= 10 , framestyle= :box )
70+ p1 = plot! (xtickfontsize= 10 , ytickfontsize= 10 )
71+ savefig (" ExpFamilyPCA/scripts/PCA_recon.png" )
72+
73+ epca = PoissonEPCA (indim, outdim)
74+ A2 = ExpFamilyPCA. fit! (epca, beliefs; verbose= true )
75+
76+ recon2 = decompress (epca, A2)
77+ plot (x, b1,
78+ linestyle= :solid ,
79+ linewidth= 1 ,
80+ color= logocolors. red,
81+ label= " Original" ,
82+ dpi= 600
83+ )
84+
85+ plot! (x, recon1[1 , :],
86+ linestyle= :solid ,
87+ linewidth= 2 ,
88+ color= logocolors. green,
89+ label= " PCA" ,
90+ dpi= 600
91+ )
92+
93+ plot! (x, recon2[1 , :],
94+ linestyle= :dash ,
95+ linewidth= 3 ,
96+ color= logocolors. blue,
97+ label= " EPCA" ,
98+ dpi= 600
99+ )
100+
101+ label_size = 8
102+ tick_size = 8
103+ xlabel! (" State" , fontsize= label_size)
104+ ylabel! (" Probability" , fontsize= label_size)
105+ title! (" Belief Reconstructions" , fontsize= 14 )
106+
107+ xlims! (0 , 200 )
108+ ylims! (0.003 , 0.008 )
109+
110+ plot! (legend= :topright , legendfontsize= 10 , framestyle= :box )
111+ figure2 = plot! (xtickfontsize= tick_size, ytickfontsize= tick_size)
112+ savefig (" ExpFamilyPCA/scripts/reconstructions.png" )
113+
114+ epca_data = Dict (parse (Int, k) => v for (k, v) in JSON. parsefile (" ExpFamilyPCA/scripts/jsons/epca_data.json" ))
115+ pca_data = Dict (parse (Int, k) => v for (k, v) in JSON. parsefile (" ExpFamilyPCA/scripts/jsons/pca_data.json" ))
116+
117+ kl_divs_poisson_epca = []
118+ kl_divs_gaussian_epca = []
119+
120+ outdims = 1 : 6
121+ for k in outdims
122+ push! (kl_divs_poisson_epca, mean (epca_data[k]))
123+ push! (kl_divs_gaussian_epca, mean (pca_data[k]))
124+ end
125+
126+ plot (
127+ outdims,
128+ kl_divs_poisson_epca,
129+ label= " EPCA" ,
130+ color= logocolors. blue,
131+ yscale= :log10 ,
132+ marker= :cross ,
133+ linestyle= :solid ,
134+ lw= 2 ,
135+ dpi= 600
136+ )
137+ plot! (
138+ outdims,
139+ kl_divs_gaussian_epca,
140+ label= " PCA" ,
141+ yscale= :log10 ,
142+ color= logocolors. green,
143+ marker= :x ,
144+ linestyle= :dash ,
145+ lw= 2 ,
146+ dpi= 600
147+
148+ )
149+ title! (" KL Divergence Across Bases" , fontsize= 14 )
150+ xlabel! (" Number of Bases" , fontsize= label_size)
151+ ylabel! (" Average KL Divergence" , fontsize= label_size)
152+ figure1 = plot! (
153+ legendfontsize = 10 ,
154+ xtickfontsize = tick_size,
155+ ytickfontsize = tick_size,
156+ framestyle = :box ,
157+ aspectration = :equal ,
158+ markerstrokewidth = 2 , # Makes markers more prominent
159+ markeralpha = 0.7 # Adds transparency to markers
160+ )
161+ savefig (" ExpFamilyPCA/scripts/kl_divergence_plot.png" )
162+
163+
164+ page_length_in_pixels = 11 * 72 # 11 inches at 72 DPI ~ 800 pixels
165+ combined_plot = plot (figure1, figure2, layout = (1 , 2 ), size = (page_length_in_pixels, 400 ), left_margin = 2 * Plots. mm, bottom_margin = 2 * Plots. mm, dpi= 800 )
166+ savefig (" ExpFamilyPCA/scripts/combo.png" )
0 commit comments