@@ -5,7 +5,7 @@ jupytext:
5
5
extension : .md
6
6
format_name : myst
7
7
format_version : 0.13
8
- jupytext_version : 1.16.4
8
+ jupytext_version : 1.15.2
9
9
kernelspec :
10
10
display_name : Python 3 (ipykernel)
11
11
language : python
@@ -21,15 +21,12 @@ Adapted from https://keras.io/examples/nlp/neural_machine_translation_with_trans
21
21
We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation.
22
22
23
23
``` {code-cell} ipython3
24
- import os
25
-
26
24
import pathlib
27
25
import random
28
26
import string
29
27
import re
30
28
import numpy as np
31
29
32
- import jax
33
30
import jax.numpy as jnp
34
31
import optax
35
32
@@ -163,13 +160,13 @@ class TransformerEncoder(nnx.Module):
163
160
self.attention = nnx.MultiHeadAttention(num_heads=num_heads,
164
161
in_features=embed_dim,
165
162
decode=False,
166
- rngs=rngs)
163
+ rngs=rngs)
167
164
self.dense_proj = nnx.Sequential(
168
165
nnx.Linear(embed_dim, dense_dim, rngs=rngs),
169
166
nnx.relu,
170
167
nnx.Linear(dense_dim, embed_dim, rngs=rngs),
171
168
)
172
-
169
+
173
170
self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)
174
171
self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)
175
172
@@ -283,7 +280,7 @@ class TransformerModel(nnx.Module):
283
280
def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):
284
281
x = self.positional_embedding(encoder_inputs)
285
282
encoder_outputs = self.encoder(x, mask=mask)
286
-
283
+
287
284
x = self.positional_embedding(decoder_inputs)
288
285
decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)
289
286
# per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training
@@ -382,7 +379,7 @@ def eval_step(model, batch, eval_metrics):
382
379
loss=loss,
383
380
logits=logits,
384
381
labels=labels,
385
- )
382
+ )
386
383
```
387
384
388
385
Here, ` nnx.MultiMetric ` helps us keep track of general training statistics, while we make our own dictionaries to hold historical values
@@ -495,7 +492,7 @@ def decode_sequence(input_sentence):
495
492
496
493
input_sentence = custom_standardization(input_sentence)
497
494
tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)
498
-
495
+
499
496
decoded_sentence = "[start"
500
497
for i in range(sequence_length):
501
498
tokenized_target_sentence = tokenize_and_pad(decoded_sentence, tokenizer, sequence_length)[:-1]
@@ -519,8 +516,8 @@ test_result_pairs = []
519
516
for _ in range(10):
520
517
input_sentence = random.choice(test_eng_texts)
521
518
translated = decode_sequence(input_sentence)
522
-
523
- test_result_pairs.append(f"[Input]: {input_sentence} [Translation]: {translated}")
519
+
520
+ test_result_pairs.append(f"[Input]: {input_sentence} [Translation]: {translated}")
524
521
```
525
522
526
523
## Test Results
0 commit comments