Skip to content

tf.keras.layers.Embedding() somehow passing a float to randrange #21285

Closed
@austin-threadbeast

Description

@austin-threadbeast

Hi all,

I'm developing a TF/Keras model on Vertex AI. I am able to successfully (albeit slowly) train the model on my Apple-silicon laptop locally, when I package the code up in a container and run it on an a2-highgpu-1g instance in GCP, I receive an error:

INFO    2025-05-14 16:59:22 -0700       workerpool0-0   DEBUGGING: Checking batch dtypes...
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   Batch dtypes:
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   field_value_string_id: <dtype: 'int32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   event_timestamp: <dtype: 'int64'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   is_numeric: <dtype: 'int32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   event_type_id: <dtype: 'int32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   field_name_id: <dtype: 'int32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   event_idx: <dtype: 'int32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   field_value_numeric: <dtype: 'float32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   token_idx: <dtype: 'int32'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   DEBUGGING: Finished checking batch dtypes.
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   Building model...
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   event_type_lookup: {'<PAD>': 0, '<UNKNOWN>': 1, 'charge_event': 2, 'static': 3, 'subscription_event': 4, 'box_feedback': 5}
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   type(event_type_lookup): <class 'dict'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   len(event_type_lookup): 6 <class 'int'>
INFO    2025-05-14 16:59:22 -0700       workerpool0-0   inputs['event_type_id'] dtype: <dtype: 'int32'>
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0       main()
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0     File "/root/scripts/train.py", line 286, in main
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0       model = build_transformer_model(
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0               ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0     File "/root/model/build.py", line 74, in build_transformer_model
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0       event_type_embed = tf.keras.layers.Embedding(input_dim=int(len(event_type_lookup)), output_dim=embedding_dim, mask_zero=True, name='event_type_embed')(inputs['event_type_id'])
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0     File "/usr/local/lib/python3.12/dist-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0       raise e.with_traceback(filtered_tb) from None
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0     File "/usr/lib/python3.12/random.py", line 336, in randint
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0       return self.randrange(a, b+1)
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0              ^^^^^^^^^^^^^^^^^^^^^^
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0     File "/usr/lib/python3.12/random.py", line 312, in randrange
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0       istop = _index(stop)
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0               ^^^^^^^^^^^^
ERROR   2025-05-14 16:59:22 -0700       workerpool0-0   TypeError: 'float' object cannot be interpreted as an integer
INFO    2025-05-14 17:00:43 -0700       service Finished tearing down training program.
INFO    2025-05-14 17:00:43 -0700       service Job failed.

The offending code snippet is this:

def build_transformer_model(input_dims, event_type_lookup, field_name_lookup, string_value_lookup,
                            num_heads=12, num_blocks=12, ff_dim=64, dropout_rate=0.1, embedding_dim=64,
                            position_dim=32, classifier_head_dim=64, class_0_count=None, class_1_count=None):
    
    print("event_type_lookup:", event_type_lookup)
    print("type(event_type_lookup):", type(event_type_lookup))
    print("len(event_type_lookup):", len(event_type_lookup), type(len(event_type_lookup)))
    
    # --- Input layers ---
    inputs = {}
    for feature_name, (shape, dtype) in input_dims.items():
        inputs[feature_name] = tf.keras.layers.Input(shape=shape, name=feature_name, dtype=dtype)

    print("inputs['event_type_id'] dtype:", inputs['event_type_id'].dtype)

    # --- Embedding layers ---
    event_type_embed = tf.keras.layers.Embedding(input_dim=int(len(event_type_lookup)), output_dim=embedding_dim, mask_zero=True, name='event_type_embed')(inputs['event_type_id'])
    field_name_embed = tf.keras.layers.Embedding(input_dim=int(len(field_name_lookup)), output_dim=embedding_dim, mask_zero=True, name='field_name_embed')(inputs['field_name_id'])
    string_value_embed = tf.keras.layers.Embedding(input_dim=int(len(string_value_lookup)), output_dim=embedding_dim, mask_zero=True, name='string_value_embed')(inputs['field_value_string_id'])

I am thoroughly puzzled how this is happening. My only thought is some kind of version difference between what is running in my local python environment and the packages loaded up in the container.

Dockerfile:

FROM nvcr.io/nvidia/tensorflow:25.02-tf2-py3 AS base

# Install gcsfuse
RUN apt-get update && apt-get install -y \
    curl \
    gnupg \
    lsb-release \
    && echo "deb https://packages.cloud.google.com/apt gcsfuse-$(lsb_release -c -s) main" | tee /etc/apt/sources.list.d/gcsfuse.list \
    && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \
    && apt-get update \
    && apt-get install -y gcsfuse \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

# Create mount point directory
RUN mkdir -p /gcs

WORKDIR /root

COPY requirements.txt /root/

RUN pip install --upgrade pip
RUN pip install -r requirements.txt

FROM base

COPY . /root/

ENV PYTHONPATH=/root

ENTRYPOINT [ "python", "scripts/train.py" ]

requirements.txt:

google-cloud-bigquery==3.31.0
google-cloud-bigquery-storage==2.31.0
google-cloud-aiplatform[autologging]==1.92.0
tensorflow[and-cuda]==2.19.0
keras==3.9.2
tqdm==4.67.1

How could a float end up in that randrange call?

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions