Skip to content

Commit 21ec468

Browse files
authored
Merge pull request #4 from trallard/trallard/patch-pr-110
Fix headings hierarchy
2 parents 9aeda02 + 8027acc commit 21ec468

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

docs/data_loaders_for_multi_device_setups_with_jax.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"id": "TsFdlkSZKp9S"
5858
},
5959
"source": [
60-
"### Checking TPU Availability for JAX"
60+
"## Checking TPU Availability for JAX"
6161
]
6262
},
6363
{
@@ -99,7 +99,7 @@
9999
"id": "qyJ_WTghDnIc"
100100
},
101101
"source": [
102-
"### Setting Hyperparameters and Initializing Parameters\n",
102+
"## Setting Hyperparameters and Initializing Parameters\n",
103103
"\n",
104104
"You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network."
105105
]
@@ -141,7 +141,7 @@
141141
"id": "rHLdqeI7D2WZ"
142142
},
143143
"source": [
144-
"### Model Prediction with Auto-Batching\n",
144+
"## Model Prediction with Auto-Batching\n",
145145
"\n",
146146
"In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n",
147147
"\n",
@@ -210,7 +210,7 @@
210210
"id": "rLqfeORsERek"
211211
},
212212
"source": [
213-
"### Utility and Loss Functions\n",
213+
"## Utility and Loss Functions\n",
214214
"\n",
215215
"You'll now define utility functions for:\n",
216216
"- One-hot encoding: Converts class indices to binary vectors.\n",

docs/data_loaders_for_multi_device_setups_with_jax.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ from jax.sharding import Mesh, PartitionSpec, NamedSharding
4444

4545
+++ {"id": "TsFdlkSZKp9S"}
4646

47-
### Checking TPU Availability for JAX
47+
## Checking TPU Availability for JAX
4848

4949
```{code-cell}
5050
---
@@ -58,7 +58,7 @@ jax.devices()
5858

5959
+++ {"id": "qyJ_WTghDnIc"}
6060

61-
### Setting Hyperparameters and Initializing Parameters
61+
## Setting Hyperparameters and Initializing Parameters
6262

6363
You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network.
6464

@@ -90,7 +90,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0))
9090

9191
+++ {"id": "rHLdqeI7D2WZ"}
9292

93-
### Model Prediction with Auto-Batching
93+
## Model Prediction with Auto-Batching
9494

9595
In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.
9696

@@ -139,7 +139,7 @@ sharding_spec = PartitionSpec('device')
139139

140140
+++ {"id": "rLqfeORsERek"}
141141

142-
### Utility and Loss Functions
142+
## Utility and Loss Functions
143143

144144
You'll now define utility functions for:
145145
- One-hot encoding: Converts class indices to binary vectors.

0 commit comments

Comments
 (0)