We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4708b84 commit f78b4fcCopy full SHA for f78b4fc
src/optimize_2d_momentum_bumpy_torch.py
@@ -33,7 +33,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
33
import matplotlib.pyplot as plt
34
from matplotlib import cm
35
36
- # TODO: use jax to find the gradient.
+ # TODO: use torch to find the gradient.
37
38
nx, ny = (1001, 1001)
39
x = th.linspace(-3, 3, nx)
@@ -57,7 +57,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
57
step_total = 100
58
59
pos_list = [start_pos]
60
- velocity_vec = np.array((0.0, 0.0))
+ velocity_vec = th.tensor((0.0, 0.0))
61
# TODO: Implement gradient descent with momentum.
62
63
for pos in pos_list:
@@ -69,5 +69,5 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
69
np.array(my),
70
np.array(mz),
71
pos_list,
72
- "writer_grad_bumpy_plot_jax",
+ "writer_grad_bumpy_plot_torch",
73
)
0 commit comments