diff --git a/mindnlp/transformers/models/jukebox/__init__.py b/mindnlp/transformers/models/jukebox/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/transformers/models/jukebox/__pycache__/__init__.cpython-39.pyc b/mindnlp/transformers/models/jukebox/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 000000000..825224ba9 Binary files /dev/null and b/mindnlp/transformers/models/jukebox/__pycache__/__init__.cpython-39.pyc differ diff --git a/mindnlp/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc b/mindnlp/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc new file mode 100644 index 000000000..fe6f433cf Binary files /dev/null and b/mindnlp/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc differ diff --git a/mindnlp/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc b/mindnlp/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc new file mode 100644 index 000000000..f3574cf03 Binary files /dev/null and b/mindnlp/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc differ diff --git a/mindnlp/transformers/models/jukebox/configuration_jukebox.py b/mindnlp/transformers/models/jukebox/configuration_jukebox.py new file mode 100644 index 000000000..6f981cd65 --- /dev/null +++ b/mindnlp/transformers/models/jukebox/configuration_jukebox.py @@ -0,0 +1,618 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Jukebox configuration""" + +import os +from typing import List, Union + +from mindnlp.utils import logging +from ...configuration_utils import PretrainedConfig + +logger = logging.get_logger(__name__) + + +_LARGE_ATTENTION = [ + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", +] +_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"] +_FullDenseAttention = ["dense_attention"] +_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"] + + +def full_dense_attention(layer): + return _FullDenseAttention[0] + + +def raw_column_previous_row_attention(layer): + return _RawColumnPreviousRowAttention[layer % 3] + + +def large_separated_enc_dec_w_lyrics(layer): + return _LARGE_ATTENTION[layer % 79] + + +def enc_dec_with_lyrics(layer): + if layer % 16 == 15: + return _PrimePrimeDenseAttention[layer % 3] + return _RawColumnPreviousRowAttention[layer % 3] + + +ATTENTION_PATTERNS = { + "full_dense_attention": full_dense_attention, + "raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn + "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics + "enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics +} + + +class JukeboxPriorConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a + `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the top level prior from the + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox + -1b-lyrics) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + + Args: + act_fn (`str`, *optional*, defaults to `"quick_gelu"`): + Activation function. + alignment_head (`int`, *optional*, defaults to 2): + Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio + alignment + alignment_layer (`int`, *optional*, defaults to 68): + Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the + lyric to audio alignment + attention_multiplier (`float`, *optional*, defaults to 0.25): + Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that + 0.25*width of the model will be used. + attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`): + Which attention pattern to use for the decoder/ + attn_dropout (`int`, *optional*, defaults to 0): + Dropout probability for the post-attention layer dropout in the decoder. + attn_res_scale (`bool`, *optional*, defaults to `False`): + Whether or not to scale the residuals in the attention conditioner block. + blocks (`int`, *optional*, defaults to 64): + Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len // + blocks]` in the `JukeboxAttention` layer. + conv_res_scale (`int`, *optional*): + Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a + conditioner, the default value is to None and should not be modified. + num_layers (`int`, *optional*, defaults to 72): + Number of layers of the transformer architecture. + emb_dropout (`int`, *optional*, defaults to 0): + Embedding dropout used in the lyric decoder. + encoder_config (`JukeboxPriorConfig`, *optional*) : + Configuration of the encoder which models the prior on the lyrics. + encoder_loss_fraction (`float`, *optional*, defaults to 0.4): + Multiplication factor used in front of the lyric encoder loss. + hidden_size (`int`, *optional*, defaults to 2048): + Hidden dimension of the attention layers. + init_scale (`float`, *optional*, defaults to 0.2): + Initialization scales for the prior modules. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is + greater than 0, the `encoder` args should be specified for the lyric encoding. + mask (`bool`, *optional*, defaults to `False`): + Whether or not to mask the previous positions in the attention. + max_duration (`int`, *optional*, defaults to 600): + Maximum supported duration of the generated song in seconds. + max_nb_genres (`int`, *optional*, defaults to 1): + Maximum number of genres that can be used to condition the model. + merged_decoder (`bool`, *optional*, defaults to `True`): + Whether or not the decoder and the encoder inputs are merged. This is used for the separated + encoder-decoder architecture + metadata_conditioning (`bool`, *optional*, defaults to `True)`: + Whether or not to condition on the artist and genre metadata. + metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`): + Number of genres and the number of artists that were used to train the embedding layers of the prior + models. + min_duration (`int`, *optional*, defaults to 0): + Minimum duration of the generated audio on which the model was trained. + mlp_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of + the model will be used. + music_vocab_size (`int`, *optional*, defaults to 2048): + Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`. + n_ctx (`int`, *optional*, defaults to 6144): + Number of context tokens for each prior. The context tokens are the music tokens that are attended to when + generating music tokens. + n_heads (`int`, *optional*, defaults to 2): + Number of attention heads. + nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384): + Number of lyric tokens that are used when sampling a single window of length `n_ctx` + res_conv_depth (`int`, *optional*, defaults to 3): + Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the + `JukeboxMusicTokenConditioner`. + res_conv_width (`int`, *optional*, defaults to 128): + Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the + `JukeboxMusicTokenConditioner`. + res_convolution_multiplier (`int`, *optional*, defaults to 1): + Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`. + res_dilation_cycle (`int`, *optional*): + Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the + corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level + tokens. + res_dilation_growth_rate (`int`, *optional*, defaults to 1): + Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner` + res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`): + Downsampling rates used in the audio conditioning network + res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + Striding used in the audio conditioning network + resid_dropout (`int`, *optional*, defaults to 0): + Residual dropout used in the attention pattern. + sampling_rate (`int`, *optional*, defaults to 44100): + Sampling rate used for training. + spread (`int`, *optional*): + Spread used in the `summary_spread_attention` pattern + timing_dims (`int`, *optional*, defaults to 64): + Dimension of the timing embedding. + zero_out (`bool`, *optional*, defaults to `False`): + Whether or not to zero out convolution weights when initializing. + """ + + model_type = "jukebox_prior" + attribute_map = { + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + } + + def __init__( + self, + act_fn="quick_gelu", + level=0, + alignment_head=2, + alignment_layer=68, + attention_multiplier=0.25, + attention_pattern="enc_dec_with_lyrics", + attn_dropout=0, + attn_res_scale=False, + blocks=64, + conv_res_scale=None, + num_layers=72, + emb_dropout=0, + encoder_config=None, + encoder_loss_fraction=0.4, + hidden_size=2048, + init_scale=0.2, + is_encoder_decoder=True, + lyric_vocab_size=80, + mask=False, + max_duration=600, + max_nb_genres=1, + merged_decoder=True, + metadata_conditioning=True, + metadata_dims=[604, 7898], + min_duration=0, + mlp_multiplier=1.0, + music_vocab_size=2048, + n_ctx=6144, + n_heads=2, + nb_relevant_lyric_tokens=384, + res_conv_depth=3, + res_conv_width=128, + res_convolution_multiplier=1, + res_dilation_cycle=None, + res_dilation_growth_rate=1, + res_downs_t=[3, 2, 2], + res_strides_t=[2, 2, 2], + resid_dropout=0, + sampling_rate=44100, + spread=None, + timing_dims=64, + zero_out=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.act_fn = act_fn + self.alignment_head = alignment_head + self.alignment_layer = alignment_layer + self.attention_multiplier = attention_multiplier + self.attention_pattern = attention_pattern + self.attn_dropout = attn_dropout + self.attn_res_scale = attn_res_scale + self.blocks = blocks + self.conv_res_scale = conv_res_scale + self.num_layers = num_layers + self.emb_dropout = emb_dropout + self.music_vocab_size = music_vocab_size + if encoder_config is not None: + self.encoder_config = JukeboxPriorConfig(**encoder_config) + else: + self.encoder_config = None + self.encoder_loss_fraction = encoder_loss_fraction + self.init_scale = init_scale + self.is_encoder_decoder = is_encoder_decoder + self.lyric_vocab_size = lyric_vocab_size + self.level = level + self.mask = mask + self.max_duration = max_duration + self.max_nb_genres = max_nb_genres + self.merged_decoder = merged_decoder + self.metadata_conditioning = metadata_conditioning + self.metadata_dims = metadata_dims + self.min_duration = min_duration + self.mlp_multiplier = mlp_multiplier + self.n_ctx = n_ctx + self.n_heads = n_heads + self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens + self.res_conv_depth = res_conv_depth + self.res_conv_width = res_conv_width + self.res_convolution_multiplier = res_convolution_multiplier + self.res_dilation_cycle = res_dilation_cycle + self.res_dilation_growth_rate = res_dilation_growth_rate + self.res_downs_t = res_downs_t + self.res_strides_t = res_strides_t + self.resid_dropout = resid_dropout + self.sampling_rate = sampling_rate + self.spread = spread + self.timing_dims = timing_dims + self.hidden_size = hidden_size + self.zero_out = zero_out + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the prior config dict if we are loading from JukeboxConfig + if config_dict.get("model_type") == "jukebox": + config_dict = config_dict[f"prior_{level}"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class JukeboxVQVAEConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a + `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VQVAE from + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function of the model. + nb_discrete_codes (`int`, *optional*, defaults to 2048): + Number of codes of the VQVAE. + commit (`float`, *optional*, defaults to 0.02): + Commit loss multiplier. + conv_input_shape (`int`, *optional*, defaults to 1): + Number of audio channels. + conv_res_scale (`bool`, *optional*, defaults to `False`): + Whether or not to scale the residuals of the `JukeboxResConv1DBlock`. + embed_dim (`int`, *optional*, defaults to 64): + Embedding dimension of the codebook vectors. + hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`): + Fraction of non-intersecting window used when continuing the sampling process. + levels (`int`, *optional*, defaults to 3): + Number of hierarchical levels that used in the VQVAE. + lmu (`float`, *optional*, defaults to 0.99): + Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1 + of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf) + multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`): + Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth` + res_conv_depth (`int`, *optional*, defaults to 4): + Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level. + res_conv_width (`int`, *optional*, defaults to 32): + Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level. + res_convolution_multiplier (`int`, *optional*, defaults to 1): + Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`. + res_dilation_cycle (`int`, *optional*): + Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth + reduced by a power of `res_dilation_cycle`. + res_dilation_growth_rate (`int`, *optional*, defaults to 3): + Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth) + res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`): + Downsampling rate for each level of the hierarchical VQ-VAE. + res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + Stride used for each level of the hierarchical VQ-VAE. + sample_length (`int`, *optional*, defaults to 1058304): + Provides the max input shape of the VQVAE. Is used to compute the input shape of each level. + init_scale (`float`, *optional*, defaults to 0.2): + Initialization scale. + zero_out (`bool`, *optional*, defaults to `False`): + Whether or not to zero out convolution weights when initializing. + """ + + model_type = "jukebox_vqvae" + + def __init__( + self, + act_fn="relu", + nb_discrete_codes=2048, + commit=0.02, + conv_input_shape=1, + conv_res_scale=False, + embed_dim=64, + hop_fraction=[0.125, 0.5, 0.5], + levels=3, + lmu=0.99, + multipliers=[2, 1, 1], + res_conv_depth=4, + res_conv_width=32, + res_convolution_multiplier=1, + res_dilation_cycle=None, + res_dilation_growth_rate=3, + res_downs_t=[3, 2, 2], + res_strides_t=[2, 2, 2], + sample_length=1058304, + init_scale=0.2, + zero_out=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hop_fraction = hop_fraction + self.conv_input_shape = conv_input_shape + self.sample_length = sample_length + + # VQVAE parameters (all used) + self.levels = levels + self.embed_dim = embed_dim + self.nb_discrete_codes = nb_discrete_codes + self.res_conv_width = res_conv_width + self.res_conv_depth = res_conv_depth + self.res_convolution_multiplier = res_convolution_multiplier + self.res_dilation_growth_rate = res_dilation_growth_rate + self.res_dilation_cycle = res_dilation_cycle + self.multipliers = multipliers + self.res_downs_t = res_downs_t + self.res_strides_t = res_strides_t + self.lmu = lmu + self.commit = commit + self.conv_res_scale = conv_res_scale + self.act_fn = act_fn + self.init_scale = init_scale + self.zero_out = zero_out + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "jukebox": + config_dict = config_dict["vqvae_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class JukeboxConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxModel`]. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will + yield a similar configuration to that of + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture. + + + The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling = + (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256 + to get the second level codes. This is mostly true for training the top level prior and the upsamplers. + + Args: + vqvae_config (`JukeboxVQVAEConfig`, *optional*): + Configuration for the `JukeboxVQVAE` model. + prior_config_list (`List[JukeboxPriorConfig]`, *optional*): + List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors. + nb_priors (`int`, *optional*, defaults to 3): + Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive + (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were + trained using a top prior and 2 upsampler priors. + sampling_rate (`int`, *optional*, defaults to 44100): + Sampling rate of the raw audio. + timing_dims (`int`, *optional*, defaults to 64): + Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding + layer. The timing embedding layer converts the absolute and relative position in the currently sampled + audio to a tensor of length `timing_dims` that will be added to the music tokens. + min_duration (`int`, *optional*, defaults to 0): + Minimum duration of the audios to generate + max_duration (`float`, *optional*, defaults to 600.0): + Maximum duration of the audios to generate + max_nb_genres (`int`, *optional*, defaults to 5): + Maximum number of genres that can be used to condition a single sample. + metadata_conditioning (`bool`, *optional*, defaults to `True`): + Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum + duration. + + Example: + + ```python + >>> from transformers import JukeboxModel, JukeboxConfig + + >>> # Initializing a Jukebox configuration + >>> configuration = JukeboxConfig() + + >>> # Initializing a model from the configuration + >>> model = JukeboxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "jukebox" + + def __init__( + self, + vqvae_config=None, + prior_config_list=None, + nb_priors=3, + sampling_rate=44100, + timing_dims=64, + min_duration=0, + max_duration=600.0, + max_nb_genres=5, + metadata_conditioning=True, + **kwargs, + ): + if vqvae_config is None: + vqvae_config = {} + logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.") + + self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config) + if prior_config_list is not None: + self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list] + else: + self.prior_configs = [] + for prior_idx in range(nb_priors): + prior_config = kwargs.pop(f"prior_{prior_idx}", None) + if prior_config is None: + prior_config = {} + logger.info( + f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default" + " values." + ) + self.prior_configs.append(JukeboxPriorConfig(**prior_config)) + + self.hop_fraction = self.vqvae_config.hop_fraction + + self.nb_priors = nb_priors + + # Metadata conditioning + self.max_nb_genres = max_nb_genres + self.sampling_rate = sampling_rate + self.timing_dims = timing_dims + self.min_duration = min_duration + self.max_duration = max_duration + self.metadata_conditioning = metadata_conditioning + + super().__init__(**kwargs) + + @classmethod + def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs): + r""" + Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`JukeboxConfig`]: An instance of a configuration object + """ + prior_config_list = [config.to_dict() for config in prior_configs] + return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs) + + def to_dict(self): + # Override the default to_dict to apply to_dict to the list of prior configs. + result = super().to_dict() + result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")] + return result +__all__ = [ + "JukeboxConfig", + "JukeboxPriorConfig", + "JukeboxVQVAEConfig", + ] diff --git a/mindnlp/transformers/models/jukebox/modeling_jukebox.py b/mindnlp/transformers/models/jukebox/modeling_jukebox.py new file mode 100644 index 000000000..1a507e925 --- /dev/null +++ b/mindnlp/transformers/models/jukebox/modeling_jukebox.py @@ -0,0 +1,2591 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mindspore Jukebox model.""" + +import math +import os +from typing import List, Optional, Tuple + +import numpy as np +import mindspore +import mindnlp.core.nn.functional as F +from mindnlp.core import nn, ops, no_grad, distributions +from mindnlp.core.nn import LayerNorm as FusedLayerNorm + +from ....common.activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ....utils import logging +from ....utils.logging import tqdm +from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig + + +logger = logging.get_logger(__name__) + + +def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + + Args: + logits (`mindspore.Tensor`): + logits distribution shape (vocabulary size) + top_k (`int`, *optional*, defaults to 0): + When `top_k >0` keep only top key tokens with highest probability (top-k filtering). + top_p (`int`, *optional*, defaults to 0): + When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering). + """ + logits = logits.clone() + top_k = min(top_k, logits.shape[-1]) # Safety check + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < ops.topk(logits, top_k, dim=-1)[0][..., -1:] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = ops.sort(logits, descending=True, dim=-1) + cumulative_probs = ops.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # indices_to_remove = sorted_indices[sorted_indices_to_remove] + indices_to_remove = ops.zeros_like(logits, dtype=mindspore.bool_).scatter_( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration): + """ + Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be + returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the + midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on + the most relevant tokens (in time) for the sequence. + + Args: + full_tokens (`List[int]`): + List containing the token ids of the entire lyrics. + total_length (`int`): + Total expected length of the music (not all of it is generated, see duration), in samples. + offset (`int`): + Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into + account + duration (`int`): + Expected duration of the generated music, in samples. The duration has to be smaller than the total length, + which represent the overall length of the signal, + """ + full_tokens = full_tokens[0] + if len(full_tokens) < max_n_lyric_tokens: + tokens = ops.cat( + [ops.zeros(max_n_lyric_tokens - len(full_tokens), dtype=mindspore.int64).to(full_tokens), full_tokens] + ) + indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens))) + else: + midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length) + midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2) + tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2] + indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2)) + return tokens.unsqueeze(dim=0), indices + + +# Break total_length into hops/windows of size n_ctx separated by hop_length +def get_starts(total_length, n_ctx, hop_length): + starts = [] + for start in range(0, total_length - n_ctx + hop_length, hop_length): + if start + n_ctx >= total_length: + # Last hop could be smaller, we make it n_ctx to maximise context + start = total_length - n_ctx + starts.append(start) + return starts + + +def get_alignment(music_tokens, labels, prior, config): + level = prior.levels - 1 # Top level used + n_ctx = prior.n_ctx + tokens = music_tokens[level] + batch_size, total_length = tokens.shape[0], tokens.shape[1] + if total_length < n_ctx: + padding_length = n_ctx - total_length + tokens = ops.cat( + [tokens, ops.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype)], dim=1 + ) + total_length = tokens.shape[1] + else: + padding_length = 0 + + hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx) + alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0] + attn_layers = {alignment_layer} + alignment_hops = {} + indices_hops = {} + for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "): + end = start + n_ctx + # set metadata offset, sample_length and lyrics tokens + metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0) + tokens_bs = ops.chunk(tokens, batch_size, dim=0) + metadata_bs = ops.chunk(metadata, batch_size, dim=0) + w_hops = [] + for tokens_i, metadata_i in zip(tokens_bs, metadata_bs): + w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers) + w_hops.append(w_hop[0][:, alignment_head]) + del w_hop + weights = ops.cat(w_hops, dim=0) + del w_hops + alignment_hop = weights.float().cpu().numpy() + del weights + + # alignment_hop has shape (bs, n_ctx, nb_relevant_lyric_tokens) + # indices_hop is a list of len=bs, each entry of len hps.nb_relevant_lyric_tokens + indices_hops[start] = indices_hop + alignment_hops[start] = alignment_hop + + # Combine attn for each hop into attn for full range + # Use indices to place them into correct place for corresponding source tokens + alignments = [] + for item in range(batch_size): + # Note each item has different length lyrics + full_tokens = labels[0, 3:] + alignment = np.zeros((total_length, len(full_tokens) + 1)) + for start in reversed(get_starts(total_length, n_ctx, hop_length)): + end = start + n_ctx + alignment_hop = alignment_hops[start][item] + indices = indices_hops[start][item] + alignment[start:end, indices] = alignment_hop + alignment = alignment[: total_length - padding_length, :-1] # remove token padding, and last lyric index + alignments.append(alignment) + return alignments + + +def save_temp_audio(fname, lvl, metas, aud): + aud = ops.clamp(aud, -1, 1).cpu().numpy() + for i in list(range(aud.shape[0])): + if metas is not None: + artists, genres, lyrics = list(metas)[i].values() + path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}" + np.save(path, aud[i]) + else: + np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i]) + + +def get_mask(mask, query_length, key_value_length, blocks, spread, sample, sample_t): + # returns a mask of shape 1 x 1 x query_length x key_value_length or None if masking is not needed. + if mask is None or query_length == 1: + return None + offset = sample_t - query_length if sample else max(key_value_length - query_length, 0) + if mask == "autoregressive": + # Masked dense + mask = ops.ones(query_length, key_value_length).tril(offset) + elif mask == "summary": + # Masked summary + mask = ops.ones(query_length, query_length).tril() + mask = ops.ones(query_length, query_length).tril() + mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :] + mask = ( + ops.pad( + mask, + (0, 0, 1, 0), + value=1, + ) + .contiguous() + .view(query_length, key_value_length) + ) + elif mask == "prime": + mask = ops.ones(query_length, key_value_length).tril(offset) + return mask.view(1, 1, query_length, key_value_length) + + +class JukeboxConv1D(nn.Module): + def __init__(self, input_width, output_width): + super().__init__() + self.input_width = input_width + self.output_width = output_width + weight = ops.zeros(input_width, output_width) + bias = ops.zeros(output_width) + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) + + def forward(self, hidden_states): + size_out = (*hidden_states.shape[:-1], self.output_width) + hidden_states = ops.addmm( + self.bias.type_as(hidden_states), + hidden_states.view(-1, hidden_states.shape[-1]), + self.weight.type_as(hidden_states), + ) + hidden_states = hidden_states.view(*size_out) + return hidden_states + + +class JukeboxResConv1DBlock(nn.Module): + def __init__(self, config, conv_width, depth=1, res_scale=1.0): + super().__init__() + hidden_dim = config.res_convolution_multiplier * conv_width + dilation = config.res_dilation_growth_rate**depth + padding = dilation + + self.res_scale = res_scale + self.activation = nn.ReLU() + self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation) + self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0) + + def forward(self, hidden_states): + residuals = hidden_states + hidden_states = self.activation(hidden_states) + hidden_states = self.conv1d_1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv1d_2(hidden_states) + return residuals + self.res_scale * hidden_states + + +class JukeboxResnet1D(nn.Module): + def __init__(self, config, conv_width, n_depth, reverse_dilation=False): + super().__init__() + self.dilation_cycle = config.res_dilation_cycle + res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth) + + blocks = [] + for depth in range(n_depth): + block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle + blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale)) + + if reverse_dilation: + blocks = blocks[::-1] + self.resnet_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + for block in self.resnet_block: + hidden_states = block(hidden_states) + return hidden_states + + +class JukeboxEncoderConvBlock(nn.Module): + def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t): + super().__init__() + blocks = [] + filter_t = stride_t * 2 + pad_t = stride_t // 2 + if down_t > 0: + for i in range(down_t): + blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t)) + blocks.append(JukeboxResnet1D(config, hidden_dim, depth)) + self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1) + self.downsample_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + for block in self.downsample_block: + hidden_states = block(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + + +class JukeboxEncoder(nn.Module): + def __init__(self, config, width, depth, levels, downs_t, strides_t): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + + iterator = zip(list(range(self.levels)), downs_t, strides_t) + for i, down_t, stride_t in iterator: + self.level_blocks.append( + JukeboxEncoderConvBlock( + config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t + ) + ) + + def forward(self, hidden_states): + all_hidden_states = [] + + # 64, 32, ... + for level in range(self.levels): + level_block = self.level_blocks[level] + hidden_states = level_block(hidden_states) + all_hidden_states.append(hidden_states) + + return all_hidden_states + + +class JukeboxDecoderConvBock(nn.Module): + def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True): + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + super().__init__() + blocks = [] + if down_t > 0: + filter_t = stride_t * 2 + pad_t = stride_t // 2 + self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1) + for i in range(down_t): + blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation)) + blocks.append( + nn.ConvTranspose1d( + hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t + ) + ) + self.upsample_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + for block in self.upsample_block: + hidden_states = block(hidden_states) + return hidden_states + + +class JukeboxDecoder(nn.Module): + def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t): + self.level_blocks.append( + JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t) + ) + + self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1) + + def forward(self, hidden_states, all_levels=True): + hidden_state = hidden_states[-1] + + # 32, 64 ... + for level in reversed(range(self.levels)): + level_block = self.level_blocks[level] + hidden_state = level_block(hidden_state) + + if level != 0 and all_levels: + hidden_state = hidden_state + hidden_states[level - 1] + + hidden_state = self.out(hidden_state) + return hidden_state + + +class JukeboxBottleneckBlock(nn.Module): + def __init__(self, config: JukeboxVQVAEConfig): + super().__init__() + self.nb_discrete_codes = config.nb_discrete_codes + self.codebook_width = config.embed_dim + self.mu = config.lmu + self.threshold = 1.0 + self.init = False + self.codebook_sum = None + self.codebook_elem = None + self.register_buffer("codebook", ops.zeros(self.nb_discrete_codes, self.codebook_width)) + + def _tile(self, hidden_states): + dim, embed_width = hidden_states.shape + if dim < self.nb_discrete_codes: + n_repeats = (self.nb_discrete_codes + dim - 1) // dim + std = 0.01 / np.sqrt(embed_width) + hidden_states = hidden_states.repeat(n_repeats, 1) + hidden_states = hidden_states + ops.randn_like(hidden_states) * std + return hidden_states + + def init_codebook(self, hidden_states): + nb_discrete_codes = self.nb_discrete_codes + self.init = True + codes = self._tile(hidden_states) + self.codebook = codes[ops.randperm(codes.shape[0])][:nb_discrete_codes] + self.codebook_sum = self.codebook + self.codebook_elem = ops.ones(nb_discrete_codes) + + def update_codebook(self, hidden_states, latent_states): + mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes + with no_grad(): + # Calculate new centres + # nb_discrete_codes, batch_size * seq_length + latent_states_onehot = ops.zeros(nb_discrete_codes, hidden_states.shape[0]) + latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1) + + _codebook_sum = ops.matmul(latent_states_onehot, hidden_states) + _codebook_elem = latent_states_onehot.sum(dim=-1) # nb_discrete_codes + codes = self._tile(hidden_states) + _random_codebook = codes[ops.randperm(codes.shape[0])][:nb_discrete_codes] + + # Update centres + old_codebook = self.codebook + self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum + self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem # nb_discrete_codes + usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float() + + norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view( + nb_discrete_codes, 1 + ) + self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook + _codebook_prob = _codebook_elem / mindspore.ops.sum(_codebook_elem) # prob of each bin + entropy = -mindspore.ops.sum(_codebook_prob * ops.log(_codebook_prob + 1e-8)) # entropy ie how diverse + used_curr = (_codebook_elem >= self.threshold).sum() + usage = mindspore.ops.sum(usage) + dk = ops.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape)) + return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk} + + def preprocess(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1).contiguous() + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if hidden_states.shape[-1] == self.codebook_width: + prenorm = ops.norm(hidden_states - ops.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape)) + elif hidden_states.shape[-1] == 2 * self.codebook_width: + x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :] + prenorm = (ops.norm(x1 - ops.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( + ops.norm(x2 - ops.mean(x2)) / np.sqrt(np.prod(x2.shape)) + ) + + # Normalise + hidden_states = x1 + x2 + + return hidden_states, prenorm + + def postprocess(self, latent_states, dequantised_states, x_shape): + batch_size, time = x_shape + dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous() + latent_states = latent_states.view(batch_size, time) + return latent_states, dequantised_states + + def quantise(self, latent_states): + # Calculate latent code latent_states + codebook_weights = self.codebook.t() + distance = ( + mindspore.ops.sum(latent_states**2, dim=-1, keepdim=True) + - 2 * ops.matmul(latent_states, codebook_weights) + + mindspore.ops.sum(codebook_weights**2, dim=0, keepdim=True) + ) # (batch_size * latent_states , codebook_weights) + min_distance, music_tokens = ops.min(distance,dim=-1) + fit = ops.mean(min_distance) + return music_tokens, fit + + def dequantise(self, music_tokens): + dequantised_states = F.embedding(music_tokens, self.codebook) + return dequantised_states + + def encode(self, latent_states): + samples, _, seq_len = latent_states.shape + + # Preprocess. + latent_states, _ = self.preprocess(latent_states) + + # Quantise + music_tokens, _ = self.quantise(latent_states) + + # Postprocess. + music_tokens = music_tokens.view(samples, seq_len) + return music_tokens + + def decode(self, music_tokens): + samples, seq_len = music_tokens.shape + + # Dequantise + dequantised_states = self.dequantise(music_tokens) + + # Postprocess + dequantised_states = ( + dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous() + ) + return dequantised_states + + def forward(self, hidden_states, update_codebook=True): + samples, _, seq_len = hidden_states.shape + + # Preprocess + hidden_states, prenorm = self.preprocess(hidden_states) + + # Init codebook if not inited + if update_codebook and not self.init: + self.init_codebook(hidden_states) + + # Quantise and dequantise through bottleneck + music_tokens, fit = self.quantise(hidden_states) + dequantised_states = self.dequantise(music_tokens) + + # Update embeddings + if update_codebook: + update_metrics = self.update_codebook(hidden_states, music_tokens) + else: + update_metrics = {} + + # Loss + commit_loss = ops.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape) + + # Passthrough + dequantised_states = hidden_states + (dequantised_states - hidden_states).detach() + + # Postprocess + music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len)) + return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics) + + +class JukeboxBottleneck(nn.Module): + def __init__(self, config, levels): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + for level in range(self.levels): + self.level_blocks.append(JukeboxBottleneckBlock(config)) + + def encode(self, raw_audio): + music_tokens = [ + level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio) + ] + return music_tokens + + def decode(self, music_tokens, start_level=0, end_level=None): + if end_level is None: + end_level = self.levels + quantised_audio = [ + level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens) + ] + return quantised_audio + + def forward(self, input_audio): + music_tokens, quantised_states, commit_losses, metrics = [], [], [], [] + for level in range(self.levels): + level_block = self.level_blocks[-level - 1] + hidden_states = input_audio[level] + sampled_tokens, quantised_state, commit_loss, metric = level_block( + hidden_states, update_codebook=self.training + ) + music_tokens.append(sampled_tokens) + if not self.training: + # Be extra paranoid and make sure the encoder weights can't + # change from straight-through estimator + quantised_state = quantised_state.detach() + quantised_states.append(quantised_state) + commit_losses.append(commit_loss) + if self.training: + metrics.append(metric) + return music_tokens, quantised_states, commit_losses, metrics + + +JUKEBOX_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (`JukeboxConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class JukeboxVQVAE(PreTrainedModel): + config_class = JukeboxVQVAEConfig + base_model_prefix = "vqvae" + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): # embed_tokens + module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + elif isinstance(module, JukeboxConv1D): + if self.config.zero_out: + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: + module.conv1d_2.weight.data.zero_() + module.conv1d_2.bias.data.zero_() + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: JukeboxVQVAEConfig): + super().__init__(config) + downs_t = config.res_downs_t + strides_t = config.res_strides_t + if not config.sample_length: + downsamples = [stride**down for stride, down in zip(strides_t, downs_t)] + top_raw_to_tokens = np.prod(downsamples) + config.sample_length = ( + config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens + ) * top_raw_to_tokens + config.sample_length = config.sample_length.astype(int) + + self.nb_discrete_codes = config.nb_discrete_codes + self.commit = config.commit + self.sample_length = config.sample_length + + self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)] + self.hop_lengths = np.cumprod(self.downsamples) + self.levels = levels = config.levels + self.music_tokens_shapes = [ + (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels) + ] + + self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + for level in range(levels): + width = config.res_conv_width * self.multipliers[level] + depth = config.res_conv_depth * self.multipliers[level] + self.encoders.append( + JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1]) + ) + self.decoders.append( + JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1]) + ) + + self.bottleneck = JukeboxBottleneck(config, levels) + + def _decode(self, music_tokens, start_level=0, end_level=None): + # Decode + if end_level is None: + end_level = self.levels + latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level) + # Use only lowest level + decoder, dequantised_state = self.decoders[start_level], latent_states[0:1] + dequantised_state = decoder(dequantised_state, all_levels=False) + dequantised_state = dequantised_state.permute(0, 2, 1) + return dequantised_state + + def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> mindspore.Tensor: + """ + Transforms the input `music_tokens` to their `raw_audio` representation. + + Args: + music_tokens (`mindspore.Tensor`): + Tensor of music tokens which will be decoded to raw audio by using the codebook. Each music token + should be an index to a corresponding `code` vector in the codebook. + start_level (`int`, *optional*): + Level at which the decoding process will start. Default to 0. + end_level (`int`, *optional*): + Level at which the decoding process will start. Default to None. + bs_chunks (int, *optional*): + Number of chunks to process at the same time. + """ + token_chunks = [ops.chunk(token, bs_chunks, dim=0) for token in music_tokens] + dequantised_states = [] + for i in range(bs_chunks): + music_tokens_i = [chunks[i] for chunks in token_chunks] + dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level) + dequantised_states.append(dequantised_state) + return ops.cat(dequantised_states, dim=0) + + def _encode(self, raw_audio, start_level=0, end_level=None): + # Encode + if end_level is None: + end_level = self.levels + input_audio = raw_audio.permute(0, 2, 1).float() + latent_states = [] + for level in range(self.levels): + encoder = self.encoders[level] + latent_state = encoder(input_audio) + latent_states.append(latent_state[-1]) + music_tokens = self.bottleneck.encode(latent_states) + return music_tokens[start_level:end_level] + + def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1): + """ + Transforms the `input_audio` to a discrete representation made out of `music_tokens`. + + Args: + input_audio (`mindspore.Tensor`): + Raw audio which will be encoded to its discrete representation using the codebook. The closest `code` + form the codebook will be computed for each sequence of samples. + start_level (`int`, *optional*, defaults to 0): + Level at which the encoding process will start. Default to 0. + end_level (`int`, *optional*): + Level at which the encoding process will start. Default to None. + bs_chunks (int, *optional*, defaults to 1): + Number of chunks of raw audio to process at the same time. + """ + audio_chunks = ops.chunk(input_audio, bs_chunks, dim=0) + music_tokens_list = [] + for chunk_i in audio_chunks: + music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level) + music_tokens_list.append(music_tokens_i) + music_tokens = [ops.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)] + return music_tokens + + def sample(self, n_samples): + music_tokens = [ + ops.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape)) + for music_tokens_shape in self.music_tokens_shapes + ] + return self.decode(music_tokens) + + def forward(self, raw_audio: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor]: + """ + Forward pass of the VQ-VAE, encodes the `raw_audio` to latent states, which are then decoded for each level. + The commit loss, which ensure that the encoder's computed embeddings are close to the codebook vectors, is + computed. + + Args: + raw_audio (`mindspore.Tensor`): + Audio input which will be encoded and decoded. + + Returns: + `Tuple[mindspore.Tensor, mindspore.Tensor]` + + + Example: + ```python + >>> from transformers import JukeboxVQVAE, set_seed + >>> import torch + + >>> model = JukeboxVQVAE.from_pretrained("openai/jukebox-1b-lyrics").eval() + >>> set_seed(0) + >>> zs = [torch.randint(100, (4, 1))] + >>> model.decode(zs).shape + torch.shape([4, 8, 1]) + ``` + """ + + # Encode/Decode + input_audio = raw_audio.permute(0, 2, 1).float() + latent_states = [] + for level in range(self.levels): + encoder = self.encoders[level] + latent_state = encoder(input_audio) + latent_states.append(latent_state[-1]) + + _, music_tokens, commit_losses, _ = self.bottleneck(latent_states) + dequantised_states = [] + for level in range(self.levels): + decoder = self.decoders[level] + dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False) + dequantised_states.append(dequantised_state.permute(0, 2, 1)) + + commit_loss = sum(commit_losses) + loss = self.commit * commit_loss + + return dequantised_states, loss + + +class JukeboxMLP(nn.Module): + def __init__(self, config): + # a single channel is always used in original code + super().__init__() + embed_dim = config.hidden_size + hidden_dim = int(config.mlp_multiplier * embed_dim) + + self.c_fc = JukeboxConv1D(embed_dim, hidden_dim) + self.c_proj = JukeboxConv1D(hidden_dim, embed_dim) + self.act = ACT2FN[config.act_fn] + self.dropout = nn.Dropout(config.resid_dropout) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class JukeboxLayerNorm(FusedLayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + self.width = np.prod(normalized_shape) + self.max_numel = 65535 * self.width + + def forward(self, input): + if input.numel() > self.max_numel: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input) + else: + return super().forward(input).type_as(input) + + +class JukeboxAttention(nn.Module): + def __init__(self, config, n_ctx, attn_func="dense_attn"): + super().__init__() + self.embed_dim = config.hidden_size + self.n_heads = config.n_heads + self.dropout = config.attn_dropout + hidden_dim = int(config.attention_multiplier * self.embed_dim) + + self.head_dim = hidden_dim // config.n_heads + self.n_ctx = n_ctx + self.hidden_dim = hidden_dim + self.scale = self.head_dim**-0.25 + self.mask = config.mask + + if attn_func == "cross_attention": + self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim) + self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2) + else: + self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3) + + self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim) + self.attn_dropout = nn.Dropout(config.attn_dropout) + self.resid_dropout = nn.Dropout(config.resid_dropout) + + # Sequence of length seq_len is factored as [blocks, seq_len // blocks] + self.attn_func = attn_func + if attn_func == "cross_attention": + self.qkv = self.decode_qkv + elif attn_func == "prime_attn": + self.qkv = self.prime_qkv + else: + self.qkv = self.factored_qkv + + ATTENTION_MAP = { + "dense_attn": (self.dense_attn, "autoregressive"), + "block_attn": (self.block_attn, "autoregressive"), + "transpose_block_attn": (self.transpose_block_attn, "autoregressive"), + "prev_block_attn": (self.prev_block_attn, None), + "summary_attn": (self.summary_attn, "summary"), + "summary_spread_attn": (self.summary_spread_attn, "summary"), + "cross_attention": (self.dense_attn, None), + "prime_attn": (self.prime_attn, "prime"), + } + self.attn, self.attn_mask = ATTENTION_MAP[attn_func] + + self.blocks = config.blocks + self.spread = config.spread + if self.blocks is not None: + self.block_ctx = self.n_ctx // self.blocks + + self.sample_t = 0 + self.cache = {} + self.encoder_len = config.nb_relevant_lyric_tokens # length of the encoder input ids + self.record_attn = False + + def _attn(self, query_states, key_states, value_states, sample): + scale = self.scale + if self.training: + attention_weight = ops.matmul(query_states * scale, key_states * scale) + else: + attention_weight = ops.matmul(query_states, key_states) + attention_weight.mul_(scale * scale) + attn_weight_type = attention_weight.dtype + attention_weight = attention_weight.float() + if self.mask: + # Generate appropriate mask to mask out all positions before current + # Might take up lot of memory for dense, so can cache it + mask = get_mask( + self.attn_mask, + query_states.shape[-2], + key_states.shape[-1], + self.blocks, + self.spread, + attention_weight, + sample, + self.sample_t, + ) + if mask is not None: + attention_weight = attention_weight * mask + -1e9 * (1 - mask) + attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type) + if self.record_attn: + self.attention_prob = attention_prob + if self.attn_func == "prime_attn": + # only keep music queries and lyrics keys/values + self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len] + attention_prob = self.attn_dropout(attention_prob) + context_states = ops.matmul(attention_prob, value_states) + return context_states + + def merge_heads(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = (*hidden_states.shape[:-2], hidden_states.shape[-2] * hidden_states.shape[-1]) + return hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, hidden_states, is_key=False): + new_hidden_states_shape = ( + *hidden_states.shape[:-1], + self.n_heads, + hidden_states.shape[-1] // self.n_heads, + ) + hidden_states = hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct split_states + if is_key: + return hidden_states.permute(0, 2, 3, 1) + else: + return hidden_states.permute(0, 2, 1, 3) + + def dense_attn(self, query, key, value, sample): + query = self.split_heads(query) + key = self.split_heads(key, is_key=True) + value = self.split_heads(value) + context_states = self._attn(query, key, value, sample) + context_states = self.merge_heads(context_states) + return context_states + + def block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim) + if query_length < seq_len: + seq_len = query_length + key = key[:, -seq_len:].contiguous() + value = value[:, -seq_len:].contiguous() + key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def transpose_block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + block_len = (seq_len - 1) % block_ctx + key = key[:, block_len::block_ctx, :] + value = value[:, block_len::block_ctx, :] + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim) + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim) + + key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim) + key = key.transpose(1, 2).contiguous() + key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim) + + value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim) + value = value.transpose(1, 2).contiguous() + value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim) + + block_attn = self.dense_attn(query, key, value, sample) + block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim) + block_attn = block_attn.transpose(1, 2).contiguous() + block_attn = block_attn.view(batch_size, query_length, embed_dim) + + return block_attn + + def prev_block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + block = (seq_len - 1) // block_ctx + prev_l = (block - 1) * block_ctx + if block > 0: + key = key[:, prev_l : prev_l + block_ctx, :] + value = value[:, prev_l : prev_l + block_ctx, :] + else: + key = ops.zeros(batch_size, block_ctx, embed_dim, dtype=query.dtype) + value = ops.zeros(batch_size, block_ctx, embed_dim, dtype=query.dtype) + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim) + + key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :] + key = ops.pad(key, (0, 0, 0, 0, 1, 0)) + key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + + value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :] + value = ops.pad(value, (0, 0, 0, 0, 1, 0)) + value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + + if query_length < seq_len: + nb_query_blocks = query_length // block_ctx + nb_key_blocks = seq_len // block_ctx + seq_len = query_length + key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:] + key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim) + + value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:] + value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim) + + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def summary_attn(self, query, key, value, sample): + blocks = self.blocks + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + key = key[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :] + key = ops.pad(key, (0, 0, 1, 0)) + + value = value[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :] + value = ops.pad(value, (0, 0, 1, 0)) + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :] + key = ops.pad(key, (0, 0, 1, 0)) # batch_size, blocks, embed_dim + + value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :] + value = ops.pad(value, (0, 0, 1, 0)) # batch_size, blocks, embed_dim + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def summary_spread_attn(self, query, key, value, sample): + blocks = self.blocks + spread = self.spread + + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + raise NotImplementedError + else: + key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :] + key = ops.pad(key, (0, 0, 0, 0, 1, 0)).contiguous() + key = key.view(batch_size, blocks * spread, embed_dim) + + value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :] + value = ops.pad(value, (0, 0, 0, 0, 1, 0)).contiguous() + value = value.view(batch_size, blocks * spread, embed_dim) + + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def prime_attn(self, query, key, value, sample): + encoder_len = self._encoder_len + key = key[:, :encoder_len] + value = value[:, :encoder_len] + return self.dense_attn(query, key, value, sample) + + def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + if last_encoder_hidden_states is not None: + raise TypeError("last_encoder_hidden_states should be None") + + query, key, value = hidden_states.chunk(3, dim=2) + if sample: + self.sample_t += curr_ctx + key, value = self._append_cache(key, value) + l_cache = self._suff_cache_len() + if self._cache_len() > l_cache: + self._slice_cache(-l_cache) + if curr_ctx > 1: + if self.attn_func != "dense_attn": + query = self._pad_to_block_ctx(query, query=True) + key = self._pad_to_block_ctx(key) + value = self._pad_to_block_ctx(value) + sample = False + else: + key = self.cache["key"] + value = self.cache["value"] + return query, key, value, sample + + def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + if last_encoder_hidden_states is not None: + raise TypeError("last_encoder_hidden_states should be None") + query, key, value = hidden_states.chunk(3, dim=2) + if sample: + if self._cache_len() < self._encoder_len: + self._append_cache(key, value) + if self._cache_len() > self._encoder_len: + self._slice_cache(0, self._encoder_len) + key, value = self.cache["key"], self.cache["value"] + self.sample_t += curr_ctx + return query, key, value, sample + + def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + query = hidden_states + if sample: + if self.sample_t == 0: + self.cache["key"], self.cache["value"] = self.c_enc_kv( + last_encoder_hidden_states.type_as(hidden_states) + ).chunk(2, dim=2) + key, value = self.cache["key"], self.cache["value"] + self.sample_t += curr_ctx + else: + key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2) + return query, key, value, sample + + def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + hidden_states = self.c_attn(hidden_states) + query, key, value, sample = self.qkv( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample + ) + attention_scores = self.attn(query, key, value, sample) + if attention_scores.shape[1] != curr_ctx: + offset = self._offset(curr_ctx) + attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous() + attention_scores = self.c_proj(attention_scores) + return self.resid_dropout(attention_scores) + + @property + def _encoder_len(self): + encoder_len = self.encoder_len + encoder_blocks = (encoder_len // self.blocks) + 1 + return encoder_blocks * self.blocks + + def _offset(self, curr_ctx): + if self.attn_func == "dense_attn": + return 0 + return (self.sample_t - curr_ctx) % self.block_ctx + + def _pad_to_block_ctx(self, hidden_states, query=False): + seq_len = hidden_states.shape[1] + offset = self._offset(seq_len) if query else 0 + n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx + pad = n_blocks * self.block_ctx - seq_len - offset + if pad == 0 and offset == 0: + return hidden_states + else: + return F.pad(hidden_states, (0, 0, offset, pad)) + + def _cache_len(self): + return 0 if "key" not in self.cache else self.cache["key"].shape[1] + + def _suff_cache_len(self): + """ + Precondition: + key and value are appended with the current context and self.sample_t reflects the 1-indexed sample + location in the context. + """ + previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx + REQUIRED_CACHE_LEN = { + "dense_attn": self.sample_t, + "block_attn": (self.sample_t - 1) % self.block_ctx + 1, + "transpose_block_attn": self.sample_t, + "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length, + "cross_attn": self.encoder_len, + "prime_attn": min(self.sample_t, self._encoder_len), + } + + return REQUIRED_CACHE_LEN[self.attn_func] + + def _slice_cache(self, start, end=None): + self.cache["key"] = self.cache["key"][:, start:end] + self.cache["value"] = self.cache["value"][:, start:end] + + def _append_cache(self, key, value): + if "key" not in self.cache: + self.cache["key"] = key + self.cache["value"] = value + else: + old_key, old_value = key, value + key = ops.cat([self.cache["key"], old_key], dim=1) + value = ops.cat([self.cache["value"], old_value], dim=1) + del self.cache["key"] + del self.cache["value"] + del old_key + del old_value + self.cache["key"] = key + self.cache["value"] = value + return self.cache["key"], self.cache["value"] + + def del_cache(self): + self.sample_t = 0 + if "key" in self.cache: + del self.cache["key"] + if "value" in self.cache: + del self.cache["value"] + self.cache = {} + + +class JukeboxBlock(nn.Module): + def __init__(self, config, n_ctx, attn_func="dense_attn"): + super().__init__() + self.width = config.hidden_size + self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func) + + self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size) + self.mlp = JukeboxMLP(config) + self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size) + self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0 + self.attn_func = attn_func + + def forward(self, hidden_states, last_encoder_hidden_states, sample=False): + residuals = hidden_states + hidden_states = self.layer_norm_0(hidden_states) + hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample) + + output_states = self.layer_norm_1(residuals + hidden_states) + output_states = self.mlp(output_states) + if self.res_scale == 1.0: + output = residuals + hidden_states + output_states + else: + output = residuals + self.res_scale * (hidden_states + output_states) + return output + + +class JukeboxLayerStack(nn.Module): + def __init__(self, config, n_ctx): + super().__init__() + self.n_ctx = n_ctx + self.width = config.hidden_size + self.num_layers = config.num_layers + self.blocks = config.blocks + self.attention_pattern = config.attention_pattern + if self.blocks is not None: + self.block_ctx = n_ctx // self.blocks + self.encoder_len = config.nb_relevant_lyric_tokens + self.n_heads = config.n_heads + + # Orders of attn_func + attention_pattern = ATTENTION_PATTERNS[self.attention_pattern] + self._attn_mods = nn.ModuleList() + for depth in range(self.num_layers): + self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth))) + + self.saved_attn_weights = [] + + def set_record_attn(self, record_attn): + """ + Makes forward prop dump self-attention softmaxes to self.saved_attn_weights. + + Args: + record_attn (`Union[bool,set]`): + Either a set of layer indices indicating which layers to store, or a boolean value indicating Whether + to dump all. + """ + + def _should_record_attn(layer_idx): + if isinstance(record_attn, bool): + return record_attn + return layer_idx in record_attn + + for i, layer in enumerate(self._attn_mods): + layer.attn.record_attn = _should_record_attn(i) + + if not record_attn: + self.saved_attn_weights = [] + + def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False): + # Blocks + for i, attn_layer in enumerate(self._attn_mods): + if attn_layer.attn_func == "cross_attention": # attend to the lyrics + hidden_states = attn_layer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample + ) + else: + hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample) + if attn_layer.attn.record_attn: + self.saved_attn_weights.append(attn_layer.attn.c_attn.weight) + return hidden_states + + def del_cache(self): + for attn_layer in self._attn_mods: + attn_layer.attn.del_cache() + + +class JukeboxPositionalEmbedding(nn.Module): + def __init__(self, embed_dim, width): + super().__init__() + self.pos_emb = nn.Parameter(ops.empty((embed_dim, width))) + + def forward(self): + pos_emb = self.pos_emb + return pos_emb + + +class JukeboxConditionalAutoregressive(nn.Module): + def __init__( + self, + config, + n_ctx=None, + embed_dim=None, + audio_conditioning=False, + metadata_conditioning=False, + is_encoder=False, + ): + """ + Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly + set fro each configuration. + + Args: + config (`JukeboxPriorConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does + not load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + n_ctx (`int`, *optional*): + Number of tokens or lyrics tokens provided in a single pass. + embed_dim (`int`, *optional*): + Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codeboook dimension, + if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder + audio_conditioning (`bool`, *optional*, defaults to `False`): + Whether or not the prior supports conditionning on audio. + metadata_conditioning (`bool`, *optional*, defaults to `False`): + Whether or not the prior supports conditionning on artitst, genres, lyrics and timing. + is_encoder (`bool`, *optional*, defaults to `False`): + Whether the model is an encoder only model. + """ + + super().__init__() + self.width = config.hidden_size + self.num_layers = config.num_layers + self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx + self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size + self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size) + self.embed_tokens_dropout = nn.Dropout(config.emb_dropout) + self.metadata_conditioning = metadata_conditioning + self.audio_conditioning = audio_conditioning + if not metadata_conditioning: + self.start_token = nn.Parameter(ops.empty((1, config.hidden_size))) + self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size) + self.pos_emb_dropout = nn.Dropout(config.emb_dropout) + + self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx) + self.is_encoder = is_encoder + self.encoder_len = config.nb_relevant_lyric_tokens + + if config.merged_decoder: + # Merged piped model uses this setup + self.add_cond_after_transformer = False + self.share_embed_tokens_fc_proj_out = False + else: + self.add_cond_after_transformer = True + self.share_embed_tokens_fc_proj_out = True + + if not is_encoder: + self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False) + if self.share_embed_tokens_fc_proj_out: + self.fc_proj_out.weight = self.embed_tokens.weight + self.loss = nn.CrossEntropyLoss() + + def forward( + self, + tokens, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + get_preds=False, + get_acts=False, + get_sep_loss=False, + ): + """ + Args: + tokens (`mindspore.tensor`): + Can represent music tokens, lyrics tokens or both, depending on the configuration. + """ + # Preprocess. + batch_size = tokens.shape[0] + with no_grad(): + tokens = tokens.view(batch_size, -1).long() + + if not self.audio_conditioning: + audio_conditioning = ops.zeros( + (batch_size, 1, self.width), + dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype, + ) + + target = tokens # Target + hidden_states = self.embed_tokens(tokens) + # Shift by 1, and fill in start token + hidden_states = ops.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1) + if self.metadata_conditioning: + hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width) + else: + hidden_states[:, 0] = self.start_token + + hidden_states = ( + self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning + ) # Pos emb and dropout + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states + ) # Transformer + if self.add_cond_after_transformer: # Piped doesnt add x_cond + hidden_states = hidden_states + audio_conditioning + + activations = hidden_states + if self.is_encoder: + return hidden_states + + hidden_states = self.fc_proj_out(hidden_states) # Predictions + loss_fn = nn.CrossEntropyLoss() + if get_sep_loss: + lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim) + token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim) + + lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0) + music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0) + + loss = (lyric_loss, music_token_loss) # Note order! Lyric is first + else: + loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0) # Loss + + if get_preds: + return loss, hidden_states + elif get_acts: + return loss, activations + else: + return loss, None + + def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning): + if sample_t == 0: + hidden_states = ops.zeros(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to( + self.embed_tokens.weight + ) + if self.metadata_conditioning: + hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width) + else: + hidden_states[:, 0] = self.start_token + else: + hidden_states = self.embed_tokens(tokens) + if audio_conditioning.shape == (n_samples, self.n_ctx, self.width): + cond = audio_conditioning[:, sample_t : sample_t + 1, :] + else: + cond = audio_conditioning + # Pos emb, dropout is identity at eval time + hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond + return hidden_states, cond + + def sample( + self, + n_samples, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + temp=1.0, + top_k=0, + top_p=0.0, + get_preds=False, + sample_tokens=None, + ): + if sample_tokens is None: + sample_tokens = self.n_ctx + + if not self.audio_conditioning: + audio_conditioning = ops.zeros( + (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype + ).to(self.fc_proj_out) + + with no_grad(): + sampled_tokens = [] + tokens = None + if get_preds: + preds = [] + + iter = tqdm(range(0, sample_tokens), leave=False) + for sample_t in iter: + iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True) + hidden_states, cond = self.get_emb( + sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning + ) + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True + ) + if self.add_cond_after_transformer: + hidden_states = hidden_states + cond + hidden_states = self.fc_proj_out(hidden_states) # Predictions + if get_preds: + preds.append(hidden_states.clone()) + # Adjust logits + hidden_states = hidden_states / temp + hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p) + # Sample and replace hidden_states + tokens = distributions.Categorical(logits=hidden_states).sample() + sampled_tokens.append(tokens.clone()) + + del tokens + self.transformer.del_cache() + + tokens = ops.cat(sampled_tokens, dim=1) + if get_preds: + preds = ops.cat(preds, dim=1) + if get_preds: + return tokens, preds + else: + return tokens + + def split_chunks(self, length, chunk_size): + n_passes = (length + chunk_size - 1) // chunk_size + chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1] + return chunk_sizes + + def primed_sample( + self, + n_samples, + lyric_and_music_tokens, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + temp=1.0, + top_k=0, + top_p=0.0, + get_preds=False, + chunk_size=None, + sample_tokens=None, + ): + if sample_tokens is None: + sample_tokens = self.n_ctx + # Preprocess. + batch_size = lyric_and_music_tokens.shape[0] + with no_grad(): + lyric_and_music_tokens = lyric_and_music_tokens.view(batch_size, -1).long() + + sampled_audio = ops.split(lyric_and_music_tokens, 1, dim=1) + sampled_audio = list(sampled_audio) + + if not self.audio_conditioning: + audio_conditioning = ops.zeros( + (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype + ).to(lyric_and_music_tokens) + + with no_grad(): + if get_preds: + preds = [] + + # Fill up key/value cache for past context by runing forward pass. + # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage. + if chunk_size is None: + chunk_size = len(sampled_audio) + chunk_sizes = self.split_chunks(len(sampled_audio), chunk_size) + x_primes = [] + start = 0 + token = None + + for current_chunk_size in tqdm(chunk_sizes, desc="Preparing past key value", leave=False): + sampled_audio_prime, conds_prime = [], [] + for sample_t in range(start, start + current_chunk_size): + x_prime, cond_prime = self.get_emb( + sample_t, n_samples, token, audio_conditioning, metadata_conditioning + ) + token = sampled_audio[sample_t] + sampled_audio_prime.append(x_prime) + conds_prime.append(cond_prime) + start = start + current_chunk_size + x_prime, cond_prime = ops.cat(sampled_audio_prime, dim=1), ops.cat(conds_prime, dim=1) + del sampled_audio_prime + del conds_prime + if not get_preds: + del cond_prime + x_prime = self.transformer(x_prime, last_encoder_hidden_states=last_encoder_hidden_states, sample=True) + + if get_preds: + if self.add_cond_after_transformer: + x_prime = x_prime + cond_prime + del cond_prime + x_primes.append(x_prime) + else: + del x_prime + + if get_preds: + x_prime = ops.cat(x_primes, dim=1) + x_prime = self.fc_proj_out(x_prime) # Predictions + preds.append(x_prime) + + # the input of the encoder and decoder can be merged into (lyrics, music tokens) + input_tokens = sampled_audio[-1] + + itererator = tqdm( + range(len(sampled_audio), sample_tokens), + desc=f"Sampling {len(range(len(sampled_audio), sample_tokens))} music tokens", + leave=False, + ) + for sample_t in itererator: + hidden_states, cond = self.get_emb( + sample_t, n_samples, input_tokens, audio_conditioning, metadata_conditioning + ) + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True + ) + if self.add_cond_after_transformer: + hidden_states = hidden_states + cond + hidden_states = self.fc_proj_out(hidden_states) # Predictions + if get_preds: + preds.append(hidden_states) + # Adjust logits + hidden_states = hidden_states / temp + hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p) + # only music tokens are sampled + music_tokens = distributions.Categorical(logits=hidden_states).sample() + sampled_audio.append(music_tokens.clone()) + input_tokens = music_tokens + + del input_tokens, music_tokens + self.transformer.del_cache() + + music_tokens = ops.cat(sampled_audio, dim=1) + if get_preds: + preds = ops.cat(preds, dim=1) + if get_preds: + return music_tokens, preds + else: + return music_tokens + + +class JukeboxMusicTokenConditioner(nn.Module): + """ + The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's + codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE). + """ + + def __init__(self, config, level): + super().__init__() + self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size) + config.embed_dim = config.music_vocab_size # setting correct argument for the `JukeboxDecoder` + + self.upsampler = JukeboxDecoderConvBock( + config, + config.hidden_size, + config.res_conv_width, + config.res_conv_depth, + config.res_downs_t[level], + config.res_strides_t[level], + reverse_dilation=False, + ) + self.layer_norm = JukeboxLayerNorm(config.hidden_size) + + def forward(self, music_tokens, raw_audio_conditionning=None): + """ + Args: + music_tokens (`mindspore.Tensor`): + Music tokens form the uper level in range(nb_discrete_codes) + raw_audio_conditionning (`mindspore.Tensor`, *optional*): + Audio used when primed sampling, raw audio information that conditions the generation + """ + if raw_audio_conditionning is None: + raw_audio_conditionning = 0.0 + # Embed music_tokens + music_tokens = music_tokens.long() + hidden_states = self.embed_tokens(music_tokens) + hidden_states = hidden_states + raw_audio_conditionning + + # Run conditioner + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class JukeboxRangeEmbedding(nn.Module): + """ + The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional + embedding of length `n_ctx`. + + Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end) + -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <= + end + """ + + def __init__(self, n_time, embed_dim, range, out_width, clamp=False): + super().__init__() + self.n_time = n_time + self.embed_dim = embed_dim + self.emb = nn.Embedding(embed_dim, out_width) + self.pos_min, self.pos_max = range + self.clamp = clamp + + def forward(self, pos_start, pos_end=None): + # Check if [pos_start,pos_end] in [pos_min, pos_max) + if not len(pos_start.shape) == 2: + raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}") + if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all(): + raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}") + + pos_start = pos_start.float() + if pos_end is not None: + if self.clamp: + pos_end = pos_end.clamp(self.pos_min, self.pos_max) + + pos_end = pos_end.float() + # Interpolate so that [pos_start, ..., pos_end] <-> position tensor of length n_ctx + n_time = self.n_time + if n_time != 1: + interpolation = ( + ops.arange(0, n_time, dtype=mindspore.float32).view(1, n_time) / n_time + ) + position = pos_start + (pos_end - pos_start) * interpolation + else: + position = pos_start + + # Bin each value to bins_ + # [0,1) -> [0,1..,embed_dim) -> [0,1...,embed_dim-1 + normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min) + bins_ = (self.embed_dim * normalised_position).floor().long().detach() + return self.emb(bins_) + + +class JukeboxLabelConditioner(nn.Module): + def __init__(self, config, include_time_signal): + super().__init__() + + embed_dim = config.hidden_size + timing_dims = config.timing_dims + sampling_rate = config.sampling_rate + nb_genres, nb_artists = config.metadata_dims + music_tokens_shape = config.n_ctx + + self.max_nb_genres = config.max_nb_genres + self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim) + self.artist_emb = nn.Embedding(nb_artists, embed_dim) + self.include_time_signal = include_time_signal + if self.include_time_signal: + total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate) + absolute_pos_range = (0.0, config.max_duration * sampling_rate) + relative_pos_range = (0.0, 1.0) + self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim) + self.absolute_pos_emb = JukeboxRangeEmbedding( + music_tokens_shape, timing_dims, absolute_pos_range, embed_dim + ) + self.relative_pos_emb = JukeboxRangeEmbedding( + music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True + ) + + def forward(self, metadata): + total_length = metadata[:, 0:1] + offset = metadata[:, 1:2] + length = metadata[:, 2:3] + artist = metadata[:, 3:4] + genre = metadata[:, 4:] + + # Start embedding of length 1 + artist_emb = self.artist_emb(artist) + # Empty genre slots are denoted by -1. We mask these out. + mask = (genre >= 0).float().unsqueeze(2) + genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True) + start_emb = genre_emb + artist_emb + + # Pos embedding of length n_ctx + if self.include_time_signal: + start, end = offset, offset + length + total_length = total_length.float() + start = start.float() + end = end.float() + pos_emb = ( + self.total_length_emb(total_length) + + self.absolute_pos_emb(start, end) + + self.relative_pos_emb(start / total_length, end / total_length) + ) + else: + pos_emb = None + return start_emb, pos_emb + + +class JukeboxPrior(PreTrainedModel): + """ + The JukeboxPrior class, which is a wrapper around the various conditioning and the transformer. JukeboxPrior can be + seen as language models trained on music. They model the next `music token` prediction task. If a (lyric) `encoderĂ¹ + is defined, it also models the `next character` prediction on the lyrics. Can be conditionned on timing, artist, + genre, lyrics and codes from lower-levels Priors. + + Args: + config (`JukeboxPriorConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + level (`int`, *optional*): + Current level of the Prior. Should be in range `[0,nb_priors]`. + nb_priors (`int`, *optional*, defaults to 3): + Total number of priors. + vqvae_encoder (`Callable`, *optional*): + Encoding method of the VQVAE encoder used in the forward pass of the model. Passing functions instead of + the vqvae module to avoid getting the parameters. + vqvae_decoder (`Callable`, *optional*): + Decoding method of the VQVAE decoder used in the forward pass of the model. Passing functions instead of + the vqvae module to avoid getting the parameters. + """ + + config_class = JukeboxPriorConfig + + def _init_weights(self, module): + init_scale = self.config.init_scale + + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxConv1D): + if self.config.zero_out: + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxPositionalEmbedding): + module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxRangeEmbedding): + module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"): + module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"): + module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: + module.conv1d_2.weigth.data.zero_() + module.conv1d_2.bias.data.zero_() + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None): + super().__init__(config) + # Passing functions instead of the vqvae module to avoid getting params, only used in the + # forward loop + self.vqvae_encoder = vqvae_encoder + self.vqvae_decoder = vqvae_decoder + + self.levels = nb_priors + self.level = level if level is not None else config.level + + self.base_model_prefix = f"priors.{self.level}" + + self.n_ctx = config.n_ctx + + self.lyric_conditioning = config.nb_relevant_lyric_tokens > 0 + self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens + self.encoder_loss_fraction = config.encoder_loss_fraction + + # Audio conditioning : conditioning on music tokens (either from audio or from previous levels or both) + self.audio_conditioning = self.level != 0 + self.cond_level = self.level - 1 + if self.audio_conditioning: + self.conditioner_blocks = JukeboxMusicTokenConditioner(config, self.level) + + # metadata conditioning : contioning on timing, genres, and artist + self.metadata_conditioning = config.metadata_conditioning + if self.metadata_conditioning: + self.metadata_embedding = JukeboxLabelConditioner(config, include_time_signal=not self.audio_conditioning) + + # define encoder-decoder or encoder and decoder + self.is_encoder_decoder = config.is_encoder_decoder + if config.is_encoder_decoder: + # encoder-decoder transformer + self.input_shapes = [config.nb_relevant_lyric_tokens, config.n_ctx] + self.embed_dim_shift = [0, config.lyric_vocab_size] + self.width = config.hidden_size + + self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens + + self.prior = JukeboxConditionalAutoregressive( + config, + n_ctx=config.nb_relevant_lyric_tokens + config.n_ctx, + embed_dim=config.lyric_vocab_size + config.music_vocab_size, + audio_conditioning=(self.audio_conditioning or self.metadata_conditioning), + metadata_conditioning=True, + ) + + else: + # Separate encoder-decoder transformer + encoder_config = config.encoder_config + + if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning: + self.lyric_acts_width = encoder_config.hidden_size + self.encoder_width = config.hidden_size + self.encoder_dim = config.lyric_vocab_size + self.encoder = JukeboxConditionalAutoregressive( + encoder_config, + n_ctx=self.nb_relevant_lyric_tokens, + embed_dim=self.encoder_dim, + audio_conditioning=False, + metadata_conditioning=False, + is_encoder=True, + ) + self.encoder.proj_in = JukeboxConv1D(encoder_config.hidden_size, config.hidden_size) + self.encoder.final_layer_norm = JukeboxLayerNorm(config.hidden_size) + self.encoder.lm_head = nn.Linear(config.hidden_size, config.lyric_vocab_size, bias=False) + else: + self.nb_relevant_lyric_tokens = 0 + + # decoder model on the tokens + self.prior = JukeboxConditionalAutoregressive( + config, + audio_conditioning=(self.audio_conditioning or self.metadata_conditioning), + metadata_conditioning=self.metadata_conditioning, + ) + + self.next_token_prediction_loss_dims = config.n_ctx + self.total_loss_dims = self.nb_relevant_lyric_tokens + self.next_token_prediction_loss_dims + + self.downsamples = [stride**down for stride, down in zip(config.res_strides_t, config.res_downs_t)] + self.cond_downsample = self.downsamples[self.level] if self.level != 0 else None + self.raw_to_tokens = np.prod(self.downsamples[: nb_priors - self.level]) + self.sample_length = self.n_ctx * self.raw_to_tokens + + logger.info( + f"Level:{self.level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample" + f" length:{self.sample_length}" + ) + + def get_metadata(self, labels, start, total_length, offset, get_indices=False): + metadata = labels.clone() + metadata[:, 0] = total_length + # Set sample_length to match this level + metadata[:, 2] = int(self.sample_length) + + # Set offset + metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens) + # here since metadata has the full token_list, we just need to selected the ones that are relevant + + # Set lyric tokens + metadata, indices = self.set_metadata_lyric_tokens(metadata) + if get_indices: + return metadata, indices + else: + return metadata + + def set_metadata_lyric_tokens(self, labels): + """ + Processes the full labels to only retreive the relevant lyric tokens and keep the metadata conditioning tokens. + """ + if self.nb_relevant_lyric_tokens > 0: + tokens_list = ops.zeros( + (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=mindspore.int64) + indices_list = [] # whats the index of each current character in original array + for idx in range(labels.shape[0]): + full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :] + total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2] + tokens, indices = get_relevant_lyric_tokens( + full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration + ) + tokens_list[idx, :] = tokens + indices_list.append(indices) + + return ( + ops.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1), + indices_list, + ) + else: + return labels, None + + def get_music_tokens_conds(self, music_tokens, start, end): + """ + Extracts current level's conditioning music tokens. + """ + if self.level != 0: + music_tokens_cond = music_tokens[self.level - 1] + music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample] + missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1] + if missing_cond_len > 0: + init_cond = ops.zeros(1, missing_cond_len).to(music_tokens_cond) + music_tokens_cond = ops.cat((music_tokens_cond, init_cond), dim=-1).long() + music_tokens_conds = [music_tokens_cond] + else: + music_tokens_conds = None + return music_tokens_conds + + def prior_preprocess(self, tokens, conds): + """ + Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music + tokens should be shifted by. It is equal to `lyric_vocab_size`. + """ + batch_size = tokens[0].shape[0] + for i in range(len(tokens)): + tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1) + + for i in range(len(conds)): + if conds[i] is None: + conds[i] = ops.zeros( + (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype) + + return ops.cat(tokens, dim=1), ops.cat(conds, dim=1) + + def prior_postprocess(self, tokens): + """ + Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is + shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music + tokens. + """ + batch_size = tokens.shape[0] + dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0]) + tokens = list(ops.split(tokens, dims, dim=1)) + + # Some of the input tokens might be shifted to take into account the voccabulary fusion + for i in range(len(tokens)): + bins_shift = int(self.embed_dim_shift[i]) + tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1) + tokens[i] = ops.clamp(tokens[i], min=0) + # If not masking loss, model may have generated lyric/midi tokens which are now shifted <0 by bin_shift + return tokens[-1] + + def embed_tokens(self, music_tokens_conds): + """ + Embeds the upper level music tokens and upsamples them to provide as audio conditioning. + """ + music_tokens_conds = music_tokens_conds[: self.cond_level + 1] + audio_conditioning = None + for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))): + audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning) + return audio_conditioning + + def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1): + """ + Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states. + """ + if start_level is None: + start_level = self.level + if end_level is None: + end_level = self.levels + # Get latents + with no_grad(): + latent_states = self.vqvae_encoder( + hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks + ) + return latent_states + + def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1): + """ + Usamples the sequence of codebook vectors to a raw audio. + """ + if start_level is None: + start_level = self.level + if end_level is None: + end_level = self.levels + with no_grad(): + output = self.vqvae_decoder( + music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks + ) + return output + + def get_cond(self, music_tokens_conds, metadata): + """ + Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens + can be None. + """ + if metadata is not None: + n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens + metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:] + else: + metadata, lyric_tokens = None, None + metadata_conditioning, metadata_pos = ( + self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None) + ) + audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos + return audio_conditioning, metadata_conditioning, lyric_tokens + + def sample( + self, + n_samples, + music_tokens=None, + music_tokens_conds=None, + metadata=None, + temp=1.0, + top_k=0, + top_p=0.0, + chunk_size=None, + sample_tokens=None, + ): + """ + Ancestral/Prime sampling a window of tokens using the provided conditioning and metadatas. + + Args: + n_samples (`int`): + Number of samples to generate. + music_tokens (`List[mindspore.Tensor]`, *optional*): + Previously gemerated tokens at the current level. Used as context for the generation. + music_tokens_conds (`List[mindspore.Tensor]`, *optional*): + Upper-level music tokens generated by the previous prior model. Is `None` if the generation is not + conditionned on the upper-level tokens. + metadata (`List[mindspore.Tensor]`, *optional*): + List containing the metatdata tensor with the artist, genre and the lyric tokens. + temp (`float`, *optional*, defaults to 1.0): + Sampling temperature. + top_k (`int`, *optional*, defaults to 0): + Top k probabilities used for filtering. + top_p (`float`, *optional*, defaults to 0.0): + Top p probabilities used for filtering. + chunk_size (`int`, *optional*): + Size of the chunks used to prepare the cache of the transformer. + sample_tokens (`int`, *optional*): + Number of tokens to sample. + + """ + no_past_context = music_tokens is None or music_tokens.shape[1] == 0 + name = {True: "Ancestral", False: "Primed"}[no_past_context] + logger.info(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}") + + with no_grad(): + # Currently audio_conditioning only uses immediately above layer + audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata) + if self.is_encoder_decoder: + if no_past_context: # the prime_sample function will be used with music_tokens set to None + lyric_and_music_tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens], [None, audio_conditioning] + ) + else: + lyric_and_music_tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens, music_tokens], [None, audio_conditioning] + ) + if sample_tokens is not None: + sample_tokens += self.nb_relevant_lyric_tokens + music_tokens = self.prior.primed_sample( + n_samples, + lyric_and_music_tokens, + audio_conditioning, + metadata_conditioning, + temp=temp, + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + music_tokens = self.prior_postprocess(music_tokens) + else: + last_encoder_hidden_states = self.get_encoder_states(lyric_tokens, sample=True) + if no_past_context: + music_tokens = self.prior.sample( + n_samples, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + temp=temp, + top_k=top_k, + top_p=top_p, + sample_tokens=sample_tokens, + ) + else: + music_tokens = self.prior.primed_sample( + n_samples, + music_tokens, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + temp=temp, + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + return music_tokens + + def get_encoder_states(self, lyric_tokens, sample=False): + """ + Retreive the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through + the lyric encoder. + """ + if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning: + if sample: + self.encoder = self.encoder.to(lyric_tokens) + lyric_acts = self.encoder(lyric_tokens, None, None, None) + lyric_acts = self.encoder.proj_in(lyric_acts) + last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts) + else: + last_encoder_hidden_states = None + return last_encoder_hidden_states + + def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics): + """ + Computes the loss for the lyric encoder: next lyric token prediction. + """ + if self.lyric_conditioning: + last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states) + encoder_loss = nn.functional.cross_entropy( + last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1) + ) / np.log(2.0) + else: + encoder_loss = mindspore.tensor(0.0) + return encoder_loss + + def forward_tokens( + self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False + ): + """ + Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the + vqvae's encoding layers. + """ + if get_attn_weights: + self.prior.transformer.set_record_attn(get_attn_weights) + audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata) + + if self.is_encoder_decoder: # the preprocess returns the full tokens (Lyrics and Music tokens), shifted + tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens, music_tokens], [None, audio_conditioning] + ) + (encoder_loss, next_token_prediction_loss), preds = self.prior( + tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds + ) + else: + last_encoder_hidden_states = self.get_encoder_states(lyric_tokens) + encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens) + next_token_prediction_loss, preds = self.prior( + music_tokens, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + get_preds=get_preds, + ) + loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims + loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims + + metrics = { + "bpd": next_token_prediction_loss.clone().detach(), + "encoder_loss": encoder_loss.clone().detach(), + "next_token_prediction_loss": next_token_prediction_loss.clone().detach(), + } + if get_preds: + metrics["preds"] = preds.clone().detach() + if get_attn_weights: + saved_attn_weights = self.prior.transformer.saved_attn_weights + self.prior.transformer.set_record_attn(False) + return saved_attn_weights + else: + return loss, metrics + + def forward( + self, + hidden_states: mindspore.Tensor, + metadata: Optional[List[mindspore.Tensor]], + decode: Optional[bool] = False, + get_preds: Optional[bool] = False, + ) -> List[mindspore.Tensor]: + """ + Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens` + function. The loss is the sum of the `encoder` loss and the `decoder` loss. + + Args: + hidden_states (`mindspore.Tensor`): + Hidden states which should be raw audio + metadata (`List[mindspore.Tensor]`, *optional*): + List containing the metadata conditioning tensorwith the lyric and the metadata tokens. + decode (`bool`, *optional*, defaults to `False`): + Whether or not to decode the encoded to tokens. + get_preds (`bool`, *optional*, defaults to `False`): + Whether or not to return the actual predicitons of the model. + """ + batch_size = hidden_states.shape[0] + music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size) + loss, metrics = self.forward_tokens( + music_tokens=music_tokens, + music_tokens_conds=music_tokens_conds, + metadata=metadata, + get_preds=get_preds, + ) + if decode: + dequantised_states = self.decode([music_tokens, *music_tokens_conds]) + else: + dequantised_states = None + return dequantised_states, loss, metrics + + +class JukeboxPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = JukeboxConfig + base_model_prefix = "jukebox" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + if isinstance(module, (JukeboxPrior, JukeboxVQVAE)): + module.apply(module._init_weights) + + +class JukeboxModel(JukeboxPreTrainedModel): + _no_split_modules = ["JukeboxBlock"] + + def __init__(self, config): + super().__init__(config) + vqvae_config = config.vqvae_config + self.vqvae = JukeboxVQVAE(vqvae_config) + self.set_shared_params(config) + self.priors = nn.ModuleList( + [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)] + ) + + def set_shared_params(self, model_config): + """ + Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig` + is nest, and is thus unreachable in the `from_dict` function + """ + for config in model_config.prior_configs: + config.sampling_rate = model_config.sampling_rate + config.timing_dims = model_config.timing_dims + config.min_duration = model_config.min_duration + config.max_duration = model_config.max_duration + config.max_nb_genres = model_config.max_nb_genres + config.metadata_conditioning = model_config.metadata_conditioning + + def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1): + return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks) + + def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1): + return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks) + + def split_batch(self, obj, n_samples, split_size): + n_passes = (n_samples + split_size - 1) // split_size + if isinstance(obj, mindspore.Tensor): + return ops.split(obj, split_size, dim=0) + elif isinstance(obj, list): + return list(zip(*[ops.split(item, split_size, dim=0) for item in obj])) + elif obj is None: + return [None] * n_passes + else: + raise TypeError("Unknown input type") + + # Sample a partial window of length= self.priors[level].n_ctx: + iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length) + for start in iterator: + music_tokens = self.sample_single_window( + music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size + ) + + else: + music_tokens = self.sample_partial_window( + music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size + ) + return music_tokens + + @no_grad() + def _sample( + self, + music_tokens, + labels, + sample_levels, + metas=None, + chunk_size=32, + sampling_temperature=0.98, + lower_batch_size=16, + max_batch_size=16, + sample_length_in_seconds=24, + compute_alignments=False, + sample_tokens=None, + offset=0, + save_results=True, + sample_length=None, + ) -> List[mindspore.Tensor]: + """ + Core sampling function used to generate music tokens. Iterates over the provided list of levels, while saving + the generated raw audio at each step. + + Args: + music_tokens (`List[mindspore.Tensor]`): + A sequence of music tokens of length `self.levels` which will be used as context to continue the + sampling process. Should have `self.levels` tensors, each corresponding to the generation at a certain + level. + labels (`List[mindspore.Tensor]`): + List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre + + lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens + which are used to condition the generation. + sample_levels (`List[int]`): + List of the desired levels at which the sampling will be done. A level is equivalent to the index of + the prior in the list of priors + metas (`List[Any]`, *optional*): + Metadatas used to generate the `labels` + chunk_size (`int`, *optional*, defaults to 32): + Size of a chunk of audio, used to fill up the memory in chuncks to prevent OOM erros. Bigger chunks + means faster memory filling but more consumption. + sampling_temperature (`float`, *optional*, defaults to 0.98): + Temperature used to ajust the randomness of the sampling. + lower_batch_size (`int`, *optional*, defaults to 16): + Maximum batch size for the lower level priors + max_batch_size (`int`, *optional*, defaults to 16): + Maximum batch size for the top level priors + sample_length_in_seconds (`int`, *optional*, defaults to 24): + Desired length of the generation in seconds + compute_alignments (`bool`, *optional*, defaults to `False`): + Whether or not to compute the alignment between the lyrics and the audio using the top_prior + sample_tokens (`int`, *optional*): + Precise number of tokens that should be sampled at each level. This is mostly useful for running dummy + experiments + offset (`int`, *optional*, defaults to 0): + Audio offset used as conditioning, corresponds to the starting sample in the music. If the offset is + greater than 0, the lyrics will be shifted take that intoaccount + save_results (`bool`, *optional*, defaults to `True`): + Whether or not to save the intermediate results. If `True`, will generate a folder named with the start + time. + sample_length (`int`, *optional*): + Desired length of the generation in samples. + + Returns: mindspore.Tensor + + Example: + + ```python + >>> from transformers import AutoTokenizer, JukeboxModel, set_seed + >>> import torch + + >>> metas = dict(artist="Zac Brown Band", genres="Country", lyrics="I met a traveller from an antique land") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval() + + >>> labels = tokenizer(**metas)["input_ids"] + >>> set_seed(0) + >>> zs = [ops.zeros(1, 0, dtype=mindspore.int64) for _ in range(3)] + >>> zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False) + >>> zs[0] + tensor([[1853, 1369, 1150, 1869, 1379, 1789, 519, 710, 1306, 1100, 1229, 519, + 353, 1306, 1379, 1053, 519, 653, 1631, 1467, 1229, 1229, 10, 1647, + 1254, 1229, 1306, 1528, 1789, 216, 1631, 1434, 653, 475, 1150, 1528, + 1804, 541, 1804, 1434]]) + ``` + """ + + top_prior = self.priors[0] + if sample_length is not None: + total_length = sample_length + else: + total_length = ( + int(sample_length_in_seconds * self.config.sampling_rate) // top_prior.raw_to_tokens + ) * top_prior.raw_to_tokens + + if sample_levels is None: + sample_levels = range(len(self.priors)) + + # total length of the signal, might be bit different from the actual generated length + self.total_length = total_length + for level in sample_levels: + sampling_kwargs = { + "temp": 0.99 if level == len(self.priors) - 1 else sampling_temperature, + "chunk_size": chunk_size, + "sample_tokens": sample_tokens, + } + # Set correct total_length, hop_length, labels and sampling_kwargs for level + + total_token_to_sample = total_length // self.priors[level].raw_to_tokens + hop_length = int(self.config.hop_fraction[level] * self.priors[level].n_ctx) + max_batch_size = lower_batch_size if level != sample_levels else max_batch_size + music_tokens = self.sample_level( + music_tokens, + labels[level], + offset, + sampling_kwargs, + level, + total_token_to_sample, + hop_length, + max_batch_size, + ) + + if save_results: + self.vqvae.to(music_tokens[level]) + # Decode sample + with no_grad(): + start_level = len(self.priors) - level - 1 # vqvae levels are reversed + raw_audio = self.vqvae.decode( + music_tokens[: level + 1], start_level=start_level, bs_chunks=music_tokens[level].shape[0] + ) + logdir = f"jukebox/level_{level}" + if not os.path.exists(logdir): + os.makedirs(logdir) + save_temp_audio(logdir, level, metas=metas, aud=raw_audio.float()) + if compute_alignments and self.priors[0] is not None and self.priors[0].nb_relevant_lyric_tokens > 0: + with no_grad(): + alignments = get_alignment(music_tokens, labels[0], self.priors[0], self.config) + mindspore.save_checkpoint({"alignments": alignments}, f"{logdir}/lyric_alignments.ckpt") + + return music_tokens + + def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[mindspore.Tensor]: + """ + Example: + + ```python + >>> from transformers import AutoTokenizer, JukeboxModel, set_seed + + >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval() + >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + + >>> lyrics = "Hey, are you awake? Can you talk to me?" + >>> artist = "Zac Brown Band" + >>> genre = "Country" + >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics) + >>> set_seed(0) + >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400) + + >>> with no_grad(): + ... model.decode(music_tokens)[:, :10].squeeze(-1) + tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405, + -0.0818, -0.0697]]) + ``` + """ + + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + music_tokens = [ + ops.zeros(n_samples, 0, dtype=mindspore.int64) for _ in range(len(self.priors)) + ] + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[mindspore.Tensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[mindspore.Tensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1))) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[mindspore.Tensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + self.vqvae.to(raw_audio).float() + with no_grad(): + music_tokens = self.vqvae.encode( + raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0] + ) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens +__all__ = [ + "JukeboxModel", + "JukeboxPreTrainedModel", + "JukeboxVQVAE", + "JukeboxPrior", + ] diff --git a/mindnlp/transformers/models/jukebox/tokenization_jukebox.py b/mindnlp/transformers/models/jukebox/tokenization_jukebox.py new file mode 100644 index 000000000..e2aca90b4 --- /dev/null +++ b/mindnlp/transformers/models/jukebox/tokenization_jukebox.py @@ -0,0 +1,342 @@ +# coding=utf-8 +# Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI Jukebox.""" + +import json +import os +import re +import unicodedata +from json.encoder import INFINITY +from typing import Any, Dict, List, Optional, Tuple + +import regex + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_base import BatchEncoding +from ....utils import logging + + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "artists_file": "artists.json", + "lyrics_file": "lyrics.json", + "genres_file": "genres.json", +} + + +class JukeboxTokenizer(PreTrainedTokenizer): + """ + Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs : + - Artists, unique ids are associated to each artist from the provided dictionary. + - Genres, unique ids are associated to each genre from the provided dictionary. + - Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the + vocabulary. + + This tokenizer does not require training. It should be able to process a different number of inputs: + as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.: + + Depending on the number of genres on which the model should be conditioned (`n_genres`). + ```python + >>> from transformers import JukeboxTokenizer + + >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"] + [tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49, + 40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + If nothing is provided, the genres and the artist will either be selected randomly or set to None + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to: + this superclass for more information regarding those methods. + + However the code does not allow that and only supports composing from various genres. + + Args: + artists_file (`str`): + Path to the vocabulary file which contains a mapping between artists and ids. The default file supports + both "v2" and "v3" + genres_file (`str`): + Path to the vocabulary file which contain a mapping between genres and ids. + lyrics_file (`str`): + Path to the vocabulary file which contains the accepted characters for the lyrics tokenization. + version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) : + List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of + `v2`. + n_genres (`int`, `optional`, defaults to 1): + Maximum number of genres to use for composition. + max_n_lyric_tokens (`int`, `optional`, defaults to 512): + Maximum number of lyric tokens to keep. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + artists_file, + genres_file, + lyrics_file, + version=["v3", "v2", "v2"], + max_n_lyric_tokens=512, + n_genres=5, + unk_token="<|endoftext|>", + **kwargs, + ): + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + self.version = version + self.max_n_lyric_tokens = max_n_lyric_tokens + self.n_genres = n_genres + self._added_tokens_decoder = {0: unk_token} + + with open(artists_file, encoding="utf-8") as vocab_handle: + self.artists_encoder = json.load(vocab_handle) + + with open(genres_file, encoding="utf-8") as vocab_handle: + self.genres_encoder = json.load(vocab_handle) + + with open(lyrics_file, encoding="utf-8") as vocab_handle: + self.lyrics_encoder = json.load(vocab_handle) + + oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+" + # In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters. + if len(self.lyrics_encoder) == 79: + oov = oov.replace(r"\-'", r"\-+'") + + self.out_of_vocab = regex.compile(oov) + self.artists_decoder = {v: k for k, v in self.artists_encoder.items()} + self.genres_decoder = {v: k for k, v in self.genres_encoder.items()} + self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()} + super().__init__( + unk_token=unk_token, + n_genres=n_genres, + version=version, + max_n_lyric_tokens=max_n_lyric_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder) + + def get_vocab(self): + return { + "artists_encoder": self.artists_encoder, + "genres_encoder": self.genres_encoder, + "lyrics_encoder": self.lyrics_encoder, + } + + def _convert_token_to_id(self, list_artists, list_genres, list_lyrics): + """Converts the artist, genre and lyrics tokens to their index using the vocabulary. + The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to + the lyrics token sequence. + """ + artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists] + for genres in range(len(list_genres)): + list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]] + list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres])) + + lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []] + return artists_id, list_genres, lyric_ids + + def _tokenize(self, lyrics): + """ + Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary. + """ + # only lyrics are not tokenized, but character based is easily handled + return list(lyrics) + + def tokenize(self, artist, genre, lyrics, **kwargs): + """ + Converts three strings in a 3 sequence of tokens using the tokenizer + """ + artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics) + lyrics = self._tokenize(lyrics) + return artist, genre, lyrics + + def prepare_for_tokenization( + self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False + ) -> Tuple[str, str, str, Dict[str, Any]]: + """ + Performs any necessary transformations before tokenization. + + Args: + artist (`str`): + The artist name to prepare. This will mostly lower the string + genres (`str`): + The genre name to prepare. This will mostly lower the string. + lyrics (`str`): + The lyrics to prepare. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + """ + for idx in range(len(self.version)): + if self.version[idx] == "v3": + artists[idx] = artists[idx].lower() + genres[idx] = [genres[idx].lower()] + else: + artists[idx] = self._normalize(artists[idx]) + ".v2" + genres[idx] = [ + self._normalize(genre) + ".v2" for genre in genres[idx].split("_") + ] # split is for the full dictionary with combined genres + + if self.version[0] == "v2": + self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+") + vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n" + self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))} + self.vocab[""] = 0 + self.n_vocab = len(vocab) + 1 + self.lyrics_encoder = self.vocab + self.lyrics_decoder = {v: k for k, v in self.vocab.items()} + self.lyrics_decoder[0] = "" + else: + self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+") + + lyrics = self._run_strip_accents(lyrics) + lyrics = lyrics.replace("\\", "\n") + lyrics = self.out_of_vocab.sub("", lyrics), [], [] + return artists, genres, lyrics + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _normalize(self, text: str) -> str: + """ + Normalizes the input text. This process is for the genres and the artist + + Args: + text (`str`): + Artist or Genre string to normalize + """ + + accepted = ( + [chr(i) for i in range(ord("a"), ord("z") + 1)] + + [chr(i) for i in range(ord("A"), ord("Z") + 1)] + + [chr(i) for i in range(ord("0"), ord("9") + 1)] + + ["."] + ) + accepted = frozenset(accepted) + pattern = re.compile(r"_+") + text = "".join([c if c in accepted else "_" for c in text.lower()]) + text = pattern.sub("_", text).strip("_") + return text + + def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str: + return " ".join(lyrics) + + def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding: + """Convert the raw string to a list of token ids + + Args: + artist (`str`): + Name of the artist. + genres (`str`): + List of genres that will be mixed to condition the audio + lyrics (`str`, *optional*, defaults to `""`): + Lyrics used to condition the generation + """ + input_ids = [0, 0, 0] + artist = [artist] * len(self.version) + genres = [genres] * len(self.version) + + artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics) + artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens) + + attention_masks = [-INFINITY] * len(full_tokens[-1]) + input_ids = [ + self.convert_to_tensors( + [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors + ) + for i in range(len(self.version)) + ] + return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks}) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Saves the tokenizer's vocabulary dictionary to the provided save_directory. + + Args: + save_directory (`str`): + A path to the directory where to saved. It will be created if it doesn't exist. + + filename_prefix (`Optional[str]`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + artists_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"] + ) + with open(artists_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.artists_encoder, ensure_ascii=False)) + + genres_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"] + ) + with open(genres_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.genres_encoder, ensure_ascii=False)) + + lyrics_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"] + ) + with open(lyrics_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False)) + + return (artists_file, genres_file, lyrics_file) + + def _convert_id_to_token(self, artists_index, genres_index, lyric_index): + """ + Converts an index (integer) in a token (str) using the vocab. + + Args: + artists_index (`int`): + Index of the artist in its corresponding dictionary. + genres_index (`Union[List[int], int]`): + Index of the genre in its corresponding dictionary. + lyric_index (`List[int]`): + List of character indices, which each correspond to a character. + """ + artist = self.artists_decoder.get(artists_index) + genres = [self.genres_decoder.get(genre) for genre in genres_index] + lyrics = [self.lyrics_decoder.get(character) for character in lyric_index] + return artist, genres, lyrics diff --git a/tests/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc b/tests/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc new file mode 100644 index 000000000..fe6f433cf Binary files /dev/null and b/tests/transformers/models/jukebox/__pycache__/test_modeling_jukebox.cpython-39-pytest-7.2.0.pyc differ diff --git a/tests/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc b/tests/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc new file mode 100644 index 000000000..f3574cf03 Binary files /dev/null and b/tests/transformers/models/jukebox/__pycache__/test_tokenization_jukebox.cpython-39-pytest-7.2.0.pyc differ