Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions examples/janus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ huggingface-cli download deepseek-ai/Janus-Pro-7B --revision refs/pr/110 --local

The code is tested in the following environments

| mindspore | ascend driver | firmware | cann tookit/kernel |
| :---: | :---: | :---: | :---: |
| 2.5.0 | 24.1.0 |7.35.23 | 8.0.RC3.beta1 |
| mindspore | ascend driver | firmware | cann tookit/kernel |
| :-------: | :-----------: | :---------: | :----------------: |
| 2.6.0 | 25.0.RC1.1 | 7.7.0.1.231 | 8.1.RC1 |
| 2.7.0 | 25.2.0 | 7.7.0.6.236 | 8.2.RC1 |


### Installation
Expand Down Expand Up @@ -138,36 +139,36 @@ Please refer to [training.md](docs/training.md)

### Multimodal Understanding

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 **graph** mode:
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 **graph** mode:

| model | # card(s) | image size | attn. type | throughput (token/s)|
|:-:|:-:|:-:|:-:|:-:|
| Janus-Pro-1B | 1 | 384x384 | Eager | 16.6|
| Janus-Pro-7B | 1 | 384x384 | Eager | 12.2|
| Janus-Pro-1B | 1 | 384x384 | Eager | 17.5 |
| Janus-Pro-7B | 1 | 384x384 | Eager | 13.6 |


Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 **pynative** mode:
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 **pynative** mode:

| model | # card(s) | image size | attn. type | throughput (token/s)|
|:-:|:-:|:-:|:-:|:-:|
| Janus-Pro-1B | 1 | 384x384 | Eager | 5.88 |
| Janus-Pro-7B | 1 | 384x384 | Eager | 3.30|
| Janus-Pro-1B | 1 | 384x384 | Eager | 7.55 |
| Janus-Pro-7B | 1 | 384x384 | Eager | 6.39 |

### Visual Generation

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 **graph** mode:
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 **graph** mode:

| model | # card(s) | batch Size | image size | attn. type | throughput (token/s)| s/img |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
| Janus-Pro-1B | 1 | 1 | 384x384 | Eager | 16.2 | ~ 40 |
| Janus-Pro-7B | 1 | 1 | 384x384 | Eager | 11.9 | ~ 52 |
| Janus-Pro-1B | 1 | 1 | 384x384 | Eager | 14.6 | ~ 44 |
| Janus-Pro-7B | 1 | 1 | 384x384 | Eager | 12.2 | ~ 51 |

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 **pynative** mode:
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 **pynative** mode:

| model | # card(s) | batch size| image size | attn. type | throughput (token/s)| s/img |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
| Janus-Pro-1B | 1 | 1 | 384x384 | Eager | 4.52 | ~ 127|
| Janus-Pro-7B | 1 | 1 | 384x384 | Eager | 3.56 | ~ 162|
| Janus-Pro-1B | 1 | 1 | 384x384 | Eager | 7.08 | ~ 81 |
| Janus-Pro-7B | 1 | 1 | 384x384 | Eager | 5.81 | ~ 99 |

* All the performances are tested with KV-Cache enabled.

Expand Down
30 changes: 15 additions & 15 deletions examples/janus/docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ bash scripts/run_sft_text.sh # if no manual patching, by default it should be c

Patching `janus/models/modeling_vlm.py`: **Single task for pure text**
```diff
# @ L428
# @ L431
-- def construct(
++ # def construct( # just comment the whole function out

# @ L476
# @ L479
-- def construct_graph_single_task(
++ def construct(
```
Expand Down Expand Up @@ -77,45 +77,45 @@ We also implemented **a stage-3 SFT for medical data aiming for building a radio

> [!NOTE]
> We achieve higher training throughput by enabling graph mode compute. However, to do that we need to predefine a compute graph for the vlm for each of the task out of three in total, as for each task, the vlm takes different types of input arg pairs.
> This feature is for MindSpore 2.5.0 only. It is no longer supported in Mindspore 2.7.0
>
> To run `scripts/run_sft_mixed_graph.sh`, simply go into `janus/models/modeling_vlm.py`, and patch `construct_*()` into `construct()` as follows.
```diff
# @ L428
# @ L431
-- def construct(
++ # def construct( # just comment the whole function out

# @ L570
# @ L573
-- def construct_graph_mixed_task(
++ def construct(
```

#### Pynative Mode SFT Training for Mixed Tasks
```diff
# @ L428
# @ L431
-- def construct(
++ # def construct( # just comment the whole function out

# @ L516
# @ L519
-- def construct_pynative_mixed_task(
++ def construct(
```

## Performance

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.5.0 pynative mode:
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 pynative mode:

