File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -33,7 +33,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
33
33
import matplotlib .pyplot as plt
34
34
from matplotlib import cm
35
35
36
- # TODO: use jax to find the gradient.
36
+ # TODO: use torch to find the gradient.
37
37
38
38
nx , ny = (1001 , 1001 )
39
39
x = th .linspace (- 3 , 3 , nx )
@@ -57,7 +57,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
57
57
step_total = 100
58
58
59
59
pos_list = [start_pos ]
60
- velocity_vec = np . array ((0.0 , 0.0 ))
60
+ velocity_vec = th . tensor ((0.0 , 0.0 ))
61
61
# TODO: Implement gradient descent with momentum.
62
62
63
63
for pos in pos_list :
@@ -69,5 +69,5 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor:
69
69
np .array (my ),
70
70
np .array (mz ),
71
71
pos_list ,
72
- "writer_grad_bumpy_plot_jax " ,
72
+ "writer_grad_bumpy_plot_torch " ,
73
73
)
You can’t perform that action at this time.
0 commit comments