|
39 | 39 | "fhn = fhn.FitzHughNagumo()\n",
|
40 | 40 | "training_data = fhn.solve_ivps(\n",
|
41 | 41 | " initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n",
|
42 |
| - " tspan=[0.0, 10.0],\n", |
| 42 | + " tspan=[0.0, 6.0],\n", |
43 | 43 | " sampling_period=0.1\n",
|
44 | 44 | ")"
|
45 | 45 | ]
|
|
66 | 66 | " \n",
|
67 | 67 | "\n",
|
68 | 68 | " # weight good trajectory by its 1 norm\n",
|
69 |
| - " w = np.sum(traj.abs().states, axis=1)\n", |
| 69 | + " #w = np.sum(traj.abs().states, axis=1)\n", |
| 70 | + " w = np.ones(traj.states.shape)\n", |
70 | 71 | " weights.append(w)\n",
|
71 | 72 | "\n",
|
72 | 73 | " # weight garbage trajectory to zero\n",
|
73 |
| - " w = np.zeros(len(traj.states))\n", |
| 74 | + " #w = np.zeros(len(traj.states))\n", |
| 75 | + " w = np.zeros(traj.states.shape)\n", |
74 | 76 | " weights.append(w)\n",
|
75 | 77 | "\n",
|
76 | 78 | "# you can also use a dict to name the trajectories if using TrajectoriesData (numpy arrays are named by their index number)\n",
|
77 |
| - "weights = {idx: w for idx, w in enumerate(weights)}" |
| 79 | + "#weights = {idx: w for idx, w in enumerate(weights)}" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "cell_type": "code", |
| 84 | + "execution_count": null, |
| 85 | + "id": "280b4bb3-4f7d-4a94-a983-663c6255bc83", |
| 86 | + "metadata": {}, |
| 87 | + "outputs": [], |
| 88 | + "source": [ |
| 89 | + "weights[1].shape" |
78 | 90 | ]
|
79 | 91 | },
|
80 | 92 | {
|
|
93 | 105 | " learning_weights=weights, # weight the eDMD algorithm objectives\n",
|
94 | 106 | " scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
|
95 | 107 | " opt=\"grid\", # grid search to find best hyperparameters\n",
|
96 |
| - " n_obs=200, # maximum number of observables to try\n", |
| 108 | + " n_obs=40, # maximum number of observables to try\n", |
97 | 109 | " max_opt_iter=200, # maximum number of optimization iterations\n",
|
98 | 110 | " grid_param_slices=5, # for grid search, number of slices for each parameter\n",
|
99 | 111 | " n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
|
|
117 | 129 | " learning_weights=None, # don't use eDMD weighting\n",
|
118 | 130 | " scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
|
119 | 131 | " opt=\"grid\", # grid search to find best hyperparameters\n",
|
120 |
| - " n_obs=200, # maximum number of observables to try\n", |
| 132 | + " n_obs=40, # maximum number of observables to try\n", |
121 | 133 | " max_opt_iter=200, # maximum number of optimization iterations\n",
|
122 | 134 | " grid_param_slices=5, # for grid search, number of slices for each parameter\n",
|
123 | 135 | " n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
|
|
178 | 190 | "plt.figure(figsize=(10, 6))\n",
|
179 | 191 | "\n",
|
180 | 192 | "# plot the results\n",
|
| 193 | + "plt.plot(*true_trajectory.states.T, linewidth=2, label='Ground Truth')\n", |
181 | 194 | "plt.plot(*trajectory.states.T, label='Weighted Trajectory Prediction')\n",
|
182 | 195 | "plt.plot(*trajectory_uw.states.T, label='Trajectory Prediction')\n",
|
183 |
| - "plt.plot(*true_trajectory.states.T, label='Ground Truth')\n", |
| 196 | + "\n", |
184 | 197 | "\n",
|
185 | 198 | "plt.xlabel(\"$x_1$\")\n",
|
186 | 199 | "plt.ylabel(\"$x_2$\")\n",
|
|
0 commit comments