Skip to content

Commit 8303d95

Browse files
committed
modify weighting example for the state weighted formulation
1 parent 81dda39 commit 8303d95

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

notebooks/weighted-cost-func.ipynb

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"fhn = fhn.FitzHughNagumo()\n",
4040
"training_data = fhn.solve_ivps(\n",
4141
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n",
42-
" tspan=[0.0, 10.0],\n",
42+
" tspan=[0.0, 6.0],\n",
4343
" sampling_period=0.1\n",
4444
")"
4545
]
@@ -66,15 +66,27 @@
6666
" \n",
6767
"\n",
6868
" # weight good trajectory by its 1 norm\n",
69-
" w = np.sum(traj.abs().states, axis=1)\n",
69+
" #w = np.sum(traj.abs().states, axis=1)\n",
70+
" w = np.ones(traj.states.shape)\n",
7071
" weights.append(w)\n",
7172
"\n",
7273
" # weight garbage trajectory to zero\n",
73-
" w = np.zeros(len(traj.states))\n",
74+
" #w = np.zeros(len(traj.states))\n",
75+
" w = np.zeros(traj.states.shape)\n",
7476
" weights.append(w)\n",
7577
"\n",
7678
"# you can also use a dict to name the trajectories if using TrajectoriesData (numpy arrays are named by their index number)\n",
77-
"weights = {idx: w for idx, w in enumerate(weights)}"
79+
"#weights = {idx: w for idx, w in enumerate(weights)}"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"id": "280b4bb3-4f7d-4a94-a983-663c6255bc83",
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"weights[1].shape"
7890
]
7991
},
8092
{
@@ -93,7 +105,7 @@
93105
" learning_weights=weights, # weight the eDMD algorithm objectives\n",
94106
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
95107
" opt=\"grid\", # grid search to find best hyperparameters\n",
96-
" n_obs=200, # maximum number of observables to try\n",
108+
" n_obs=40, # maximum number of observables to try\n",
97109
" max_opt_iter=200, # maximum number of optimization iterations\n",
98110
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
99111
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
@@ -117,7 +129,7 @@
117129
" learning_weights=None, # don't use eDMD weighting\n",
118130
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
119131
" opt=\"grid\", # grid search to find best hyperparameters\n",
120-
" n_obs=200, # maximum number of observables to try\n",
132+
" n_obs=40, # maximum number of observables to try\n",
121133
" max_opt_iter=200, # maximum number of optimization iterations\n",
122134
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
123135
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
@@ -178,9 +190,10 @@
178190
"plt.figure(figsize=(10, 6))\n",
179191
"\n",
180192
"# plot the results\n",
193+
"plt.plot(*true_trajectory.states.T, linewidth=2, label='Ground Truth')\n",
181194
"plt.plot(*trajectory.states.T, label='Weighted Trajectory Prediction')\n",
182195
"plt.plot(*trajectory_uw.states.T, label='Trajectory Prediction')\n",
183-
"plt.plot(*true_trajectory.states.T, label='Ground Truth')\n",
196+
"\n",
184197
"\n",
185198
"plt.xlabel(\"$x_1$\")\n",
186199
"plt.ylabel(\"$x_2$\")\n",

0 commit comments

Comments
 (0)