| model | task | # card(s) | image size | max_length | batch size | step time (s/step)|
|:-:|:--:| :-:|:-:|:-:|:-:|:-:|
| Janus-Pro-1B | T2I | 1 | 384x384 | 1024 | 8 | 0.66 |
| Janus-Pro-1B | VQA | 1 | 384x384 | 1024 | 4 | 0.59 |
| Janus-Pro-1B | Text | 1 | n.a. | 512 | 8 | 0.50 |
| Janus-Pro-7B | T2I | 1 | 384x384 | 1024 | 1 | 0.49 |
| Janus-Pro-7B | VQA | 1 | 384x384 | 1024 | 1 | 0.66 |
| Janus-Pro-7B | Text | 1 | n.a. | 512 | 1 | 0.53 |
| Janus-Pro-1B | T2I | 1 | 384x384 | 1024 | 8 | 0.60 |
| Janus-Pro-1B | VQA | 1 | 384x384 | 1024 | 4 | 0.42 |
| Janus-Pro-1B | Text | 1 | n.a. | 512 | 8 | 0.46 |
| Janus-Pro-7B | T2I | 1 | 384x384 | 1024 | 1 | 0.51 |
| Janus-Pro-7B | VQA | 1 | 384x384 | 1024 | 1 | 0.46 |
| Janus-Pro-7B | Text | 1 | n.a. | 512 | 1 | 0.57 |

For mixed-SFT:

| model | task | ms_mode | # card(s) | image size | max_length | batch size | step time (s/step)|
|:-:|:--:| :-:|:-:|:-:|:-:|:-:|:-:|
| Janus-Pro-1B | mixed | pynative | 1 | 384x384 | 1024 | 6 | 3.05 |
| Janus-Pro-1B | mixed | graph | 1 | 384x384 | 1024 | 6 | 2.36 |
| Janus-Pro-1B | mixed | pynative | 1 | 384x384 | 1024 | 6 | 2.30 |
6 changes: 5 additions & 1 deletion examples/janus/generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generate(
# ms context
ms.set_context(mode=args.ms_mode)
if args.ms_mode == 0:
ms.set_context(jit_config={"jit_level": "O0"}, enable_compile_cache=True)
ms.set_context(jit_config={"jit_level": "O0"})
set_random_seed(args.seed)

# specify the path to the model
Expand All @@ -212,6 +212,10 @@ def generate(
vl_gpt = set_model_param_dtype(vl_gpt, dtype)
vl_gpt.set_train(False)

if args.ms_mode == 0:
# in graph mode, cache class is not supported yet
vl_gpt.language_model._supports_cache_class = False

if args.ms_mode == 0 and not args.use_cache:
bs = args.parallel_size * 2
hidden_size = vl_gpt.language_model.model.layers[0].hidden_size
Expand Down
6 changes: 5 additions & 1 deletion examples/janus/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def multimodal_understanding(
# ms context
ms.set_context(mode=args.ms_mode)
if args.ms_mode == 0:
ms.set_context(jit_config={"jit_level": "O0"}, enable_compile_cache=True)
ms.set_context(jit_config={"jit_level": "O0"})

# specify the path to the model
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(args.model_path)
Expand All @@ -163,6 +163,10 @@ def multimodal_understanding(

vl_gpt.set_train(False)

if args.ms_mode == 0:
# in graph mode, cache class is not supported yet
vl_gpt.language_model._supports_cache_class = False

# infer
answer, prepare_inputs = multimodal_understanding(
args.image,
Expand Down
3 changes: 2 additions & 1 deletion examples/janus/janus/models/modeling_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ def und_with_loss(

# FIXME same workaround as above, for the ms2.5.0 graph mode constraint
image_seq_mask = image_seq_mask.nonzero().squeeze()
inputs_embeds[image_seq_mask] = image_embeds
if image_seq_mask.numel() > 0:
inputs_embeds[image_seq_mask] = image_embeds

inputs_embeds = inputs_embeds.reshape(B, S, D)

Expand Down
50 changes: 0 additions & 50 deletions examples/janus/pyproject.toml

This file was deleted.

3 changes: 2 additions & 1 deletion examples/janus/scripts/run_sft_mixed_pynative.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ python train.py \
--warmup_steps 50 \
--ckpt_save_steps 1000 \
--ckpt_max_keep 1 \
--output_path outputs/stage${stage}_mixed_lr${lr}_wd${wd}_bs${bs}_npp${npp}_mode${ms_mode} \ --mixed_task_rand_samp \
--output_path outputs/stage${stage}_mixed_lr${lr}_wd${wd}_bs${bs}_npp${npp}_mode${ms_mode} \
--mixed_task_rand_samp
1 change: 0 additions & 1 deletion examples/janus/scripts/run_sft_text.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dataset_meta_path=YOUR_DATA_PATH
pretrained_ckpt_path=YOUR_DOWNLOADED_JANUS_CKPT_PATH

python train.py \
--ms_mode 0 \
--model_path ${pretrained_ckpt_path} \
--load_weight True \
--task 'text' \
Expand Down
4 changes: 2 additions & 2 deletions examples/janus/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def main(args):
dataloader = create_unified_dataloader_weightrandsamp(
vl_chat_processor,
t2i_csv_path=args.t2i_csv_path,
t2i_data_path=args.t2i_data_path,
t2i_data_dir=args.t2i_data_dir,
t2i_parquet_dir=args.t2i_parquet_dir,
text_data_dir=args.text_qa_data_dir,
vqa_data_dir=args.vqa_data_dir,
Expand All @@ -243,7 +243,7 @@ def main(args):
dataloader = create_dataloader_unified(
vl_chat_processor,
t2i_csv_path=args.t2i_csv_path,
t2i_data_path=args.t2i_data_path,
t2i_data_path=args.t2i_data_dir,
vqa_data_dir=args.vqa_data_dir,
text_qa_data_dir=args.text_qa_data_dir,
num_samples_vqa=100,
Expand Down