Skip to content

Commit a5ba53a

Browse files
authored
Adapt to tf-nightly (#1634)
* fixed tests in keras_layers_test * fixed image classifier test * fix bert tokenizer * multi branch arch adapt preprocessing layers * depending on tf-nightly * coverage Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
1 parent a50dbc1 commit a5ba53a

File tree

6 files changed

+42
-24
lines changed

6 files changed

+42
-24
lines changed

.github/workflows/actions.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ jobs:
1111
runs-on: ubuntu-latest
1212
steps:
1313
- uses: actions/checkout@v2
14-
- name: Set up Python 3.6
14+
- name: Set up Python 3.7
1515
uses: actions/setup-python@v1
1616
with:
17-
python-version: 3.6
17+
python-version: 3.7
1818
- name: Get pip cache dir
1919
id: pip-cache
2020
run: |
@@ -50,10 +50,10 @@ jobs:
5050
runs-on: ubuntu-latest
5151
steps:
5252
- uses: actions/checkout@v2
53-
- name: Set up Python 3.6
53+
- name: Set up Python 3.7
5454
uses: actions/setup-python@v1
5555
with:
56-
python-version: 3.6
56+
python-version: 3.7
5757
- name: Install dependencies
5858
run: |
5959
python -m pip install --upgrade pip setuptools
@@ -73,7 +73,7 @@ jobs:
7373
- name: Set up Python
7474
uses: actions/setup-python@v1
7575
with:
76-
python-version: 3.6
76+
python-version: 3.7
7777
- name: Install dependencies
7878
run: |
7979
python -m pip install --upgrade pip setuptools wheel twine

autokeras/engine/tuner.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import collections
1516
import copy
1617
import os
1718

@@ -112,26 +113,32 @@ def adapt(model, dataset):
112113
# TODO: Use Keras Tuner for preprocessing layers adapt.
113114
x = dataset.map(lambda x, y: x)
114115

115-
def get_output_layer(tensor):
116+
def get_output_layers(tensor):
117+
output_layers = []
116118
tensor = nest.flatten(tensor)[0]
117119
for layer in model.layers:
118120
if isinstance(layer, tf.keras.layers.InputLayer):
119121
continue
120122
input_node = nest.flatten(layer.input)[0]
121123
if input_node is tensor:
122-
if not isinstance(layer, preprocessing.PreprocessingLayer):
123-
break
124-
return layer
125-
return None
124+
if isinstance(layer, preprocessing.PreprocessingLayer):
125+
output_layers.append(layer)
126+
return output_layers
127+
128+
dq = collections.deque()
126129

127130
for index, input_node in enumerate(nest.flatten(model.input)):
128-
temp_x = x.map(lambda *args: nest.flatten(args)[index])
129-
layer = get_output_layer(input_node)
130-
while layer is not None:
131-
if isinstance(layer, preprocessing.PreprocessingLayer):
132-
layer.adapt(temp_x)
133-
temp_x = temp_x.map(layer)
134-
layer = get_output_layer(layer.output)
131+
in_x = x.map(lambda *args: nest.flatten(args)[index])
132+
for layer in get_output_layers(input_node):
133+
dq.append((layer, in_x))
134+
135+
while len(dq):
136+
layer, in_x = dq.popleft()
137+
layer.adapt(in_x)
138+
out_x = in_x.map(layer)
139+
for next_layer in get_output_layers(layer.output):
140+
dq.append((next_layer, out_x))
141+
135142
return model
136143

137144
def search(

autokeras/keras_layers.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def get_config(self):
4040
def call(self, inputs):
4141
return data_utils.cast_to_float32(inputs)
4242

43+
def adapt(self, data):
44+
return
45+
4346

4447
@tf.keras.utils.register_keras_serializable()
4548
class ExpandLastDim(preprocessing.PreprocessingLayer):
@@ -49,6 +52,9 @@ def get_config(self):
4952
def call(self, inputs):
5053
return tf.expand_dims(inputs, axis=-1)
5154

55+
def adapt(self, data):
56+
return
57+
5258

5359
@tf.keras.utils.register_keras_serializable()
5460
class MultiCategoryEncoding(preprocessing.PreprocessingLayer):
@@ -75,9 +81,7 @@ def __init__(self, encoding: List[str], **kwargs):
7581
# Set a temporary vocabulary to prevent the error of no
7682
# vocabulary when calling the layer to build the model. The
7783
# vocabulary would be reset by adapting the layer later.
78-
self.encoding_layers.append(
79-
preprocessing.StringLookup(vocabulary=["NONE"])
80-
)
84+
self.encoding_layers.append(preprocessing.StringLookup())
8185
elif encoding == ONE_HOT:
8286
self.encoding_layers.append(None)
8387

@@ -190,6 +194,9 @@ def bert_encode(self, input_tensor):
190194

191195
return input_word_ids
192196

197+
def adapt(self, data):
198+
return # pragma: no cover
199+
193200

194201
# TODO: Remove after KerasNLP is ready.
195202
@tf.keras.utils.register_keras_serializable()
@@ -685,6 +692,9 @@ def call(self, inputs):
685692

686693
return mask # pragma: no cover
687694

695+
def get_config(self):
696+
return super().get_config()
697+
688698

689699
@tf.keras.utils.register_keras_serializable()
690700
class Transformer(tf.keras.layers.Layer):

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ filterwarnings =
1414
ignore::PendingDeprecationWarning
1515
ignore::FutureWarning
1616
ignore::numpy.VisibleDeprecationWarning
17-
ignore::tensorflow.python.keras.utils.generic_utils.CustomMaskWarning
17+
ignore::keras.utils.generic_utils.CustomMaskWarning
1818

1919
addopts=-v
2020
--durations=10

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
install_requires=[
2020
"packaging",
2121
"keras-tuner>=1.0.2",
22-
"tensorflow<=2.5.0,>=2.3.0",
22+
"tf-nightly==2.8.0.dev20211016",
2323
"scikit-learn",
2424
"pandas",
2525
],
@@ -41,7 +41,6 @@
4141
"Intended Audience :: Education",
4242
"Intended Audience :: Science/Research",
4343
"License :: OSI Approved :: Apache Software License",
44-
"Programming Language :: Python :: 3.6",
4544
"Programming Language :: Python :: 3.7",
4645
"Programming Language :: Python :: 3.8",
4746
"Topic :: Scientific/Engineering :: Mathematics",

tests/unit_tests/keras_layers_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ def test_init_multi_one_hot_encode():
6262

6363

6464
def test_call_multi_with_single_column_return_right_shape():
65+
x_train = np.array([["a"], ["b"], ["a"]])
6566
layer = layer_module.MultiCategoryEncoding(encoding=[layer_module.INT])
67+
layer.adapt(tf.data.Dataset.from_tensor_slices(x_train).batch(32))
6668

67-
assert layer(np.array([["a"], ["b"], ["a"]])).shape == (3, 1)
69+
assert layer(x_train).shape == (3, 1)
6870

6971

7072
def get_text_data():

0 commit comments

Comments
 (0)