Skip to content

Commit c9383e2

Browse files
Keras <> NNX integration (#21252)
* _valu * update variables * add nnx.jit * revert changes to JaxLayer * format fix * make variables subclass nnx.Variable * more tweaks * update init * refactor jax Variable class * code reformat * more cleanup * update flax version * update flax version * fix jax error * update Variables implementation * fix import * add a test * needs updates in operation * remove __new__ from JaxLayer * update base optimizers * code reformat+ model saving tests * add __hash__ * update variable value updates * sync value properly * update flag based routing between nnx and jax * clean up * fix circular import error * fix is nnx call enabled flag * attemptto fix circular import error * try again * fix import error * reformat# Please enter the commit message for your changes. Lines starting * This has to fix it * api gen * remove enable diisable configs -that does not work * adrress some comments * update conditional imports * fix tests * add github workflow for nnx * fix test * address comments * fix test * address comments * fix test * fix test -_- * put the set attr in operation * fix jax error * fix trace error * remove installation * import fixes * update jax version * ugh the jax version issue * update jax version * update installations * update jax utils * another requirents file fix * fix test * add back flax to req common * address review comments * fix tests * fix tests address more comments * fix tests * fix tests * fix jax tests * revert guide * fix code format * fix tests and jit * fix import test * try to fix memory error * revert memory fix * fix test * fix test * fix test * revert version back * Update functional.py * Update core.py * remove nnx workflow * code reformat * address gemini comments * address latest comments * remove hash function * update tests name * address latest comments * address review comments * fix actions * fix actions * skipt other tests in nnx backend * revert changes to basic flow * point installation to official JAX code * fix actions * revert basic flow test * move logic out of functional and into function class * revert functional.py * simplify init * FIX MODEL BUILD ERROR * revert changes to basic_full_flow.py * revert basic_full_flow.py * address review comments * add layer.py# modified: keras/src/layers/layer.py * revert variables change * nit * update init * ypou * resolve circukar import error * update if condition * resolve circular import error * revert change to variables.py * delete state comment
1 parent 3554825 commit c9383e2

File tree

17 files changed

+419
-21
lines changed

17 files changed

+419
-21
lines changed

.github/workflows/actions.yml

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
name: Tests
22

3+
# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future
4+
# Currently only basic flow tests run with NNX enabled
5+
36
on:
47
push:
58
branches: [ master ]
@@ -17,7 +20,12 @@ jobs:
1720
matrix:
1821
python-version: ['3.10']
1922
backend: [tensorflow, jax, torch, numpy, openvino]
20-
name: Run tests
23+
nnx_enabled: [false]
24+
include:
25+
- python-version: '3.10'
26+
backend: jax
27+
nnx_enabled: true
28+
name: ${{ matrix.backend == 'jax' && format('Run tests ({0}, {1}, nnx_enabled = {2})', matrix.python-version, matrix.backend, matrix.nnx_enabled) || format('Run tests ({0}, {1})', matrix.python-version, matrix.backend) }}
2129
runs-on: ubuntu-latest
2230
env:
2331
PYTHON: ${{ matrix.python-version }}
@@ -48,15 +56,18 @@ jobs:
4856
- name: Install dependencies
4957
run: |
5058
pip install -r requirements.txt --progress-bar off --upgrade
59+
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
60+
pip install --upgrade git+https://github.com/google/flax.git
61+
fi
5162
pip uninstall -y keras keras-nightly
5263
pip install -e "." --progress-bar off --upgrade
5364
- name: Test applications with pytest
54-
if: ${{ steps.filter.outputs.applications == 'true' }}
65+
if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}
5566
run: |
5667
pytest keras/src/applications --cov=keras/src/applications --cov-config=pyproject.toml
5768
coverage xml --include='keras/src/applications/*' -o apps-coverage.xml
5869
- name: Codecov keras.applications
59-
if: ${{ steps.filter.outputs.applications == 'true' }}
70+
if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}
6071
uses: codecov/codecov-action@v5
6172
with:
6273
env_vars: PYTHON,KERAS_HOME
@@ -65,14 +76,21 @@ jobs:
6576
token: ${{ secrets.CODECOV_TOKEN }}
6677
fail_ci_if_error: false
6778
- name: Test integrations
68-
if: ${{ matrix.backend != 'numpy'}}
79+
if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }}
6980
run: |
7081
python integration_tests/import_test.py
7182
python integration_tests/numerical_test.py
7283
- name: Test JAX-specific integrations
73-
if: ${{ matrix.backend == 'jax'}}
84+
if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }}
7485
run: |
7586
python integration_tests/jax_custom_fit_test.py
87+
- name: Test basic flow with NNX
88+
if: ${{ matrix.nnx_enabled == true }}
89+
env:
90+
KERAS_NNX_ENABLED: true
91+
run: |
92+
python integration_tests/import_test.py
93+
python integration_tests/basic_full_flow.py
7694
- name: Test TF-specific integrations
7795
if: ${{ matrix.backend == 'tensorflow'}}
7896
run: |
@@ -84,6 +102,7 @@ jobs:
84102
pytest integration_tests/torch_workflow_test.py
85103
python integration_tests/torch_custom_fit_test.py
86104
- name: Test with pytest
105+
if: ${{ matrix.nnx_enabled == false }}
87106
run: |
88107
if [ "${{ matrix.backend }}" == "openvino" ]; then
89108
IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt"
@@ -94,10 +113,11 @@ jobs:
94113
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
95114
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
96115
- name: Codecov keras
116+
if: ${{ matrix.nnx_enabled == false }}
97117
uses: codecov/codecov-action@v5
98118
with:
99-
env_vars: PYTHON,KERAS_HOME
100-
flags: keras,keras-${{ matrix.backend }}
119+
env_vars: PYTHON,KERAS_HOME,KERAS_NNX_ENABLED
120+
flags: keras,keras-${{ matrix.backend }}${{ matrix.nnx_enabled == 'true' && '-nnx' || '' }}
101121
files: core-coverage.xml
102122
token: ${{ secrets.CODECOV_TOKEN }}
103123
fail_ci_if_error: false

