Skip to content

[refactor] Wan single file implementation #11918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 14, 2025

Refactors most of Wan modeling into a single file implementation following transformers.

Requires #11916 to be merged first

@a-r-r-o-w a-r-r-o-w requested a review from DN6 July 14, 2025 05:35
Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>
@a-r-r-o-w a-r-r-o-w changed the title [refactor] Flux single file implementation [refactor] Wan single file implementation Jul 14, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member Author

@DN6 could you approve #11920 too so I can run the final tests here before merging (and ofc fixing the VACE related test issues)?

@a-r-r-o-w
Copy link
Member Author

@sayakpaul Trying to run the compile tests, I see failures on both main and this branch:

Logs
(nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ TORCH_LOGS=recompiles,dynamo RUN_SLOW=1 RUN_COMPILE=1 pytest -s tests/models/transformers/test_models_transformer_wan.py::WanTransformerCompileTests::test_torch_compile_repeated_blocks
============================================= test session starts ==============================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/aryan/work/diffusers
configfile: pyproject.toml
plugins: timeout-2.3.1, requests-mock-1.10.0, xdist-3.6.1, hydra-core-1.3.2, anyio-4.6.2.post1
collected 1 item                                                                                               

tests/models/transformers/test_models_transformer_wan.py I0723 14:46:54.177000 1671979 torch/_dynamo/__init__.py:112] torch._dynamo.reset
I0723 14:46:54.177000 1671979 torch/_dynamo/__init__.py:145] torch._dynamo.reset_code_caches
I0723 14:46:54.955000 1671979 torch/_dynamo/utils.py:1603] [0/0] ChromiumEventLogger initialized with id de26dbac-a5a9-4076-bf91-add317329cda
I0723 14:46:54.959000 1671979 torch/_dynamo/symbolic_convert.py:3322] [0/0] Step 1: torchdynamo start tracing forward /home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_wan.py:305
I0723 14:46:54.960000 1671979 torch/fx/experimental/symbolic_shapes.py:3334] [0/0] create_env
I0723 14:46:55.434000 1671979 torch/_dynamo/symbolic_convert.py:3679] [0/0] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0723 14:46:55.449000 1671979 torch/_dynamo/output_graph.py:1515] [0/0] Step 2: calling compiler function inductor
I0723 14:47:00.886000 1671979 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
I0723 14:47:00.893000 1671979 torch/_dynamo/output_graph.py:1520] [0/0] Step 2: done compiler function inductor
I0723 14:47:00.925000 1671979 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
I0723 14:47:00.978000 1671979 torch/_dynamo/pgo.py:660] [0/0] put_code_state: no cache key, skipping
I0723 14:47:00.979000 1671979 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
V0723 14:47:00.993000 1671979 torch/_dynamo/guards.py:3006] [0/1] [__recompiles] Recompiling function forward in /home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_wan.py:305
V0723 14:47:00.993000 1671979 torch/_dynamo/guards.py:3006] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0723 14:47:00.993000 1671979 torch/_dynamo/guards.py:3006] [0/1] [__recompiles]     - 0/0: ___check_obj_id(self._modules['attn1'].processor, 139920780269056)
W0723 14:47:00.993000 1671979 torch/_dynamo/convert_frame.py:964] [0/1] torch._dynamo hit config.recompile_limit (1)
W0723 14:47:00.993000 1671979 torch/_dynamo/convert_frame.py:964] [0/1]    function: 'forward' (/home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_wan.py:305)
W0723 14:47:00.993000 1671979 torch/_dynamo/convert_frame.py:964] [0/1]    last reason: 0/0: ___check_obj_id(self._modules['attn1'].processor, 139920780269056)
W0723 14:47:00.993000 1671979 torch/_dynamo/convert_frame.py:964] [0/1] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0723 14:47:00.993000 1671979 torch/_dynamo/convert_frame.py:964] [0/1] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0723 14:47:00.997000 1671979 torch/_inductor/utils.py:953] on error, temporary cache dir kept at /tmp/tmp8ie_n_zu
I0723 14:47:00.997000 1671979 torch/_dynamo/__init__.py:112] torch._dynamo.reset
I0723 14:47:00.998000 1671979 torch/_dynamo/__init__.py:145] torch._dynamo.reset_code_caches
F

