|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. |
| 3 | +# |
| 4 | +# This code is adapted from https://github.com/huggingface/transformers |
| 5 | +# with modifications to run transformers on mindspore. |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | +"""Feature extractor class for Dia""" |
| 19 | + |
| 20 | +from typing import Optional, Union |
| 21 | + |
| 22 | +import numpy as np |
| 23 | + |
| 24 | +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor |
| 25 | +from ...feature_extraction_utils import BatchFeature |
| 26 | +from ...utils import PaddingStrategy, TensorType, logging |
| 27 | + |
| 28 | + |
| 29 | +logger = logging.get_logger(__name__) |
| 30 | + |
| 31 | + |
| 32 | +class DiaFeatureExtractor(SequenceFeatureExtractor): |
| 33 | + r""" |
| 34 | + Constructs an Dia feature extractor. |
| 35 | +
|
| 36 | + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains |
| 37 | + most of the main methods. Users should refer to this superclass for more information regarding those methods. |
| 38 | +
|
| 39 | + Args: |
| 40 | + feature_size (`int`, *optional*, defaults to 1): |
| 41 | + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. |
| 42 | + sampling_rate (`int`, *optional*, defaults to 16000): |
| 43 | + The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). |
| 44 | + padding_value (`float`, *optional*, defaults to 0.0): |
| 45 | + The value that is used for padding. |
| 46 | + hop_length (`int`, *optional*, defaults to 512): |
| 47 | + Overlap length between successive windows. |
| 48 | + """ |
| 49 | + |
| 50 | + model_input_names = ["input_values", "n_quantizers"] |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + feature_size: int = 1, |
| 55 | + sampling_rate: int = 16000, |
| 56 | + padding_value: float = 0.0, |
| 57 | + hop_length: int = 512, |
| 58 | + **kwargs, |
| 59 | + ): |
| 60 | + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) |
| 61 | + self.hop_length = hop_length |
| 62 | + |
| 63 | + def __call__( |
| 64 | + self, |
| 65 | + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], |
| 66 | + padding: Optional[Union[bool, str, PaddingStrategy]] = None, |
| 67 | + truncation: Optional[bool] = False, |
| 68 | + max_length: Optional[int] = None, |
| 69 | + return_tensors: Optional[Union[str, TensorType]] = None, |
| 70 | + sampling_rate: Optional[int] = None, |
| 71 | + ) -> BatchFeature: |
| 72 | + """ |
| 73 | + Main method to featurize and prepare for the model one or several sequence(s). |
| 74 | +
|
| 75 | + Args: |
| 76 | + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): |
| 77 | + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float |
| 78 | + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape |
| 79 | + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio |
| 80 | + (`feature_size = 2`). |
| 81 | + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): |
| 82 | + Select a strategy to pad the returned sequences (according to the model's padding side and padding |
| 83 | + index) among: |
| 84 | +
|
| 85 | + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
| 86 | + sequence if provided). |
| 87 | + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum |
| 88 | + acceptable input length for the model if that argument is not provided. |
| 89 | + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different |
| 90 | + lengths). |
| 91 | + truncation (`bool`, *optional*, defaults to `False`): |
| 92 | + Activates truncation to cut input sequences longer than `max_length` to `max_length`. |
| 93 | + max_length (`int`, *optional*): |
| 94 | + Maximum length of the returned list and optionally padding length (see above). |
| 95 | + return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): |
| 96 | + If set, will return tensors instead of list of python integers. Acceptable values are: |
| 97 | +
|
| 98 | + - `'tf'`: Return TensorFlow `tf.constant` objects. |
| 99 | + - `'pt'`: Return PyTorch `torch.Tensor` objects. |
| 100 | + - `'np'`: Return Numpy `np.ndarray` objects. |
| 101 | + sampling_rate (`int`, *optional*): |
| 102 | + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass |
| 103 | + `sampling_rate` at the forward call to prevent silent errors. |
| 104 | + """ |
| 105 | + if sampling_rate is not None: |
| 106 | + if sampling_rate != self.sampling_rate: |
| 107 | + raise ValueError( |
| 108 | + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" |
| 109 | + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" |
| 110 | + f" {self.sampling_rate} and not {sampling_rate}." |
| 111 | + ) |
| 112 | + else: |
| 113 | + logger.warning( |
| 114 | + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " |
| 115 | + "Failing to do so can result in silent errors that might be hard to debug." |
| 116 | + ) |
| 117 | + |
| 118 | + if padding and truncation: |
| 119 | + raise ValueError("Both padding and truncation were set. Make sure you only set one.") |
| 120 | + elif padding is None: |
| 121 | + # by default let's pad the inputs |
| 122 | + padding = True |
| 123 | + |
| 124 | + is_batched = bool( |
| 125 | + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) |
| 126 | + ) |
| 127 | + |
| 128 | + if is_batched: |
| 129 | + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] |
| 130 | + elif not is_batched and not isinstance(raw_audio, np.ndarray): |
| 131 | + raw_audio = np.asarray(raw_audio, dtype=np.float32) |
| 132 | + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): |
| 133 | + raw_audio = raw_audio.astype(np.float32) |
| 134 | + |
| 135 | + # always return batch |
| 136 | + if not is_batched: |
| 137 | + raw_audio = [np.asarray(raw_audio).T] |
| 138 | + |
| 139 | + # convert stereo to mono if necessary, unique to Dia |
| 140 | + for idx, example in enumerate(raw_audio): |
| 141 | + if self.feature_size == 2 and example.ndim == 2: |
| 142 | + raw_audio[idx] = np.mean(example, -1) |
| 143 | + |
| 144 | + # verify inputs are valid |
| 145 | + for idx, example in enumerate(raw_audio): |
| 146 | + if example.ndim > 2: |
| 147 | + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") |
| 148 | + if self.feature_size == 1 and example.ndim != 1: |
| 149 | + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") |
| 150 | + if self.feature_size == 2 and example.ndim != 1: # note the conversion before |
| 151 | + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") |
| 152 | + |
| 153 | + input_values = BatchFeature({"input_values": raw_audio}) |
| 154 | + |
| 155 | + # temporarily treat it as if we were mono as we also convert stereo to mono |
| 156 | + origingal_feature_size = self.feature_size |
| 157 | + self.feature_size = 1 |
| 158 | + |
| 159 | + # normal padding on batch |
| 160 | + padded_inputs = self.pad( |
| 161 | + input_values, |
| 162 | + max_length=max_length, |
| 163 | + truncation=truncation, |
| 164 | + padding=padding, |
| 165 | + return_attention_mask=True, |
| 166 | + pad_to_multiple_of=self.hop_length, |
| 167 | + ) |
| 168 | + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") |
| 169 | + |
| 170 | + input_values = [] |
| 171 | + for example in padded_inputs.pop("input_values"): |
| 172 | + if self.feature_size == 1: |
| 173 | + example = example[..., None] |
| 174 | + input_values.append(example.T) |
| 175 | + |
| 176 | + padded_inputs["input_values"] = input_values |
| 177 | + if return_tensors is not None: |
| 178 | + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) |
| 179 | + |
| 180 | + # rewrite back to original feature size |
| 181 | + self.feature_size = origingal_feature_size |
| 182 | + |
| 183 | + return padded_inputs |
| 184 | + |
| 185 | + |
| 186 | +__all__ = ["DiaFeatureExtractor"] |
0 commit comments