.github/workflows/config/jax/keras.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
"floatx": "float32",
33
"epsilon": 1e-07,
44
"backend": "jax",
5-
"image_data_format": "channels_last"
5+
"image_data_format": "channels_last",
6+
"nnx_enabled": false
67
}

integration_tests/import_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import subprocess
44

55
from keras.src import backend
6+
from keras.src.backend import config
67

78
# For torch, use index url to avoid installing nvidia drivers for the test.
89
BACKEND_REQ = {
@@ -65,6 +66,9 @@ def manage_venv_installs(whl_path):
6566
# Install `.whl` package
6667
"pip install " + whl_path,
6768
]
69+
# Install flax for JAX when NNX is enabled
70+
if backend.backend() == "jax" and config.is_nnx_enabled():
71+
install_setup.append("pip install flax>=0.10.1")
6872
run_commands_venv(install_setup)
6973

7074

keras/api/_tf_keras/keras/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20+
from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled
2021
from keras.src.backend.config import max_epochs as max_epochs
2122
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
2223
from keras.src.backend.config import set_epsilon as set_epsilon

keras/api/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20+
from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled
2021
from keras.src.backend.config import max_epochs as max_epochs
2122
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
2223
from keras.src.backend.config import set_epsilon as set_epsilon

keras/src/backend/common/variables.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ def __init__(
211211

212212
def _deferred_initialize(self):
213213
if self._value is not None:
214+
# If NNX is enabled, it's possible the variable was already
215+
# initialized by a concrete call. In this case, _deferred_initialize
216+
# returns early and does not raise an error.
217+
if config.is_nnx_enabled():
218+
return
214219
raise ValueError(f"Variable {self.path} is already initialized.")
215220

216221
if in_stateless_scope():

keras/src/backend/config.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
# Default backend: TensorFlow.
1616
_BACKEND = "tensorflow"
1717

18+
# Whether NNX is enabled.
19+
_NNX_ENABLED = False
20+
1821
# Cap run duration for debugging.
1922
_MAX_EPOCHS = None
2023
_MAX_STEPS_PER_EPOCH = None
@@ -230,6 +233,32 @@ def is_flash_attention_enabled():
230233
return global_state.get_global_attribute("flash_attention", default=None)
231234

232235

236+
@keras_export("keras.config.is_nnx_enabled")
237+
def is_nnx_enabled():
238+
"""Checks whether NNX specific features are enabled for the JAX backend.
239+
240+
Returns:
241+
bool: `True` if NNX backend features are enabled, `False` otherwise.
242+
Defaults to `False`.
243+
"""
244+
return _NNX_ENABLED
245+
246+
247+
def set_nnx_enabled(value):
248+
global _NNX_ENABLED
249+
from keras.src.backend.common import global_state
250+
251+
_NNX_ENABLED = bool(value)
252+
if _NNX_ENABLED:
253+
try:
254+
from flax import nnx # noqa F401
255+
except ImportError:
256+
raise ImportError(
257+
"To use NNX with the JAX backend, you must install `flax`."
258+
)
259+
global_state.set_global_attribute("nnx_enabled", bool(value))
260+
261+
233262
def standardize_data_format(data_format):
234263
if data_format is None:
235264
return image_data_format()
@@ -274,8 +303,11 @@ def keras_home():
274303
_backend = _config.get("backend", _BACKEND)
275304
_image_data_format = _config.get("image_data_format", image_data_format())
276305
assert _image_data_format in {"channels_last", "channels_first"}
306+
_nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED)
277307

308+
# Apply basic configs that don't cause circular import
278309
set_floatx(_floatx)
310+
_NNX_ENABLED = _nnx_enabled_config
279311
set_epsilon(_epsilon)
280312
set_image_data_format(_image_data_format)
281313
_BACKEND = _backend
@@ -313,6 +345,7 @@ def keras_home():
313345
if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ:
314346
_MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"])
315347

348+
316349
if _BACKEND != "tensorflow":
317350
# If we are not running on the tensorflow backend, we should stop tensorflow
318351
# from using all available GPU memory. See
@@ -403,3 +436,13 @@ def max_steps_per_epoch():
403436
`None`, no limit is applied.
404437
"""
405438
return _MAX_STEPS_PER_EPOCH
439+
440+
441+
if "KERAS_NNX_ENABLED" in os.environ:
442+
env_val = os.environ["KERAS_NNX_ENABLED"].lower()
443+
if env_val == "true" or env_val == "1":
444+
_NNX_ENABLED = True
445+
else:
446+
_NNX_ENABLED = False
447+
448+
set_nnx_enabled(_NNX_ENABLED)

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from keras.src.backend.config import is_nnx_enabled
12
from keras.src.backend.jax import core
23
from keras.src.backend.jax import distribution_lib
34
from keras.src.backend.jax import image

0 commit comments

Comments
 (0)