Skip to content

Commit f9ba58d

Browse files
committed
add dia
1 parent 8f6e9c9 commit f9ba58d

File tree

14 files changed

+2355
-1
lines changed

14 files changed

+2355
-1
lines changed

mindone/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@
429429
)
430430
from .models.depth_anything import DepthAnythingForDepthEstimation, DepthAnythingPreTrainedModel
431431
from .models.depth_pro import DepthProForDepthEstimation, DepthProImageProcessor, DepthProModel, DepthProPreTrainedModel
432+
from .models.dia import DiaForConditionalGeneration, DiaModel, DiaPreTrainedModel, DiaProcessor
432433
from .models.diffllama import (
433434
DiffLlamaForCausalLM,
434435
DiffLlamaForQuestionAnswering,

mindone/transformers/generation/logits_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3197,7 +3197,7 @@ def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor:
31973197
# Create top k based on the combined CFG output
31983198
_, top_k_indices = mint.topk(scores_processed, k=self.guidance_top_k, dim=-1)
31993199
top_k_mask = mint.ones_like(scores_processed, dtype=ms.bool_)
3200-
top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False)
3200+
top_k_mask = top_k_mask.scatter_(dim=-1, index=top_k_indices, value=False)
32013201
# Only return conditioned logits with top k
32023202
scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf"))
32033203

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
deprecated,
6262
depth_anything,
6363
depth_pro,
64+
dia,
6465
diffllama,
6566
dinov2,
6667
dinov2_with_registers,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
("depth_anything", "DepthAnythingConfig"),
8484
("depth_pro", "DepthProConfig"),
8585
("detr", "DetrConfig"),
86+
("dia", "DiaConfig"),
8687
("diffllama", "DiffLlamaConfig"),
8788
("dinov2", "Dinov2Config"),
8889
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
@@ -352,6 +353,7 @@
352353
("depth_anything", "Depth Anything"),
353354
("depth_pro", "DepthPro"),
354355
("detr", "DETR"),
356+
("dia", "Dia"),
355357
("diffllama", "DiffLlama"),
356358
("dinov2", "DINOv2"),
357359
("dinov2_with_registers", "DINOv2 with Registers"),

mindone/transformers/models/auto/feature_extraction_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
("chinese_clip", "ChineseCLIPFeatureExtractor"),
4747
("convnext", "ConvNextFeatureExtractor"),
4848
("cvt", "ConvNextFeatureExtractor"),
49+
("dia", "DiaFeatureExtractor"),
4950
("flava", "FlavaFeatureExtractor"),
5051
("seamless_m4t", "SeamlessM4TFeatureExtractor"),
5152
]

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
("deit", "DeiTModel"),
8282
("depth_pro", "DepthProModel"),
8383
("detr", "DetrModel"),
84+
("dia", "DiaModel"),
8485
("diffllama", "DiffLlamaModel"),
8586
("dinov2", "Dinov2Model"),
8687
("dinov2_with_registers", "Dinov2WithRegistersModel"),
@@ -353,6 +354,7 @@
353354
("data2vec-text", "Data2VecTextForMaskedLM"),
354355
("deberta", "DebertaForMaskedLM"),
355356
("deberta-v2", "DebertaV2ForMaskedLM"),
357+
("dia", "DiaForConditionalGeneration"),
356358
("distilbert", "DistilBertForMaskedLM"),
357359
("esm", "EsmForMaskedLM"),
358360
("electra", "ElectraForMaskedLM"),
@@ -838,6 +840,7 @@
838840

839841
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
840842
[
843+
("dia", "DiaForConditionalGeneration"),
841844
("moonshine", "MoonshineForConditionalGeneration"),
842845
("pop2piano", "Pop2PianoForConditionalGeneration"),
843846
("seamless_m4t", "SeamlessM4TForSpeechToText"),
@@ -1274,6 +1277,12 @@
12741277
]
12751278
)
12761279

1280+
MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
1281+
[
1282+
("dac", "DacModel"),
1283+
]
1284+
)
1285+
12771286

12781287
if version.parse(transformers.__version__) >= version.parse("4.51.0"):
12791288
MODEL_MAPPING_NAMES.update({"qwen3_moe": "Qwen3MoeModel"})
@@ -1424,6 +1433,7 @@
14241433

14251434
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
14261435

1436+
MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
14271437

14281438
class AutoModelForMaskGeneration(_BaseAutoModelClass):
14291439
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
@@ -1708,6 +1718,15 @@ class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
17081718
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
17091719

17101720

1721+
class AutoModelForAudioTokenization(_BaseAutoModelClass):
1722+
_model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
1723+
1724+
1725+
AutoModelForAudioTokenization = auto_class_update(
1726+
AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
1727+
)
1728+
1729+
17111730
class AutoModelWithLMHead(_AutoModelWithLMHead):
17121731
@classmethod
17131732
def from_config(cls, config):

mindone/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
("chameleon", "ChameleonProcessor"),
5656
("chinese_clip", "ChineseCLIPProcessor"),
5757
("colpali", "ColPaliProcessor"),
58+
("dia", "DiaProcessor"),
5859
("flava", "FlavaProcessor"),
5960
("idefics", "IdeficsProcessor"),
6061
("instructblip", "InstructBlipProcessor"),
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
from .feature_extraction_dia import *
18+
from .generation_dia import *
19+
from .modeling_dia import *
20+
from .processing_dia import *
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# coding=utf-8
2+
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is adapted from https://github.com/huggingface/transformers
5+
# with modifications to run transformers on mindspore.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""Feature extractor class for Dia"""
19+
20+
from typing import Optional, Union
21+
22+
import numpy as np
23+
24+
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
25+
from ...feature_extraction_utils import BatchFeature
26+
from ...utils import PaddingStrategy, TensorType, logging
27+
28+
29+
logger = logging.get_logger(__name__)
30+
31+
32+
class DiaFeatureExtractor(SequenceFeatureExtractor):
33+
r"""
34+
Constructs an Dia feature extractor.
35+
36+
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
37+
most of the main methods. Users should refer to this superclass for more information regarding those methods.
38+
39+
Args:
40+
feature_size (`int`, *optional*, defaults to 1):
41+
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
42+
sampling_rate (`int`, *optional*, defaults to 16000):
43+
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
44+
padding_value (`float`, *optional*, defaults to 0.0):
45+
The value that is used for padding.
46+
hop_length (`int`, *optional*, defaults to 512):
47+
Overlap length between successive windows.
48+
"""
49+
50+
model_input_names = ["input_values", "n_quantizers"]
51+
52+
def __init__(
53+
self,
54+
feature_size: int = 1,
55+
sampling_rate: int = 16000,
56+
padding_value: float = 0.0,
57+
hop_length: int = 512,
58+
**kwargs,
59+
):
60+
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
61+
self.hop_length = hop_length
62+
63+
def __call__(
64+
self,
65+
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
66+
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
67+
truncation: Optional[bool] = False,
68+
max_length: Optional[int] = None,
69+
return_tensors: Optional[Union[str, TensorType]] = None,
70+
sampling_rate: Optional[int] = None,
71+
) -> BatchFeature:
72+
"""
73+
Main method to featurize and prepare for the model one or several sequence(s).
74+
75+
Args:
76+
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
77+
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
78+
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
79+
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
80+
(`feature_size = 2`).
81+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
82+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
83+
index) among:
84+
85+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
86+
sequence if provided).
87+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
88+
acceptable input length for the model if that argument is not provided.
89+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
90+
lengths).
91+
truncation (`bool`, *optional*, defaults to `False`):
92+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
93+
max_length (`int`, *optional*):
94+
Maximum length of the returned list and optionally padding length (see above).
95+
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
96+
If set, will return tensors instead of list of python integers. Acceptable values are:
97+
98+
- `'tf'`: Return TensorFlow `tf.constant` objects.
99+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
100+
- `'np'`: Return Numpy `np.ndarray` objects.
101+
sampling_rate (`int`, *optional*):
102+
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
103+
`sampling_rate` at the forward call to prevent silent errors.
104+
"""
105+
if sampling_rate is not None:
106+
if sampling_rate != self.sampling_rate:
107+
raise ValueError(
108+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
109+
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
110+
f" {self.sampling_rate} and not {sampling_rate}."
111+
)
112+
else:
113+
logger.warning(
114+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
115+
"Failing to do so can result in silent errors that might be hard to debug."
116+
)
117+
118+
if padding and truncation:
119+
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
120+
elif padding is None:
121+
# by default let's pad the inputs
122+
padding = True
123+
124+
is_batched = bool(
125+
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
126+
)
127+
128+
if is_batched:
129+
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
130+
elif not is_batched and not isinstance(raw_audio, np.ndarray):
131+
raw_audio = np.asarray(raw_audio, dtype=np.float32)
132+
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
133+
raw_audio = raw_audio.astype(np.float32)
134+
135+
# always return batch
136+
if not is_batched:
137+
raw_audio = [np.asarray(raw_audio).T]
138+
139+
# convert stereo to mono if necessary, unique to Dia
140+
for idx, example in enumerate(raw_audio):
141+
if self.feature_size == 2 and example.ndim == 2:
142+
raw_audio[idx] = np.mean(example, -1)
143+
144+
# verify inputs are valid
145+
for idx, example in enumerate(raw_audio):
146+
if example.ndim > 2:
147+
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
148+
if self.feature_size == 1 and example.ndim != 1:
149+
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
150+
if self.feature_size == 2 and example.ndim != 1: # note the conversion before
151+
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
152+
153+
input_values = BatchFeature({"input_values": raw_audio})
154+
155+
# temporarily treat it as if we were mono as we also convert stereo to mono
156+
origingal_feature_size = self.feature_size
157+
self.feature_size = 1
158+
159+
# normal padding on batch
160+
padded_inputs = self.pad(
161+
input_values,
162+
max_length=max_length,
163+
truncation=truncation,
164+
padding=padding,
165+
return_attention_mask=True,
166+
pad_to_multiple_of=self.hop_length,
167+
)
168+
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
169+
170+
input_values = []
171+
for example in padded_inputs.pop("input_values"):
172+
if self.feature_size == 1:
173+
example = example[..., None]
174+
input_values.append(example.T)
175+
176+
padded_inputs["input_values"] = input_values
177+
if return_tensors is not None:
178+
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
179+
180+
# rewrite back to original feature size
181+
self.feature_size = origingal_feature_size
182+
183+
return padded_inputs
184+
185+
186+
__all__ = ["DiaFeatureExtractor"]

0 commit comments

Comments
 (0)