Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions examples/opensora_hpcai/opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def construct(self, x, cond, mask=None):
# 2+: mask adaptation for multi-head attention
if mask is not None:
# flip mask, since ms FA treats 1 as discard, 0 as retain.
mask = 1 - mask.to(ms.int32)
mask = 1 - mask

# 3. attn compute
if self.enable_flash_attention:
Expand Down Expand Up @@ -266,15 +266,15 @@ def construct(self, x: Tensor, cond: Tensor, mask: Optional[Tensor] = None) -> T
# 2+: mask adaptation for multi-head attention
if mask is not None:
# flip mask, since ms FA treats 1 as discard, 0 as retain.
mask = 1 - mask.to(ms.int32)
mask = 1 - mask

# 3. attn compute
if self.enable_flash_attention:
if mask is not None:
# (b n_k) -> (b 1 1 n_k), will be broadcast according to qk sim, e.g. (b num_heads n_q n_k)
mask = mask[:, None, None, :]
# (b 1 1 n_k) -> (b 1 n_q n_k)
mask = self.repeat_interleave(mask.to(ms.int32), int(q.shape[1]), axis=-2)
mask = self.repeat_interleave(mask, int(q.shape[1]), axis=-2)
x = self.flash_attention(q, k, v, mask=mask)

# FA attn_mask def: retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)` `(S1, S2)`
Expand Down Expand Up @@ -384,7 +384,7 @@ def construct(self, x, mask=None, freqs_cis: Optional[Tensor] = None):

# mask process
if mask is not None:
mask = 1 - mask.to(ms.int32)
mask = 1 - mask

if self.enable_flash_attention:
if mask is not None:
Expand Down Expand Up @@ -500,8 +500,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True,
self.gamma = Parameter(initializer("ones", normalized_shape, dtype=dtype))
self.beta = Parameter(initializer("zeros", normalized_shape, dtype=dtype))
else:
self.gamma = ops.ones(normalized_shape, dtype=dtype)
self.beta = ops.zeros(normalized_shape, dtype=dtype)
self.gamma = Tensor(np.ones(normalized_shape, dtype=np.float32))
self.beta = Tensor(np.zeros(normalized_shape, dtype=np.float32))

def construct(self, x: Tensor):
normalized_shape = x.shape[-1:]
Expand Down Expand Up @@ -592,10 +592,7 @@ def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
self.norm_final = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# (1152, 4*8)
self.linear = nn.Dense(hidden_size, num_patch * out_channels, has_bias=True)
# self.scale_shift_table = Parameter((ops.randn(2, hidden_size, dtype=ms.float32) / hidden_size**0.5).astype(ms.float32))
self.scale_shift_table = Parameter(
ms.Tensor((np.random.randn(2, hidden_size) / hidden_size**0.5), dtype=ms.float32)
)
self.scale_shift_table = Parameter(np.random.randn(2, hidden_size).astype(np.float32) / hidden_size**0.5)
self.out_channels = out_channels
self.d_t = d_t
self.d_s = d_s
Expand All @@ -614,11 +611,13 @@ def construct(
T = self.d_t
if S is None:
S = self.d_s
shift, scale = self.chunk(self.scale_shift_table[None] + t[:, None], 2, 1)

scale_shift_table = self.scale_shift_table.to(x.dtype)
shift, scale = self.chunk(scale_shift_table[None] + t[:, None], 2, 1)
x = t2i_modulate(self.norm_final(x), shift, scale)

if frames_mask is not None:
shift_zero, scale_zero = self.chunk(self.scale_shift_table[None] + t0[:, None], 2, 1)
shift_zero, scale_zero = self.chunk(scale_shift_table[None] + t0[:, None], 2, 1)
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
x = t_mask_select(frames_mask, x, x_zero, T, S)

Expand All @@ -639,9 +638,9 @@ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU, tok
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
)

y_embedding = ops.randn(token_num, in_channels) / in_channels**0.5
y_embedding = np.random.randn(token_num, in_channels).astype(np.float32) / in_channels**0.5
# just for token dropping replacement, not learnable
self.y_embedding = Parameter(Tensor(y_embedding, dtype=ms.float32), requires_grad=False)
self.y_embedding = Parameter(y_embedding, requires_grad=False)

self.uncond_prob = uncond_prob

Expand All @@ -656,9 +655,7 @@ def token_drop(self, caption, force_drop_ids=None):

# manually expand dims to avoid infer-shape bug in ms2.3 daily
caption = ops.where(
drop_ids[:, None, None, None],
self.y_embedding[None, None, :, :].to(caption.dtype),
caption,
drop_ids[:, None, None, None], self.y_embedding[None, None, :, :].to(caption.dtype), caption
)

return caption
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mindspore as ms
from mindspore import mint, ops
from mindspore.ops.function.array_func import chunk_ext, repeat_interleave_ext
from mindspore.ops.function.array_func import chunk_ext

use_dynamic_ops = False

