-
Notifications
You must be signed in to change notification settings - Fork 87
feat(transformers): add CSM (v4.54.1) #1398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @alien-0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the CSM (Compound Speech Model) to the library, including its modeling, generation logic, and configuration. The implementation is a significant addition, enabling text-to-waveform generation.
My review focuses on performance optimizations in the new model code. I've identified a few areas where loops can be vectorized for significant performance gains, particularly in audio encoding/decoding paths and logit computation. These are marked with TODOs in the code, and my comments provide suggestions on how to address them.
Overall, the PR is well-structured, and the addition of the CSM model is a valuable contribution. Addressing the performance points will make it even better.
| for audio_codes_batch in generated_audio_codes: | ||
| eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero() | ||
| if eos_idxs.numel() != 0: | ||
| cutoff_idx = eos_idxs.min() | ||
| else: | ||
| cutoff_idx = audio_codes_batch.shape[0] | ||
|
|
||
| audio_codes_batch = audio_codes_batch[:cutoff_idx] | ||
| codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0)) | ||
| audio.append(codec_decode_output.audio_values[0, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The audio decoding loop iterates over each item in the batch, which is inefficient. This should be vectorized to process the entire batch at once for better performance. The TODO comment indicates awareness of this, but it's a significant performance bottleneck that should be addressed.
For example, you could pad the audio_codes_batch tensors to the same length, stack them into a single batch tensor, and then call self.codec_model.decode once.
| for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): | ||
| batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] | ||
| for i in range(batch_input_values_cutoffs.shape[0] - 1): | ||
| start_idx = batch_input_values_cutoffs[i] | ||
| end_idx = batch_input_values_cutoffs[i + 1] | ||
| audio_batch = batch_input_values[..., start_idx:end_idx] | ||
| codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) | ||
| codebook_ids = codec_outputs.audio_codes.transpose(1, -1) | ||
| audio_tokens_list.append(codebook_ids[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The audio token encoding loop iterates over each item in the batch, which is inefficient. This should be vectorized to process the entire batch at once for better performance. The TODO comment indicates awareness of this, but it's a significant performance bottleneck that should be addressed.
This would likely involve padding the audio segments to a uniform length before passing them to self.codec_model.encode in a single batch.
| hidden_states = [ | ||
| mint.nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T) | ||
| for codebook_idx in range(codebook_weight.shape[0]) | ||
| ] | ||
| hidden_states = mint.stack(hidden_states, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list comprehension in CsmCodebooksHead.construct iterates over the sequence length to compute logits. This can be optimized by using a vectorized operation like mindspore.ops.bmm (batched matrix multiplication) for better performance, especially with longer sequences.
hidden_states = mint.bmm(hidden_states.transpose(0, 1), codebook_weight).transpose(0, 1)
What does this PR do?
Fixes # (issue)
Adds # (feature)
Before submitting
What's New. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx