Skip to content

Commit f78b4fc

Browse files
committed
fix np to torch in bumpy torch
1 parent 4708b84 commit f78b4fc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/optimize_2d_momentum_bumpy_torch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
3333
import matplotlib.pyplot as plt
3434
from matplotlib import cm
3535

36-
# TODO: use jax to find the gradient.
36+
# TODO: use torch to find the gradient.
3737

3838
nx, ny = (1001, 1001)
3939
x = th.linspace(-3, 3, nx)
@@ -57,7 +57,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
5757
step_total = 100
5858

5959
pos_list = [start_pos]
60-
velocity_vec = np.array((0.0, 0.0))
60+
velocity_vec = th.tensor((0.0, 0.0))
6161
# TODO: Implement gradient descent with momentum.
6262

6363
for pos in pos_list:
@@ -69,5 +69,5 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
6969
np.array(my),
7070
np.array(mz),
7171
pos_list,
72-
"writer_grad_bumpy_plot_jax",
72+
"writer_grad_bumpy_plot_torch",
7373
)

0 commit comments

Comments
 (0)