Expand Down Expand Up @@ -56,9 +56,7 @@ def get_repeat_interleave_op():
# provide better performance for static shape in graph mode
return ops.repeat_interleave
else:
# FIXME: check overflow for v2
# return repeat_interleave_ext_v2
return repeat_interleave_ext
return repeat_interleave_ext_v2


def get_chunk_op():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def rotate_half(x: Tensor) -> Tensor:
def apply_rotary_emb(freqs: Parameter, t: Tensor, scale: float = 1.0, seq_dim: int = -2) -> Tensor:
# FIXME: start_index is always 0 in OS1.2 and ops.concat doesn't support empty elements. OS1.x future versions may need start_index > 0
# t, t_right = t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos().astype(t.dtype) * scale) + (rotate_half(t) * freqs.sin().astype(t.dtype) * scale)
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)

return t

Expand Down Expand Up @@ -139,7 +139,7 @@ def get_axial_freqs(self, *dims):
raise NotImplementedError

def construct(self, t: Tensor, seq_len=None, offset=0) -> Tensor:
freqs = t.astype(self.freqs.dtype)[..., None] * self.freqs
freqs = t[..., None] * self.freqs.to(t.dtype)
return self.repeat_interleave(freqs, 2, -1) # ... n -> ... (n r), r = 2


Expand Down
5 changes: 3 additions & 2 deletions examples/opensora_hpcai/opensora/models/stdit/stdit3.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ def construct(
) -> Tensor:
# prepare modulate parameters
B, N, C = x.shape
scale_shift_table = self.scale_shift_table.to(x.dtype)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.chunk(
self.scale_shift_table[None] + t.reshape(B, 6, -1), 6, 1
scale_shift_table[None] + t.reshape(B, 6, -1), 6, 1
)

# frames mask branch
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = self.chunk(
self.scale_shift_table[None] + t0.reshape(B, 6, -1), 6, 1
scale_shift_table[None] + t0.reshape(B, 6, -1), 6, 1
)

# modulate (attention)
Expand Down
13 changes: 6 additions & 7 deletions examples/opensora_hpcai/opensora/models/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,11 @@ def __init__(
self.scale_factor = 0.18215

@staticmethod
def rearrange_in(x):
B, C, T, H, W = x.shape
# (b c t h w) -> (b t c h w)
x = ops.transpose(x, (0, 2, 1, 3, 4))
def rearrange_in(x, transpose: bool = True):
if transpose: # (b c t h w) -> (b t c h w)
x = ops.transpose(x, (0, 2, 1, 3, 4))
B, T, C, H, W = x.shape
x = ops.reshape(x, (B * T, C, H, W))

return x

@staticmethod
Expand All @@ -139,8 +138,8 @@ def encode(self, x):
# is_video = (x.ndim == 5)

B = x.shape[0]
# B C T H W -> (B T) C H W
x = self.rearrange_in(x)
# B T C H W -> (B T) C H W
x = self.rearrange_in(x, transpose=False)

pad_num = None
if self.micro_batch_parallel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,10 @@ def construct(

with no_grad():
# 1. get image/video latents z using vae
# (b f c h w) -> (b c f h w)
x = ops.transpose(x, (0, 2, 1, 3, 4))

if not self.video_emb_cached:
x = self.get_latents(x)
else:
x = ops.transpose(x, (0, 2, 1, 3, 4))

# 2. get conditions
if not self.text_emb_cached:
Expand Down
8 changes: 3 additions & 5 deletions examples/opensora_hpcai/opensora/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@

from mindspore import Model as MSModel
from mindspore import context, nn
from mindspore.nn import GroupNorm, SiLU # GELU
from mindspore.nn import GroupNorm
from mindspore.train.callback import _CallbackManager

from ..models.layers.blocks import Attention, LayerNorm, LlamaRMSNorm, PositionEmbedding2D, SinusoidalEmbedding
from ..models.text_encoder.flan_t5_large.t5 import T5LayerNorm

# SORA's whitelist (FP32) operators
WHITELIST_OPS = [
# SORA's blacklist (FP32) operators for O2 AMP level
BLACKLIST_OPS = [
LayerNorm,
Attention,
LlamaRMSNorm,
SiLU,
# GELU,
GroupNorm,
PositionEmbedding2D,
SinusoidalEmbedding,
Expand Down
6 changes: 3 additions & 3 deletions examples/opensora_hpcai/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from opensora.pipelines import InferPipeline, InferPipelineFiTLike
from opensora.utils.amp import auto_mixed_precision
from opensora.utils.cond_data import get_references, read_captions_from_csv, read_captions_from_txt
from opensora.utils.model_utils import WHITELIST_OPS, _check_cfgs_in_parser, str2bool
from opensora.utils.model_utils import BLACKLIST_OPS, _check_cfgs_in_parser, str2bool
from opensora.utils.util import IMG_FPS, apply_mask_strategy, process_mask_strategies, process_prompts

from mindone.data.data_split import distribute_samples
Expand Down Expand Up @@ -273,7 +273,7 @@ def main(args):

if args.dtype in ["fp16", "bf16"]:
latte_model = auto_mixed_precision(
latte_model, amp_level=args.amp_level, dtype=dtype_map[args.dtype], custom_fp32_cells=WHITELIST_OPS
latte_model, amp_level=args.amp_level, dtype=dtype_map[args.dtype], custom_fp32_cells=BLACKLIST_OPS
)

if args.ckpt_path:
Expand Down Expand Up @@ -301,7 +301,7 @@ def main(args):
"T5 dtype is fp16, which may lead to video color vibration. Suggest to use bf16 or fp32."
)
text_encoder = auto_mixed_precision(
text_encoder, amp_level="O2", dtype=dtype_map[args.t5_dtype], custom_fp32_cells=WHITELIST_OPS
text_encoder, amp_level="O2", dtype=dtype_map[args.t5_dtype], custom_fp32_cells=BLACKLIST_OPS
)
logger.info(f"Num tokens: {mask.asnumpy().sum(2)}")
else:
Expand Down
4 changes: 2 additions & 2 deletions examples/opensora_hpcai/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback
from opensora.utils.ema import EMA, save_ema_ckpts
from opensora.utils.metrics import BucketLoss
from opensora.utils.model_utils import WHITELIST_OPS, Model
from opensora.utils.model_utils import BLACKLIST_OPS, Model
from opensora.utils.resume import flush_from_cache, get_resume_ckpt, get_resume_states, resume_train_net, save_train_net

from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch, StopAtStepCallback
Expand Down Expand Up @@ -467,7 +467,7 @@ def main(args):
latte_model,
amp_level=args.amp_level,
dtype=dtype_map[args.dtype],
custom_fp32_cells=WHITELIST_OPS,
custom_fp32_cells=BLACKLIST_OPS,
)
# load checkpoint
if len(args.pretrained_model_path) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from opensora.acceleration.parallel_states import create_parallel_group, get_sequence_parallel_group
from opensora.models.stdit.stdit3 import STDiT3
from opensora.utils.amp import auto_mixed_precision
from opensora.utils.model_utils import WHITELIST_OPS
from opensora.utils.model_utils import BLACKLIST_OPS

import mindspore as ms
import mindspore.nn as nn
Expand Down Expand Up @@ -82,7 +82,7 @@ def run_model(mode: int = 0, model_dtype: ms.dtype = ms.float32):
non_dist_model,
amp_level="O2",
dtype=model_dtype,
custom_fp32_cells=WHITELIST_OPS,
custom_fp32_cells=BLACKLIST_OPS,
)

# sequence parallel model
Expand All @@ -95,7 +95,7 @@ def run_model(mode: int = 0, model_dtype: ms.dtype = ms.float32):
dist_model,
amp_level="O2",
dtype=model_dtype,
custom_fp32_cells=WHITELIST_OPS,
custom_fp32_cells=BLACKLIST_OPS,
)