=================================================== FAILURES ===================================================
________________________ WanTransformerCompileTests.test_torch_compile_repeated_blocks _________________________

self = <tests.models.transformers.test_models_transformer_wan.WanTransformerCompileTests testMethod=test_torch_compile_repeated_blocks>

    def test_torch_compile_repeated_blocks(self):
        if self.model_class._repeated_blocks is None:
            pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
    
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
    
        model = self.model_class(**init_dict).to(torch_device)
        model.compile_repeated_blocks(fullgraph=True)
    
        recompile_limit = 1
        if self.model_class.__name__ == "UNet2DConditionModel":
            recompile_limit = 2
    
        with (
            torch._inductor.utils.fresh_inductor_cache(),
            torch._dynamo.config.patch(recompile_limit=recompile_limit),
            torch.no_grad(),
        ):
>           _ = model(**inputs_dict)

tests/models/test_modeling_common.py:2000: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1751: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
src/diffusers/models/transformers/transformer_wan.py:488: in forward
    hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1749: in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:655: in _fn
    return fn(*args, **kwargs)
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1762: in _call_impl
    return forward_call(*args, **kwargs)
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:1432: in __call__
    return self._torchdynamo_orig_callable(
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:598: in __call__
    return _compile(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

code = <code object forward at 0x7f41d580c190, file "/home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 305>
globals = {'Any': typing.Any, 'Attention': <class 'diffusers.models.attention_processor.Attention'>, 'CacheMixin': <class 'diffusers.models.cache_utils.CacheMixin'>, 'ConfigMixin': <class 'diffusers.configuration_utils.ConfigMixin'>, ...}
locals = {'encoder_hidden_states': tensor([[[-0.1011, -0.2938,  0.1531,  0.3172, -0.0294,  0.0226,  0.1603,
          -0.4711, ...=32, out_features=24, bias=True)
    )
  )
  (norm3): FP32LayerNorm((24,), eps=1e-06, elementwise_affine=False)
), ...}
builtins = {'ArithmeticError': <class 'ArithmeticError'>, 'AssertionError': <class 'AssertionError'>, 'AttributeError': <class 'AttributeError'>, 'BaseException': <class 'BaseException'>, ...}
closure = (), compiler_fn = <torch._dynamo.repro.after_dynamo.WrapBackendDebug object at 0x7f41d50dca00>
one_graph = True, export = False, export_constraints = None
hooks = Hooks(guard_export_fn=None, guard_fail_fn=None)
cache_entry = <torch._C._dynamo.eval_frame._CacheEntry object at 0x7f41d4364cf0>
cache_size = CacheSizeRelevantForFrame(num_cache_entries=1, num_cache_entries_with_same_id_matched_objs=1)
frame = <torch._C._dynamo.eval_frame._PyInterpreterFrame object at 0x7f41d5173750>, frame_state = {'_id': 0}

    def _compile(
        code: CodeType,
        globals: dict[str, object],
        locals: dict[str, object],
        builtins: dict[str, object],
        closure: tuple[CellType],
        compiler_fn: CompilerFn,
        one_graph: bool,
        export: bool,
        export_constraints: Optional[typing.Never],
        hooks: Hooks,
        cache_entry: Optional[CacheEntry],
        cache_size: CacheSizeRelevantForFrame,
        frame: Optional[DynamoFrameType] = None,
        frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
        *,
        compile_id: CompileId,
        skip: int = 0,
    ) -> ConvertFrameReturn:
        from torch.fx.experimental.validator import (
            bisect,
            BisectValidationException,
            translation_validation_enabled,
            ValidationException,
        )
    
        # Only nonlocal defs here please!
        # Time spent compiling this frame before restarting or failing analysis
        dynamo_time_before_restart: float = 0.0
        output: Optional[OutputGraph] = None
        tracer: Optional[InstructionTranslator] = None
    
        tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
            torch.overrides._get_current_function_mode_stack()
        )
    
        @preserve_global_state
        def transform(
            instructions: list[Instruction], code_options: dict[str, object]
        ) -> None:
            nonlocal output
            nonlocal tracer
            speculation_log.restart()
            exn_vt_stack = ExceptionStack()
            tracer = InstructionTranslator(
                instructions,
                code,
                locals,
                globals,
                builtins,
                closure,
                tf_mode_stack,
                code_options,
                compiler_fn,
                one_graph,
                export,
                export_constraints,
                frame_state=frame_state,
                speculation_log=speculation_log,
                exn_vt_stack=exn_vt_stack,
                distributed_state=distributed_state,
            )
    
            try:
                with tracing(tracer.output.tracing_context), tracer.set_current_tx():
                    tracer.run()
            except exc.UnspecializeRestartAnalysis:
                speculation_log.clear()
                raise
            except (
                exc.SpeculationRestartAnalysis,
                exc.TensorifyScalarRestartAnalysis,
                exc.SkipFrame,
            ):
                raise
            except Exception:
                if translation_validation_enabled():
                    bisect(tracer.output.shape_env)
                raise
            finally:
                tracer.output.call_cleanup_hooks()
    
            output = tracer.output
            assert output is not None
            assert output.output_instructions
            instructions[:] = output.output_instructions
            code_options.update(output.code_options)
            propagate_inst_exn_table_entries(instructions)
            check_inst_exn_tab_entries_valid(instructions)
            instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
    
        @compile_time_strobelight_meta(phase_name="compile_inner")
        def compile_inner(
            code: CodeType,
            one_graph: bool,
            hooks: Hooks,
            transform: Callable[[list[Instruction], dict[str, Any]], Any],
        ) -> ConvertFrameReturn:
            with contextlib.ExitStack() as stack:
                stack.enter_context(
                    dynamo_timed(
                        "_compile.compile_inner",
                        phase_name="entire_frame_compile",
                        dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
                    )
                )
                stack.enter_context(
                    _WaitCounter("pytorch.wait_counter.dynamo_compile").guard()
                )
                stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
                stack.enter_context(CompileTimeInstructionCounter.record())
                return _compile_inner(code, one_graph, hooks, transform)
    
            return (
                ConvertFrameReturn()
            )  # dead, but see https://github.com/python/mypy/issues/7577
    
        @maybe_cprofile
        def _compile_inner(
            code: CodeType,
            one_graph: bool,
            hooks: Hooks,
            transform: Callable[[list[Instruction], dict[str, Any]], Any],
        ) -> ConvertFrameReturn:
            nonlocal dynamo_time_before_restart
            last_attempt_start_time = start_time = time.time()
    
            def log_bytecode(
                prefix: str, name: str, filename: str, line_no: int, code: CodeType
            ) -> None:
                if bytecode_log.isEnabledFor(logging.DEBUG):
                    bytecode_log.debug(
                        format_bytecode(prefix, name, filename, line_no, code)
                    )
    
            log_bytecode(
                "ORIGINAL BYTECODE",
                code.co_name,
                code.co_filename,
                code.co_firstlineno,
                code,
            )
    
            out_code = None
            for attempt in itertools.count():
                CompileContext.get().attempt = attempt
                try:
                    out_code = transform_code_object(code, transform)
                    break
                except exc.RestartAnalysis as e:
                    if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
                        TensorifyState.clear()
                    log.info(
                        "Restarting analysis due to %s",
                        LazyString(format_traceback_short, e.__traceback__),
                    )
                    # If restart reason is None just log the type of the exception
                    restart_reasons.add(e.restart_reason or str(type(e)))
                    # We now have a new "last attempt", reset the clock
                    last_attempt_start_time = time.time()
                    if attempt > 100:
                        unimplemented_v2(
                            gb_type="Excessive RestartAnalysis() calls",
                            context="",
                            explanation="Dynamo attempted to trace the same frame 100+ times. "
                            "Giving up on compiling as the compile time tradeoff is likely not "
                            "worth the performance gain.",
                            hints=[],
                        )
                except exc.SkipFrame as e:
                    if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
                        TensorifyState.clear()
                    log.debug(
                        "Skipping frame %s %s \
                        %s %s",
                        e,
                        code.co_name,
                        code.co_filename,
                        code.co_firstlineno,
                    )
                    if one_graph:
                        log.debug("No graph captured with one_graph=True")
                    return ConvertFrameReturn()
    
            assert distributed_state is None or distributed_state.all_states is not None, (
                "compiler collective wasn't run before compilation completed"
            )
    
            assert out_code is not None
            log_bytecode(
                "MODIFIED BYTECODE",
                code.co_name,
                code.co_filename,
                code.co_firstlineno,
                out_code,
            )
    
            for hook in _bytecode_hooks.values():
                hook_output = hook(code, out_code)
                if hook_output is not None:
                    out_code = hook_output
    
            orig_code_map[out_code] = code
            output_codes.add(out_code)
            dynamo_time_before_restart = last_attempt_start_time - start_time
            assert output is not None
    
            # Tests for new code objects.
            # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c
            # Only test once the code object is created.
            # They are not tested during runtime.
    
            def count_args(code: CodeType) -> int:
                import inspect
    
                return (
                    code.co_argcount
                    + code.co_kwonlyargcount
                    + bool(code.co_flags & inspect.CO_VARARGS)
                    + bool(code.co_flags & inspect.CO_VARKEYWORDS)
                )
    
            assert out_code is not None
    
            total_argcount_old = count_args(code)
            total_argcount_new = count_args(out_code)
            msg = "arg mismatch: "
            msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, "
            msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}"
            assert (
                code.co_varnames[:total_argcount_old]
                == out_code.co_varnames[:total_argcount_new]
            ), msg
    
            msg = "free var mismatch: "
            msg += f"old code object has free var {code.co_freevars}, "
            msg += f"new code object has free var {out_code.co_freevars}"
            assert code.co_freevars == out_code.co_freevars, msg
    
            msg = "cell var mismatch: "
            msg += f"old code object has cell var {code.co_cellvars}, "
            msg += f"new code object has cell var {out_code.co_cellvars}"
            assert code.co_cellvars == out_code.co_cellvars, msg
    
            # Skipping Dynamo on a frame without any extracted graph.
            # This does not affect eager functionality. But this is necessary
            # for export for cases where Dynamo-reconstructed bytecode can create
            # new function frames, confusing export in thinking that there
            # are extra graphs now.
    
            if output.export and output.is_empty_graph():
                return ConvertFrameReturn()
    
            assert output.guards is not None
            CleanupManager.instance[out_code] = output.cleanups
            nonlocal cache_entry
            check_fn = CheckFunctionManager(
                code,
                output,
                cache_entry,
                hooks.guard_fail_fn if hooks else None,
            )
    
            compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
            annotation_str = "Torch-Compiled Region: " + compile_id_str
            guarded_code = GuardedCode(
                out_code,
                check_fn.guard_manager,  # type: ignore[arg-type]
                compile_id,
                annotation_str,
            )
    
            if not output.is_empty_graph() and hooks.guard_export_fn is not None:
                # We should not run the guard_export_fn when Dynamo does not
                # generate any graph. This can happen in export when TorchDynamo
                # generated bytecode has some reconstruction logic for mutated
                # variables which can trigger TorchDynamo on the children frames but
                # they are benign and do not generate any new graphs.
                hooks.guard_export_fn(output.guards)
    
            return wrap_guarded_code(guarded_code)
    
        metrics_context = get_metrics_context()
        with (
            _use_lazy_graph_module(config.use_lazy_graph_module),
            compile_context(CompileContext(compile_id)),
            chromium_event_timed(
                "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
            ),
            metrics_context,
        ):
            restart_reasons: set[str] = set()
            # This is shared across restarts
            speculation_log = SpeculationLog()
            if compile_pg := get_compile_pg():
                distributed_state = DistributedState(compile_pg, LocalState())
            else:
                distributed_state = None
    
            # Check recompilations
            recompile_reason: Optional[str] = None
            if is_recompilation(cache_size) and frame:
                reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame)
                recompile_reason = (
                    "Unable to find recompilation reasons" if not reasons else reasons[0]
                )
            metrics_context.update_outer({"recompile_reason": recompile_reason})
    
            exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id)
            if exceeded:
    
                def format_func_info(code: CodeType) -> str:
                    return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
    
                log.warning(
                    "torch._dynamo hit config.%s (%s)\n"
                    "   function: %s\n"
                    "   last reason: %s\n"
                    'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n'
                    "To diagnose recompilation issues, see %s.",
                    limit_type,
                    getattr(config, limit_type),
                    format_func_info(code),
                    recompile_reason,
                    troubleshooting_url,
                )
                if config.fail_on_recompile_limit_hit:
                    raise FailOnRecompileLimitHit(
                        f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
                    )
                elif one_graph:
