Skip to content

Commit 5a3542b

Browse files
lkartheejames77777778fcholletlpizzinidevhertschuh
authored
mlx - merge master into mlx (#19657)
* Introduce float8 training (#19488) * Add float8 training support * Add tests for fp8 training * Add `quantize_and_dequantize` test * Fix bugs and add float8 correctness tests * Cleanup * Address comments and cleanup * Add docstrings and some minor refactoring * Add `QuantizedFloat8DTypePolicy` * Add dtype policy setter * Fix torch dynamo issue by using `self._dtype_policy` * Improve test coverage * Add LoRA to ConvND layers (#19516) * Add LoRA to `BaseConv` * Add tests * Fix typo * Fix tests * Fix tests * Add path to run keras on dm-tree when optree is not available. * feat(losses): add Tversky loss implementation (#19511) * feat(losses): add Tversky loss implementation * adjusted documentation * Update KLD docs * Models and layers now return owned metrics recursively. (#19522) - added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively. - `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers. - `Model.metrics` now returns all metrics recursively, not just the model level metrics. - `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics. - added test coverage to test metrics and variables 2 levels deep. This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work. * Update IoU ignore_class handling * Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513) * Add tests for CategoryEncoding class in category_encoding_test.py * fix * Fix IndexLookup class initialization and add test cases * Add test case for IndexLookupLayerTest without vocabulary * Fix IndexLookup class initialization * Add normalization test cases * Add test cases for Hashing class * Fix value range validation error in RandomBrightness class * Refactor IndexLookup class initialization and add test cases * Reffix ndexLookup class initialization and afix est cases * Add test for spectral norm * Add missing test decorator * Fix torch test * Fix code format * Generate API (#19530) * API Generator for Keras * API Generator for Keras * Generates API Gen via api_gen.sh * Remove recursive import of _tf_keras * Generate API Files via api_gen.sh * Update APIs * Added metrics from custom `train_step`/`test_step` are now returned. (#19529) This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics. * Use temp dir and abs path in `api_gen.py` (#19533) * Use temp dir and abs path * Use temp dir and abs path * Update Readme * Update API * Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534) * Fix gradient accumulation with `overwrite_with_gradient` in float8 training * Add comments * Fix annotation * Update code path in ignore path (#19537) * Add operations per run (#19538) * Include input shapes in model visualization. * Add pad_to_aspect_ratio feature in ops.image.resize * Add pad_to_aspect_ratio feature in Resizing layer. * Fix incorrect usage of `quantize` (#19541) * Add logic to prevent double quantization * Add detailed info for double quantization error * Update error msg * Add eigh op. * Add keepdim in argmax/argmin. * Fix small bug in model.save_weights (#19545) * Update public APIs. * eigh should work on JAX GPU * Copy init to keras/__init__.py (#19551) * Revert "Copy init to keras/__init__.py (#19551)" (#19552) This reverts commit da9af61. * sum-reduce inlined losses * Remove the dependency on `tensorflow.experimental.numpy` and support negative indices for `take` and `take_along_axis` (#19556) * Remove `tfnp` * Update numpy api * Improve test coverage * Improve test coverage * Fix `Tri` and `Eye` and increase test converage * Update `round` test * Fix `jnp.round` * Fix `diag` bug for iou_metrics * Add op.select. * Add new API for select * Make `ops.abs` and `ops.absolute` consistent between backends. (#19563) - The TensorFlow implementation was missing `convert_to_tensor` - The sparse annotation was unnecessarily applied twice - Now `abs` calls `absolute` in all backends Also fixed TensorFlow `ops.select`. * Add pickle support for Keras model (#19555) * Implement unit tests for pickling * Reformat model_test * Reformat model_test * Rename depickle to unpickle * Rename depickle to unpickle * Reformat * remove a comment * Ellipsis Serialization and tests (#19564) * Serialization and tests * Serialization and tests * Serialization and tests * Make TF one_hot input dtype less strict. * Fix einsum `_int8_call` (#19570) * CTC Decoding for JAX and Tensorflow (#19366) * Tensorflow OP for CTC decoding * JAX op for CTC greedy decoding * Update CTC decoding documentation * Fix linting issues * Fix trailing whitespace * Simplify returns in tensorflow CTC wrapper * Fix CTC decoding error messages * Fix line too long * Bug fixes to JAX CTC greedy decoder * Force int typecast in TF CTC decoder * Unit tests for CTC greedy decoding * Add unit test for CTC beam search decoding * Fix mask index set location in JAX CTC decoding * CTC beam search decoding for JAX * Fix unhandled token repetitions in ctc_beam_search_decode * Fix merge_repeated bug in CTC beam search decode * Fix beam storage and repetition bugs in JAX ctc_decode * Remove trailing whitespace * Fix ordering bug for ties in JAX CTC beam search * Cast sequence lengths to integers in JAX ctc_decode * Remove line break in docstring * CTC beam search decoding for JAX * Fix unhandled token repetitions in ctc_beam_search_decode * Fix merge_repeated bug in CTC beam search decode * Fix beam storage and repetition bugs in JAX ctc_decode * Fix ordering bug for ties in JAX CTC beam search * Generate public api directory * Add not implemented errors for NumPy and Torch CTC decoding * Remove unused redefinition of JAX ctc_beam_search_decode * Docstring edits * Expand nan_to_num args. * Add vectorize op. * list insert requires index (#19575) * Add signature and exclude args to knp.vectorize. * Fix the apis of `dtype_polices` (#19580) * Fix api of `dtype_polices` * Update docstring * Increase test coverage * Fix format * Fix keys of `save_own_variables` and `load_own_variables` (#19581) * Fix JAX CTC test. * Fix loss_weights handling in single output case * Fix JAX vectorize. * Move _tf_keras directory to the root of the pip package. * One time fix to _tf_keras API. * Convert return type imdb.load_data to nparray (#19598) Convert return type imdb.load_data to Numpy array. Currently X_train and X-test returned as list. * Fix typo * fix api_gen.py for legacy (#19590) * fix api_gen.py for legacy * merge api and legacy for _tf_keras * Improve int8 for `Embedding` (#19595) * pin torch < 2.3.0 (#19603) * Clean up duplicated `inputs_quantizer` (#19604) * Cleanup duplicated `inputs_quantizer` and add type check for `input_spec` and `supports_masking` * Revert setter * output format changes and errors in github (#19608) * Provide write permission to action for cache management. (#19606) * Pickle support for all saveables (#19592) * Pickle support * Add keras pickleable mixin * Reformat * Implement pickle all over * reformat * Reformat * Keras saveable * Keras saveable * Keras saveable * Keras saveable * Keras saveable * obj_type * Update pickleable * Saveable logic touchups * Add slogdet op. * Update APIs * Remove unused import * Refactor CTC APIs (#19611) * Add `ctc_loss` and `ctc_decode` for numpy backend, improve imports and tests * Support "beam_search" strategy for torch's `ctc_decode` * Improve `ctc_loss` * Cleanup * Refactor `ctc_decode` * Update docstring * Update docstring * Add `CTCDecode` operation and ensure dtype inference of `ctc_decode` * Fix `name` of `losses.CTC` * update the namex version requirements (#19617) * Add `PSNR` API (#19616) * PSNR * Fix * Docstring format * Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps (#19618) * Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps * Formatting * Implement custom layer insertion in clone_model. (#19610) * Implement custom layer insertion in clone_model. * Add recursive arg and tests. * Add nested sequential cloning test * Fix bidir lstm saving issue. * Fix CI * Fix cholesky tracing with jax * made extract_patches dtype agnostic (#19621) * Simplify Bidirectional implementation * Add support for infinite `PyDataset`s. (#19624) `PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled. Fixes #19528 Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported. * Fix dataset shuffling issue. * Update version string. * Minor fix * Restore version string resolution in pip_build. * Speed up `DataAdapter` tests by testing only the current backend. (#19625) There is no use case for using an iterator for a different backend than the current backend. Also: - limit the number of tests using multiprocessing, the threading tests give us good coverage. - fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases. - removed unused `init_pool` method. * feat(ops): support np.argpartition (#19588) * feat(ops): support np.argpartition * updated documentation, type-casting, and tf implementation * fixed tf implementation * added torch cast to int32 * updated torch type and API generated files * added torch output type cast * test(trainers): add test_errors implementation for ArrayDataAdapter class (#19626) * Fix torch GPU CI * Fix argmax/argmin keepdims with defined axis in TF * Misc fixes in TF backend ops. * Fix `argpartition` cuda bug in torch (#19634) * fix(ops): specify NonZero output dtype and add test coverage (#19635) * Fix `ops.ctc_decode` (#19633) * Fix greedy ctc decode * Remove print * Fix `tf.nn.ctc_beam_search_decoder` * Change default `mask_index` to `0` * Fix losses test * Update * Ensure the same rule applies for np arrays in autocasting (#19636) * Ensure the same rule applies for np arrays in autocasting * Trigger CI by adding docstring * Update * Update docstring * Fix `istft` and add class `TestMathErrors` in `ops/math_test.py` (#19594) * Fix and test math functions for jax backend * run /workspaces/keras/shell/format.sh * refix * fix * fix _get_complex_tensor_from_tuple * fix * refix * Fix istft function to handle inputs with less than 2 dimensions * fix * Fix ValueError in istft function for inputs with less than 2 dimensions * Return a tuple from `ops.shape` with the Torch backend. (#19640) With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`. This fixes #18900 * support conv3d on cpu for TF (#19641) * Enable cudnn rnns when dropout is set (#19645) * Enable cudnn rnns when dropout is set * Fix * Fix plot_model for input dicts. * Fix deprecation warning in torch * Bump the github-actions group with 2 updates (#19653) Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [github/codeql-action](https://github.com/github/codeql-action). Updates `actions/upload-artifact` from 4.3.1 to 4.3.3 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](actions/upload-artifact@5d5d22a...6546280) Updates `github/codeql-action` from 3.24.9 to 3.25.3 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](github/codeql-action@1b1aada...d39d31e) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump the python group with 2 updates (#19654) Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.2.1+cu121 to 2.3.0+cu121 Updates `torchvision` from 0.17.1+cu121 to 0.18.0+cu121 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Revert "Bump the python group with 2 updates (#19654)" (#19655) This reverts commit 09133f4. --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com> Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: Luca Pizzini <lpizzini7@gmail.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Sachin Prasad <sachinprasad@google.com> Co-authored-by: Uwe Schmidt <uschmidt83@users.noreply.github.com> Co-authored-by: Luke Wood <LukeWood@users.noreply.github.com> Co-authored-by: Maanas Arora <maanasarora23@gmail.com> Co-authored-by: AlexanderLavelle <73360008+AlexanderLavelle@users.noreply.github.com> Co-authored-by: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Co-authored-by: Shivam Mishra <124146945+shmishra99@users.noreply.github.com> Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: IMvision12 <88665786+IMvision12@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Vachan V Y <109357590+VachanVY@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
1 parent 4c90dfb commit 5a3542b

File tree

107 files changed

+4396
-1034
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+4396
-1034
lines changed

.github/workflows/actions.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,12 @@ jobs:
126126
fi
127127
- name: Lint
128128
run: bash shell/lint.sh
129+
- name: Check for API changes
130+
run: |
131+
bash shell/api_gen.sh
132+
git status
133+
clean=$(git status | grep "nothing to commit")
134+
if [ -z "$clean" ]; then
135+
echo "Please run shell/api_gen.sh to generate API."
136+
exit 1
137+
fi

.github/workflows/nightly.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ jobs:
9292
fi
9393
- name: Lint
9494
run: bash shell/lint.sh
95+
- name: Check for API changes
96+
run: |
97+
bash shell/api_gen.sh
98+
git status
99+
clean=$(git status | grep "nothing to commit")
100+
if [ -z "$clean" ]; then
101+
echo "Please run shell/api_gen.sh to generate API."
102+
exit 1
103+
fi
104+
95105
96106
nightly:
97107
name: Build Wheel file and upload

.github/workflows/scorecard.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ jobs:
4848
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
4949
# format to the repository Actions tab.
5050
- name: "Upload artifact"
51-
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1
51+
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
5252
with:
5353
name: SARIF file
5454
path: results.sarif
5555
retention-days: 5
5656

5757
# Upload the results to GitHub's code scanning dashboard.
5858
- name: "Upload to code-scanning"
59-
uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9
59+
uses: github/codeql-action/upload-sarif@d39d31e687223d841ef683f52467bd88e9b21c14 # v3.25.3
6060
with:
6161
sarif_file: results.sarif

.github/workflows/stale-issue-pr.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ jobs:
1010
permissions:
1111
issues: write
1212
pull-requests: write
13+
actions: write
1314
steps:
1415
- name: Awaiting response issues
1516
uses: actions/stale@v9

SECURITY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Besides the virtual environment, the hardware (GPUs or TPUs) can also be attacke
5959

6060
## Reporting a Vulnerability
6161

62-
Beware that none of the topics under [Using Keras Securely](#using-Keras-securely) are considered vulnerabilities of Keras.
62+
Beware that none of the topics under [Using Keras Securely](#using-keras-securely) are considered vulnerabilities of Keras.
6363

6464
If you have discovered a security vulnerability in this project, please report it
6565
privately. **Do not disclose it as a public issue.** This gives us time to work with you

api_gen.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import os
10+
import re
1011
import shutil
1112

1213
import namex
@@ -78,8 +79,7 @@ def create_legacy_directory(package_dir):
7879
for path in os.listdir(os.path.join(src_dir, "legacy"))
7980
if os.path.isdir(os.path.join(src_dir, "legacy", path))
8081
]
81-
82-
for root, _, fnames in os.walk(os.path.join(package_dir, "_legacy")):
82+
for root, _, fnames in os.walk(os.path.join(api_dir, "_legacy")):
8383
for fname in fnames:
8484
if fname.endswith(".py"):
8585
legacy_fpath = os.path.join(root, fname)
@@ -110,6 +110,20 @@ def create_legacy_directory(package_dir):
110110
f"keras.api.{legacy_submodule}",
111111
f"keras.api._tf_keras.keras.{legacy_submodule}",
112112
)
113+
# Remove duplicate generated comments string.
114+
legacy_contents = re.sub(r"\n", r"\\n", legacy_contents)
115+
legacy_contents = re.sub('""".*"""', "", legacy_contents)
116+
legacy_contents = re.sub(r"\\n", r"\n", legacy_contents)
117+
# If the same module is in legacy and core_api, use legacy
118+
legacy_imports = re.findall(
119+
r"import (\w+)", legacy_contents
120+
)
121+
for import_name in legacy_imports:
122+
core_api_contents = re.sub(
123+
f"\n.* import {import_name}\n",
124+
r"\n",
125+
core_api_contents,
126+
)
113127
legacy_contents = core_api_contents + "\n" + legacy_contents
114128
with open(tf_keras_fpath, "w") as f:
115129
f.write(legacy_contents)

keras/api/_tf_keras/keras/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from keras.api import activations
88
from keras.api import applications
9-
from keras.api import backend
109
from keras.api import callbacks
1110
from keras.api import config
1211
from keras.api import constraints
@@ -15,21 +14,21 @@
1514
from keras.api import dtype_policies
1615
from keras.api import export
1716
from keras.api import initializers
18-
from keras.api import layers
1917
from keras.api import legacy
20-
from keras.api import losses
21-
from keras.api import metrics
2218
from keras.api import mixed_precision
2319
from keras.api import models
2420
from keras.api import ops
2521
from keras.api import optimizers
26-
from keras.api import preprocessing
2722
from keras.api import quantizers
2823
from keras.api import random
2924
from keras.api import regularizers
30-
from keras.api import saving
3125
from keras.api import tree
3226
from keras.api import utils
27+
from keras.api._tf_keras.keras import backend
28+
from keras.api._tf_keras.keras import layers
29+
from keras.api._tf_keras.keras import losses
30+
from keras.api._tf_keras.keras import metrics
31+
from keras.api._tf_keras.keras import preprocessing
3332
from keras.src.backend.common.keras_tensor import KerasTensor
3433
from keras.src.backend.common.stateless_scope import StatelessScope
3534
from keras.src.backend.exports import Variable

keras/api/_tf_keras/keras/backend/__init__.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,127 @@
1717
from keras.src.backend.config import set_epsilon
1818
from keras.src.backend.config import set_floatx
1919
from keras.src.backend.config import set_image_data_format
20+
from keras.src.legacy.backend import abs
21+
from keras.src.legacy.backend import all
22+
from keras.src.legacy.backend import any
23+
from keras.src.legacy.backend import arange
24+
from keras.src.legacy.backend import argmax
25+
from keras.src.legacy.backend import argmin
26+
from keras.src.legacy.backend import batch_dot
27+
from keras.src.legacy.backend import batch_flatten
28+
from keras.src.legacy.backend import batch_get_value
29+
from keras.src.legacy.backend import batch_normalization
30+
from keras.src.legacy.backend import batch_set_value
31+
from keras.src.legacy.backend import bias_add
32+
from keras.src.legacy.backend import binary_crossentropy
33+
from keras.src.legacy.backend import binary_focal_crossentropy
34+
from keras.src.legacy.backend import cast
35+
from keras.src.legacy.backend import cast_to_floatx
36+
from keras.src.legacy.backend import categorical_crossentropy
37+
from keras.src.legacy.backend import categorical_focal_crossentropy
38+
from keras.src.legacy.backend import clip
39+
from keras.src.legacy.backend import concatenate
40+
from keras.src.legacy.backend import constant
41+
from keras.src.legacy.backend import conv1d
42+
from keras.src.legacy.backend import conv2d
43+
from keras.src.legacy.backend import conv2d_transpose
44+
from keras.src.legacy.backend import conv3d
45+
from keras.src.legacy.backend import cos
46+
from keras.src.legacy.backend import count_params
47+
from keras.src.legacy.backend import ctc_batch_cost
48+
from keras.src.legacy.backend import ctc_decode
49+
from keras.src.legacy.backend import ctc_label_dense_to_sparse
50+
from keras.src.legacy.backend import cumprod
51+
from keras.src.legacy.backend import cumsum
52+
from keras.src.legacy.backend import depthwise_conv2d
53+
from keras.src.legacy.backend import dot
54+
from keras.src.legacy.backend import dropout
55+
from keras.src.legacy.backend import dtype
56+
from keras.src.legacy.backend import elu
57+
from keras.src.legacy.backend import equal
58+
from keras.src.legacy.backend import eval
59+
from keras.src.legacy.backend import exp
60+
from keras.src.legacy.backend import expand_dims
61+
from keras.src.legacy.backend import eye
62+
from keras.src.legacy.backend import flatten
63+
from keras.src.legacy.backend import foldl
64+
from keras.src.legacy.backend import foldr
65+
from keras.src.legacy.backend import gather
66+
from keras.src.legacy.backend import get_value
67+
from keras.src.legacy.backend import gradients
68+
from keras.src.legacy.backend import greater
69+
from keras.src.legacy.backend import greater_equal
70+
from keras.src.legacy.backend import hard_sigmoid
71+
from keras.src.legacy.backend import in_top_k
72+
from keras.src.legacy.backend import int_shape
73+
from keras.src.legacy.backend import is_sparse
74+
from keras.src.legacy.backend import l2_normalize
75+
from keras.src.legacy.backend import less
76+
from keras.src.legacy.backend import less_equal
77+
from keras.src.legacy.backend import log
78+
from keras.src.legacy.backend import map_fn
79+
from keras.src.legacy.backend import max
80+
from keras.src.legacy.backend import maximum
81+
from keras.src.legacy.backend import mean
82+
from keras.src.legacy.backend import min
83+
from keras.src.legacy.backend import minimum
84+
from keras.src.legacy.backend import moving_average_update
85+
from keras.src.legacy.backend import name_scope
86+
from keras.src.legacy.backend import ndim
87+
from keras.src.legacy.backend import not_equal
88+
from keras.src.legacy.backend import one_hot
89+
from keras.src.legacy.backend import ones
90+
from keras.src.legacy.backend import ones_like
91+
from keras.src.legacy.backend import permute_dimensions
92+
from keras.src.legacy.backend import pool2d
93+
from keras.src.legacy.backend import pool3d
94+
from keras.src.legacy.backend import pow
95+
from keras.src.legacy.backend import prod
96+
from keras.src.legacy.backend import random_bernoulli
97+
from keras.src.legacy.backend import random_normal
98+
from keras.src.legacy.backend import random_normal_variable
99+
from keras.src.legacy.backend import random_uniform
100+
from keras.src.legacy.backend import random_uniform_variable
101+
from keras.src.legacy.backend import relu
102+
from keras.src.legacy.backend import repeat
103+
from keras.src.legacy.backend import repeat_elements
104+
from keras.src.legacy.backend import reshape
105+
from keras.src.legacy.backend import resize_images
106+
from keras.src.legacy.backend import resize_volumes
107+
from keras.src.legacy.backend import reverse
108+
from keras.src.legacy.backend import rnn
109+
from keras.src.legacy.backend import round
110+
from keras.src.legacy.backend import separable_conv2d
111+
from keras.src.legacy.backend import set_value
112+
from keras.src.legacy.backend import shape
113+
from keras.src.legacy.backend import sigmoid
114+
from keras.src.legacy.backend import sign
115+
from keras.src.legacy.backend import sin
116+
from keras.src.legacy.backend import softmax
117+
from keras.src.legacy.backend import softplus
118+
from keras.src.legacy.backend import softsign
119+
from keras.src.legacy.backend import sparse_categorical_crossentropy
120+
from keras.src.legacy.backend import spatial_2d_padding
121+
from keras.src.legacy.backend import spatial_3d_padding
122+
from keras.src.legacy.backend import sqrt
123+
from keras.src.legacy.backend import square
124+
from keras.src.legacy.backend import squeeze
125+
from keras.src.legacy.backend import stack
126+
from keras.src.legacy.backend import std
127+
from keras.src.legacy.backend import stop_gradient
128+
from keras.src.legacy.backend import sum
129+
from keras.src.legacy.backend import switch
130+
from keras.src.legacy.backend import tanh
131+
from keras.src.legacy.backend import temporal_padding
132+
from keras.src.legacy.backend import tile
133+
from keras.src.legacy.backend import to_dense
134+
from keras.src.legacy.backend import transpose
135+
from keras.src.legacy.backend import truncated_normal
136+
from keras.src.legacy.backend import update
137+
from keras.src.legacy.backend import update_add
138+
from keras.src.legacy.backend import update_sub
139+
from keras.src.legacy.backend import var
140+
from keras.src.legacy.backend import variable
141+
from keras.src.legacy.backend import zeros
142+
from keras.src.legacy.backend import zeros_like
20143
from keras.src.utils.naming import get_uid

keras/api/_tf_keras/keras/dtype_policies/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras.src.dtype_policies import deserialize
8+
from keras.src.dtype_policies import get
9+
from keras.src.dtype_policies import serialize
710
from keras.src.dtype_policies.dtype_policy import DTypePolicy
811
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
912
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@
157157
from keras.src.layers.regularization.activity_regularization import (
158158
ActivityRegularization,
159159
)
160-
from keras.src.layers.regularization.alpha_dropout import AlphaDropout
161160
from keras.src.layers.regularization.dropout import Dropout
162161
from keras.src.layers.regularization.gaussian_dropout import GaussianDropout
163162
from keras.src.layers.regularization.gaussian_noise import GaussianNoise
@@ -190,6 +189,10 @@
190189
from keras.src.layers.rnn.simple_rnn import SimpleRNNCell
191190
from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells
192191
from keras.src.layers.rnn.time_distributed import TimeDistributed
192+
from keras.src.legacy.layers import AlphaDropout
193+
from keras.src.legacy.layers import RandomHeight
194+
from keras.src.legacy.layers import RandomWidth
195+
from keras.src.legacy.layers import ThresholdedReLU
193196
from keras.src.utils.jax_layer import FlaxLayer
194197
from keras.src.utils.jax_layer import JaxLayer
195198
from keras.src.utils.torch_utils import TorchModuleWrapper

0 commit comments

Comments
 (0)