Skip to content

Commit 36d1e6a

Browse files
authored
(feat)Example MmaDA: update performance in readme (#1377)
* print text token throughputs * print image generation throughputs * update readme * revert zero.py get_optimizer_param_tuples * udpate performance of zero2 training
1 parent 86e88c0 commit 36d1e6a

File tree

5 files changed

+62
-11
lines changed

5 files changed

+62
-11
lines changed

examples/mmada/README.md

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ Here is the development plan of the project:
4040

4141
| MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel |
4242
|:---------:|:-------------:|:-----------:|:-------------------:|
43-
| 2.6.0 | 24.1.RC3 | 7.6.0.1.220 | 8.0.RC3.beta1 |
43+
| 2.6.0/2.7.0 | 24.1.RC3.b080 | 7.5.T11.0.B088 | 8.1.RC1 |
4444

4545
</div>
4646

4747
1. Install
48-
[CANN 8.0.RC3.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC3.beta1)
48+
[CANN 8.1.RC1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.1.RC1)
4949
and MindSpore according to the [official instructions](https://www.mindspore.cn/install).
5050
2. Install requirements
5151
```shell
@@ -98,7 +98,7 @@ python generate.py
9898

9999
### 2. MultiModal Generation
100100

101-
For multiModal generation, please run:
101+
For multimodal generation, please run:
102102
```
103103
python3 inference_mmu.py config=configs/mmada_demo.yaml mmu_image_root=./mmu_validation question='Please describe this image in detail.'
104104
```
@@ -109,10 +109,28 @@ The outputs are stored locally.
109109
For text-to-image generation, please run:
110110
```
111111
python3 inference_t2i.py config=configs/mmada_demo.yaml batch_size=1 validation_prompts_file=validation_prompts/text2image_prompts.txt guidance_scale=3.5 generation_timesteps=15
112-
mode='t2i'
113112
```
114113
The outputs are stored locally.
115114
115+
### Performance
116+
117+
The following experiments are tested on Ascend Atlas 800T A2 machines with mindspore **2.7.0** under **pynative** mode:
118+
119+
| model | # card(s) | batch size | task | throughput (token/s) |
120+
|:-:|:-:|:-:|:-:|:-:|
121+
| MMaDA-8B-Base | 1 | 1 | text generation | 12.56 |
122+
| MMaDA-8B-Base | 1 | 1 | mmu generation | 13.48 |
123+
| MMaDA-8B-Base | 1 | 1 | text-to-image generation| 167.50 |
124+
125+
The following experiments are tested on Ascend Atlas 800T A2 machines with mindspore **2.6.0** under **pynative** mode:
126+
127+
| model | # card(s) | batch size | task | throughput (token/s) |
128+
|:-:|:-:|:-:|:-:|:-:|
129+
| MMaDA-8B-Base | 1 | 1 | text generation | 12.53 |
130+
| MMaDA-8B-Base | 1 | 1 | mmu generation | 13.50 |
131+
| MMaDA-8B-Base | 1 | 1 | text-to-image generation| 168.60 |
132+
133+
116134
## 🔧 Training
117135
118136
@@ -164,6 +182,21 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --
164182
python training/train_mmada_stage2.py config=configs/mmada_finetune_artwork.yaml
165183
```
166184

185+
### Performance
186+
187+
The following experiments are tested on Ascend Atlas 800T A2 machines with mindspore **2.7.0** under **pynative** mode:
188+
189+
| model | # card(s) | batch size | parallelism |task | per batch time (seconds) |
190+
|:-:|:-:|:-:|:-:|:-:|:-:|
191+
| MMaDA-8B-Base | 8 | 4 | zero2 | finetune | 1.29 |
192+
193+
The following experiments are tested on Ascend Atlas 800T A2 machines with mindspore **2.6.0** under **pynative** mode:
194+
195+
| model | # card(s) | batch size | parallelism | task | per batch time (seconds) |
196+
|:-:|:-:|:-:|:-:|:-:|:-:|
197+
| MMaDA-8B-Base | 8 | 4 | zero2 | finetune | 1.30 |
198+
199+
167200

168201
## 🤝 Acknowledgments
169202

examples/mmada/generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def main():
168168
remasking="low_confidence",
169169
)
170170
print(f"Inference time: {time() - infer_start:.3f}s")
171+
print(f"Throughput: {out.shape[1] / (time() - infer_start):.3f} token/s")
171172
print(tokenizer.batch_decode(out[:, input_ids.shape[1] :], skip_special_tokens=True))
172173

173174

examples/mmada/inference_mmu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# limitations under the License.
1818

1919
import os
20+
from time import time
2021

2122
os.environ["TOKENIZERS_PARALLELISM"] = "true"
2223
os.environ["SAFETENSORS_WEIGHTS_NAME"] = "pytorch_model.safetensors" # vq_model
@@ -122,6 +123,8 @@ def draw_caption_on_image(
122123
responses = ["" for i in range(len(file_list))]
123124
images = []
124125
config.question = config.question.split(" *** ")
126+
127+
throughputs = []
125128
for i, file_name in enumerate(tqdm(file_list)):
126129
image_path = os.path.join(config.mmu_image_root, file_name)
127130
image_ori = Image.open(image_path).convert("RGB")
@@ -152,10 +155,12 @@ def draw_caption_on_image(
152155
],
153156
dim=1,
154157
)
158+
infer_start = time()
155159
output_ids = model.mmu_generate(input_ids, max_new_tokens=1024, steps=512, block_length=1024)
156160
text = uni_prompting.text_tokenizer.batch_decode(
157161
output_ids[:, input_ids.shape[1] :], skip_special_tokens=True
158162
)
163+
throughputs.append(output_ids.shape[1] / (time() - infer_start))
159164
print(text[0])
160165
responses[i] += text[0]
161166

@@ -169,3 +174,4 @@ def draw_caption_on_image(
169174
draw_caption_on_image(pil_images, responses, output_dir, file_list=file_list)
170175

171176
print("Generated captions are saved in", output_dir)
177+
print(f"Average throughput: {np.mean(throughputs):.3f} token/s")

examples/mmada/inference_t2i.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# limitations under the License.
1818

1919
import os
20+
from time import time
2021

2122
os.environ["TOKENIZERS_PARALLELISM"] = "true"
2223
import numpy as np
@@ -116,6 +117,8 @@ def draw_caption_on_image(
116117
with open(config.dataset.params.validation_prompts_file, "r") as f:
117118
validation_prompts = f.read().splitlines()
118119
output_images, output_responses = [], []
120+
print("Generating images with batch size: ", config.training.batch_size)
121+
throughputs = []
119122
for step in tqdm(range(0, len(validation_prompts), config.training.batch_size)):
120123
prompts = validation_prompts[step : step + config.training.batch_size]
121124

@@ -133,7 +136,7 @@ def draw_caption_on_image(
133136
mask_schedule = get_mask_schedule(schedule, **args)
134137
else:
135138
mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine"))
136-
139+
infer_start = time()
137140
gen_token_ids = model.t2i_generate(
138141
input_ids=input_ids,
139142
uncond_input_ids=uncond_input_ids,
@@ -153,6 +156,7 @@ def draw_caption_on_image(
153156
images = vq_model.decode_code(gen_token_ids)
154157
output_images.append(images)
155158
output_responses.extend(prompts)
159+
throughputs.append(gen_token_ids.shape[1] / config.training.batch_size / (time() - infer_start))
156160

157161
images = mint.cat(output_images, dim=0)
158162
images = mint.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
@@ -162,5 +166,5 @@ def draw_caption_on_image(
162166
output_dir = "./inference_t2i_outputs/"
163167
os.makedirs(output_dir, exist_ok=True)
164168
draw_caption_on_image(pil_images, output_responses, output_dir)
165-
169+
print(f"Average throughput: {np.mean(throughputs):.3f} token/s")
166170
print("Generated images are saved in ", output_dir)

mindone/trainers/zero.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,19 @@ def split_param(self, param):
271271

272272
def get_optimizer_param_tuples(self):
273273
param_tuples = []
274-
for attr in self.optimizer.__dict__:
275-
if isinstance(getattr(self.optimizer, attr), ms.ParameterTuple):
276-
if attr in ["_parameters", "parameters"]:
274+
if ms.get_context("mode") == ms.PYNATIVE_MODE:
275+
for name in self.optimizer._params_list:
276+
if name in ["_parameters", "parameters"]:
277277
continue
278-
_logger.debug(f"Add optimizer param_tuples {attr}")
279-
param_tuples.append(getattr(self.optimizer, attr))
278+
_logger.debug(f"Add optimizer param_tuples {name}")
279+
param_tuples.append(getattr(self.optimizer, name))
280+
else:
281+
for attr in self.optimizer.__dict__:
282+
if isinstance(getattr(self.optimizer, attr), ms.ParameterTuple):
283+
if attr in ["_parameters", "parameters"]:
284+
continue
285+
_logger.debug(f"Add optimizer param_tuples {attr}")
286+
param_tuples.append(getattr(self.optimizer, attr))
280287
return param_tuples
281288

282289
def dump_params_split_info(self, params_split_info):

0 commit comments

Comments
 (0)