>                   raise FailOnRecompileLimitHit(
                        f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade "
                        "performance due to the compilation overhead of each recompilation. To monitor "
                        "recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
                        "increasing torch._dynamo.config.cache_size_limit to an appropriate value."
E                       torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.

/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:981: FailOnRecompileLimitHit
=============================================== warnings summary ===============================================
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
../../../../raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108
  /raid/aryan/nightly-venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

tests/models/transformers/test_models_transformer_wan.py::WanTransformerCompileTests::test_torch_compile_repeated_blocks
  /raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:236: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================== short test summary info ============================================
FAILED tests/models/transformers/test_models_transformer_wan.py::WanTransformerCompileTests::test_torch_compile_repeated_blocks - torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompila...
======================================== 1 failed, 5 warnings in 8.08s =========================================
I0723 14:47:01.756000 1672591 torch/_dynamo/eval_frame.py:475] TorchDynamo attempted to trace the following frames: [
I0723 14:47:01.756000 1672591 torch/_dynamo/eval_frame.py:475] 
I0723 14:47:01.756000 1672591 torch/_dynamo/eval_frame.py:475] ]
I0723 14:47:01.762000 1672591 torch/_dynamo/utils.py:765] TorchDynamo compilation metrics:
I0723 14:47:01.762000 1672591 torch/_dynamo/utils.py:765] Function    Runtimes (s)
I0723 14:47:01.762000 1672591 torch/_dynamo/utils.py:765] ----------  --------------
I0723 14:47:02.346000 1671979 torch/_dynamo/eval_frame.py:475] TorchDynamo attempted to trace the following frames: [
I0723 14:47:02.346000 1671979 torch/_dynamo/eval_frame.py:475]   * forward /home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_wan.py:305
I0723 14:47:02.346000 1671979 torch/_dynamo/eval_frame.py:475]   * forward /home/aryan/work/diffusers/src/diffusers/models/transformers/transformer_wan.py:305
I0723 14:47:02.346000 1671979 torch/_dynamo/eval_frame.py:475] ]
I0723 14:47:02.352000 1671979 torch/_dynamo/utils.py:765] TorchDynamo compilation metrics:
I0723 14:47:02.352000 1671979 torch/_dynamo/utils.py:765] Function    Runtimes (s)
I0723 14:47:02.352000 1671979 torch/_dynamo/utils.py:765] ----------  --------------

I'm guessing this is expected unless testing with 2.8 nightly?

@sayakpaul
Copy link
Member

Yes, just confirmed it and opened #11979 as a mitigation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants