Skip to content

Commit 46021eb

Browse files
authored
Fixed linter (#115)
1 parent 5c86a13 commit 46021eb

File tree

3 files changed

+47
-53
lines changed

3 files changed

+47
-53
lines changed

docs/JAX_machine_translation.ipynb

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,12 @@
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
28-
"import os\n",
29-
"\n",
3028
"import pathlib\n",
3129
"import random\n",
3230
"import string\n",
3331
"import re\n",
3432
"import numpy as np\n",
3533
"\n",
36-
"import jax\n",
3734
"import jax.numpy as jnp\n",
3835
"import optax\n",
3936
"\n",
@@ -280,13 +277,13 @@
280277
" self.attention = nnx.MultiHeadAttention(num_heads=num_heads,\n",
281278
" in_features=embed_dim,\n",
282279
" decode=False,\n",
283-
" rngs=rngs) \n",
280+
" rngs=rngs)\n",
284281
" self.dense_proj = nnx.Sequential(\n",
285282
" nnx.Linear(embed_dim, dense_dim, rngs=rngs),\n",
286283
" nnx.relu,\n",
287284
" nnx.Linear(dense_dim, embed_dim, rngs=rngs),\n",
288285
" )\n",
289-
" \n",
286+
"\n",
290287
" self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n",
291288
" self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n",
292289
"\n",
@@ -412,7 +409,7 @@
412409
" def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):\n",
413410
" x = self.positional_embedding(encoder_inputs)\n",
414411
" encoder_outputs = self.encoder(x, mask=mask)\n",
415-
" \n",
412+
"\n",
416413
" x = self.positional_embedding(decoder_inputs)\n",
417414
" decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)\n",
418415
" # per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training\n",
@@ -547,7 +544,7 @@
547544
" loss=loss,\n",
548545
" logits=logits,\n",
549546
" labels=labels,\n",
550-
" ) "
547+
" )"
551548
]
552549
},
553550
{
@@ -941,7 +938,7 @@
941938
"\n",
942939
" input_sentence = custom_standardization(input_sentence)\n",
943940
" tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)\n",
944-
" \n",
941+
"\n",
945942
" decoded_sentence = \"[start\"\n",
946943
" for i in range(sequence_length):\n",
947944
" tokenized_target_sentence = tokenize_and_pad(decoded_sentence, tokenizer, sequence_length)[:-1]\n",
@@ -977,8 +974,8 @@
977974
"for _ in range(10):\n",
978975
" input_sentence = random.choice(test_eng_texts)\n",
979976
" translated = decode_sequence(input_sentence)\n",
980-
" \n",
981-
" test_result_pairs.append(f\"[Input]: {input_sentence} [Translation]: {translated}\") "
977+
"\n",
978+
" test_result_pairs.append(f\"[Input]: {input_sentence} [Translation]: {translated}\")"
982979
]
983980
},
984981
{
@@ -1057,7 +1054,7 @@
10571054
"name": "python",
10581055
"nbconvert_exporter": "python",
10591056
"pygments_lexer": "ipython3",
1060-
"version": "3.11.10"
1057+
"version": "3.11.9"
10611058
}
10621059
},
10631060
"nbformat": 4,

docs/JAX_machine_translation.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
extension: .md
66
format_name: myst
77
format_version: 0.13
8-
jupytext_version: 1.16.4
8+
jupytext_version: 1.15.2
99
kernelspec:
1010
display_name: Python 3 (ipykernel)
1111
language: python
@@ -21,15 +21,12 @@ Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_trans
2121
We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation.
2222

2323
```{code-cell} ipython3
24-
import os
25-
2624
import pathlib
2725
import random
2826
import string
2927
import re
3028
import numpy as np
3129
32-
import jax
3330
import jax.numpy as jnp
3431
import optax
3532
@@ -163,13 +160,13 @@ class TransformerEncoder(nnx.Module):
163160
self.attention = nnx.MultiHeadAttention(num_heads=num_heads,
164161
in_features=embed_dim,
165162
decode=False,
166-
rngs=rngs)
163+
rngs=rngs)
167164
self.dense_proj = nnx.Sequential(
168165
nnx.Linear(embed_dim, dense_dim, rngs=rngs),
169166
nnx.relu,
170167
nnx.Linear(dense_dim, embed_dim, rngs=rngs),
171168
)
172-
169+
173170
self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)
174171
self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)
175172
@@ -283,7 +280,7 @@ class TransformerModel(nnx.Module):
283280
def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):
284281
x = self.positional_embedding(encoder_inputs)
285282
encoder_outputs = self.encoder(x, mask=mask)
286-
283+
287284
x = self.positional_embedding(decoder_inputs)
288285
decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)
289286
# per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training
@@ -382,7 +379,7 @@ def eval_step(model, batch, eval_metrics):
382379
loss=loss,
383380
logits=logits,
384381
labels=labels,
385-
)
382+
)
386383
```
387384

388385
Here, `nnx.MultiMetric` helps us keep track of general training statistics, while we make our own dictionaries to hold historical values
@@ -495,7 +492,7 @@ def decode_sequence(input_sentence):
495492
496493
input_sentence = custom_standardization(input_sentence)
497494
tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)
498-
495+
499496
decoded_sentence = "[start"
500497
for i in range(sequence_length):
501498
tokenized_target_sentence = tokenize_and_pad(decoded_sentence, tokenizer, sequence_length)[:-1]
@@ -519,8 +516,8 @@ test_result_pairs = []
519516
for _ in range(10):
520517
input_sentence = random.choice(test_eng_texts)
521518
translated = decode_sequence(input_sentence)
522-
523-
test_result_pairs.append(f"[Input]: {input_sentence} [Translation]: {translated}")
519+
520+
test_result_pairs.append(f"[Input]: {input_sentence} [Translation]: {translated}")
524521
```
525522

526523
## Test Results

0 commit comments

Comments
 (0)