Skip to content

Error while running bash command: run_sample_video.sh | Error: "TypeError: missing a required argument: 'segment_ids'" #77

@samitm-123

Description

@samitm-123

I receive this error when i run this bash command: !bash LWM/scripts/run_sample_video.sh. I have followed all the direction listed in the repo.

/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/LWM/lwm/vision_generation.py", line 256, in <module>
    run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/content/LWM/lwm/vision_generation.py", line 92, in main
    model = FlaxVideoLLaMAForCausalLM(
  File "/content/LWM/lwm/vision_llama.py", line 141, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flax_utils.py", line 224, in __init__
    params_shape_tree = jax.eval_shape(init_fn, self.key)
  File "/content/LWM/lwm/vision_llama.py", line 166, in init_weights
    random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
  File "/content/LWM/lwm/vision_llama.py", line 396, in __call__
    outputs = self.transformer(
  File "/content/LWM/lwm/vision_llama.py", line 315, in __call__
    outputs = self.h(
  File "/content/LWM/lwm/llama.py", line 945, in __call__
    hidden_states, _ = nn.scan(
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 151, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 123, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/content/LWM/lwm/llama.py", line 724, in __call__
    attn_outputs = self.attention(
  File "/content/LWM/lwm/llama.py", line 615, in __call__
    attn_output = ring_attention_sharded(
  File "/usr/lib/python3.10/inspect.py", line 3186, in bind
    return self._bind(args, kwargs)
  File "/usr/lib/python3.10/inspect.py", line 3101, in _bind
    raise TypeError(msg) from None
TypeError: missing a required argument: 'segment_ids'

Would appreciate some help here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions