-
Notifications
You must be signed in to change notification settings - Fork 559
Open
Description
I receive this error when i run this bash command: !bash LWM/scripts/run_sample_video.sh. I have followed all the direction listed in the repo.
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/content/LWM/lwm/vision_generation.py", line 256, in <module>
run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/content/LWM/lwm/vision_generation.py", line 92, in main
model = FlaxVideoLLaMAForCausalLM(
File "/content/LWM/lwm/vision_llama.py", line 141, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flax_utils.py", line 224, in __init__
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/content/LWM/lwm/vision_llama.py", line 166, in init_weights
random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
File "/content/LWM/lwm/vision_llama.py", line 396, in __call__
outputs = self.transformer(
File "/content/LWM/lwm/vision_llama.py", line 315, in __call__
outputs = self.h(
File "/content/LWM/lwm/llama.py", line 945, in __call__
hidden_states, _ = nn.scan(
File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 151, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 123, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/content/LWM/lwm/llama.py", line 724, in __call__
attn_outputs = self.attention(
File "/content/LWM/lwm/llama.py", line 615, in __call__
attn_output = ring_attention_sharded(
File "/usr/lib/python3.10/inspect.py", line 3186, in bind
return self._bind(args, kwargs)
File "/usr/lib/python3.10/inspect.py", line 3101, in _bind
raise TypeError(msg) from None
TypeError: missing a required argument: 'segment_ids'
Would appreciate some help here.
leftouterjoins
Metadata
Metadata
Assignees
Labels
No labels