for (_, w0), (_, w1) in zip(non_dist_model.parameters_and_names(), dist_model.parameters_and_names()):
Expand Down
6 changes: 3 additions & 3 deletions examples/opensora_hpcai/tests/test_vae_1_2_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from opensora.acceleration.parallel_states import create_parallel_group
from opensora.models.vae.vae import OpenSoraVAE_V1_2
from opensora.utils.amp import auto_mixed_precision
from opensora.utils.model_utils import WHITELIST_OPS
from opensora.utils.model_utils import BLACKLIST_OPS

import mindspore as ms
import mindspore.nn as nn
Expand Down Expand Up @@ -58,7 +58,7 @@ def run_model(mode: int = 0, model_dtype: ms.dtype = ms.float16):
non_dist_model,
amp_level="O2",
dtype=model_dtype,
custom_fp32_cells=WHITELIST_OPS,
custom_fp32_cells=BLACKLIST_OPS,
)
non_dist_model.set_train(False)

Expand All @@ -72,7 +72,7 @@ def run_model(mode: int = 0, model_dtype: ms.dtype = ms.float16):
dist_model,
amp_level="O2",
dtype=model_dtype,
custom_fp32_cells=WHITELIST_OPS,
custom_fp32_cells=BLACKLIST_OPS,
)
dist_model.set_train(False)

Expand Down
10 changes: 7 additions & 3 deletions mindone/trainers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,13 @@ def on_train_step_end(self, run_context):
cur_step = cb_params.cur_step_num
if cur_step == self.end_step:
self.profiler.stop()
self.profiler.analyse()
_logger.info(f"finish analyzing profiler in step range [{self.start_step}, {self.end_step}]")
_logger.info(f"Finished profiling in step range [{self.start_step}, {self.end_step}]")
if self.exit_after_analyze:
run_context.request_stop()

def on_train_end(self, run_context):
self.profiler.analyse()


class ProfilerCallbackEpoch(Callback):
def __init__(self, start_epoch, stop_epoch, output_dir="./profiler_data"):
Expand All @@ -468,4 +470,6 @@ def on_train_epoch_end(self, run_context):
epoch_num = cb_params.cur_epoch_num
if epoch_num == self.stop_epoch:
self.profiler.stop()
self.profiler.analyse()

def on_train_end(self, run_context):
self.profiler.analyse()