Skip to content

Conversation

@Cui-yshoho
Copy link
Contributor

@Cui-yshoho Cui-yshoho commented Oct 20, 2025

What does this PR do?

Add

  • Bria with diffusers master.
    • mindone.diffusers.BriaTransformer2DModel
    • mindone.diffusers.BriaPipeline

Usage

import mindspore as ms
from mindone.diffusers import BriaPipeline

pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", mindspore_dtype=ms.bfloat16)
# BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.

pipe.text_encoder = pipe.text_encoder.to(dtype=ms.bfloat16)
for block in pipe.text_encoder.encoder.block:
    block.layer[-1].DenseReluDense.wo.set_dtype(ms.float32)
    # BRIA's VAE is not supported in mixed precision, so we use float32.

if pipe.vae.config.shift_factor == 0:
    pipe.vae.to(dtype=ms.float32)

prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
image = pipe(prompt)[0][0]
image.save("bria.png")

Performance

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0.

pipeline mode speed
BriaPipeline pynative 1.17 it/s
BriaPipeline jit 1.39 it/s

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@Cui-yshoho Cui-yshoho requested a review from vigo999 as a code owner October 20, 2025 06:37
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the Bria 3.2 text-to-image model and its associated pipeline into the mindone.diffusers framework. The changes encompass the core model architecture, the inference pipeline, and comprehensive documentation, all tailored for MindSpore compatibility. This addition expands the library's capabilities by offering a new, efficient, and commercially-ready text-to-image generation option, complete with usage examples and performance notes for Ascend devices.

Highlights

  • New Model Integration: Introduced the BriaTransformer2DModel, a modified Flux Transformer model, into the mindone.diffusers library, enabling the use of the Bria 3.2 text-to-image model.
  • New Pipeline Implementation: Added the BriaPipeline for seamless text-to-image generation using the Bria 3.2 model, including specific handling for T5 text encoder precision and VAE mixed precision.
  • Comprehensive Documentation: New documentation files have been added for both the BriaTransformer2DModel and the Bria 3.2 pipeline, detailing their functionality, usage, and key features.
  • MindSpore Compatibility: The Bria model and pipeline are implemented with MindSpore compatibility, including specific adjustments for attn_mask handling in attention mechanisms and performance considerations on Ascend hardware.
  • Extensive Testing: New unit tests for BriaTransformer2DModel and both fast and slow integration tests for BriaPipeline have been added to ensure correctness and performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Bria model and pipeline, which is a significant feature addition. The implementation looks solid and follows the existing structure of the library. I've found a few issues, including a critical runtime error due to a missing import, a logic bug in prompt handling, and some code duplication. I've also pointed out several documentation typos and areas for minor code improvements to enhance robustness and maintainability. Once these points are addressed, the PR should be in great shape.

FlowMatchEulerDiscreteScheduler,
KarrasDiffusionSchedulers,
)
from ...utils import logging, scale_lora_layers, unscale_lora_layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The name USE_PEFT_BACKEND used on line 169 is not defined, which will cause a NameError at runtime. You need to import it from ...utils.

Suggested change
from ...utils import logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers

Comment on lines +415 to +442
class BriaPosEmbed(nn.Cell):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim

def construct(self, ids: ms.Tensor) -> ms.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
freqs_dtype = ms.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = mint.cat(cos_out, dim=-1)
freqs_sin = mint.cat(sin_out, dim=-1)
return freqs_cos, freqs_sin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The class BriaPosEmbed is an exact duplicate of BriaEmbedND defined earlier in this file (lines 348-375). This code duplication increases maintenance overhead. You should remove BriaPosEmbed and use BriaEmbedND in its place.

Comment on lines +499 to +501
else:
attn_output, context_attn_output, ip_attn_output = None, None, None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The else block here silently sets attn_output, context_attn_output, and ip_attn_output to None if attention_outputs does not have a length of 2 or 3. This could hide bugs if self.attn returns an unexpected number of outputs. It would be safer to raise a ValueError to make such issues immediately apparent.

Suggested change
else:
attn_output, context_attn_output, ip_attn_output = None, None, None
else:
raise ValueError(f"Unexpected number of attention outputs: {len(attention_outputs)}")

return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for detecting and warning about prompt truncation appears to be incorrect. untruncated_ids is calculated using the entire prompt list, while text_input_ids is based on a single prompt p from the list within the loop. This comparison is invalid. You should calculate untruncated_ids for the single prompt p inside the loop.

Suggested change
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids
untruncated_ids = tokenizer(p, padding="longest", return_tensors="np").input_ids

Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2).
Github repo for Bria 3.2 can be found [here](https://github.com/Bria-AI/BRIA-3.2).

If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the word 'traril'. It should be 'trial'.

Suggested change
If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai).
If you want to learn more about the Bria platform, and get free trial access, please visit [bria.ai](https://bria.ai).

Comment on lines +38 to +51
def _get_fused_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)

encoder_query = encoder_key = encoder_value = (None,)
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)

return query, key, value, encoder_query, encoder_key, encoder_value


def _get_qkv_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _get_fused_projections and the check for attn.fused_projections in _get_qkv_projections appear to be unused, as the BriaAttention class does not have a fused_projections attribute. This looks like dead code that could be removed to improve clarity and reduce maintenance overhead.

The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring indicates that max_sequence_length defaults to 256, but the function signature on line 439 sets the default to 128. Please update the docstring to match the implementation for consistency.

Suggested change
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`.


# expand the latents if we are doing classifier free guidance
latent_model_input = mint.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using type() for type checking is generally discouraged in favor of isinstance(), as isinstance() correctly handles inheritance. It's more robust to check if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):.

Suggested change
if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):

@Cui-yshoho Cui-yshoho added the new model add new model to mindone label Oct 20, 2025
# While this behavior is consistent with HF Diffusers for now,
# it may still be a potential bug source worth validating.
if attn_mask is not None and 1.0 in attn_mask:
if attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是否是对齐原实现的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是否是对齐原实现的

嗯嗯,发现的一个新问题,这里修改一下,主要是为了解决sdpa输入如果是float和ms的fa无法对齐的问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new model add new model to mindone

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants