Skip to content

Commit d9906fa

Browse files
authored
fix sync parallel and support low_cpu_mem_usage (#1814)
1 parent e6924bb commit d9906fa

File tree

21 files changed

+490
-25
lines changed

21 files changed

+490
-25
lines changed

.github/pylint.conf

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ disable=raw-checker-failed,
216216
consider-using-generator,
217217
fixme,
218218
use-a-generator,
219-
nested-min-max
219+
nested-min-max,
220+
method-hidden
220221

221222
# Enable the message, report, category or checker with the given id(s). You can
222223
# either give multiple identifier separated by comma (,) or put this option

llm/inference/chatglm3/cli_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import platform
3-
from mindnlp.transformers import ChatGLM3Tokenizer, ChatGLM3ForConditionalGeneration
3+
from mindnlp.transformers import ChatGLM3Tokenizer, AutoModelForCausalLM
44

55
tokenizer = ChatGLM3Tokenizer.from_pretrained("ZhipuAI/chatglm3-6b", mirror='modelscope', revision='master')
6-
model = ChatGLM3ForConditionalGeneration.from_pretrained("ZhipuAI/chatglm3-6b", mirror='modelscope', revision='master')
6+
model = AutoModelForCausalLM.from_pretrained("ZhipuAI/chatglm3-6b", mirror='modelscope', revision='master')
77
model = model.set_train(False)
88

99
os_name = platform.system()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import mindspore
2+
from mindnlp.transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache
3+
from mindnlp.core import ops
4+
import time
5+
6+
prompts = [
7+
"Simply put, the theory of relativity states that ",
8+
"My favorite all time favorite condiment is ketchup.",
9+
]
10+
11+
NUM_TOKENS_TO_GENERATE = 40
12+
13+
model_id = 'shakechen/llama-2-7b-hf'
14+
tokenizer = LlamaTokenizer.from_pretrained(model_id, mirror='modelscope', pad_token="</s>", padding_side="right")
15+
model = LlamaForCausalLM.from_pretrained(model_id, mirror='modelscope', use_safetensors=False, ms_dtype=mindspore.float16)
16+
17+
inputs = tokenizer(prompts, return_tensors="ms", padding=True)
18+
19+
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
20+
logits = model(
21+
cur_token,
22+
position_ids=input_pos,
23+
cache_position=cache_position,
24+
past_key_values=past_key_values,
25+
return_dict=False,
26+
use_cache=True
27+
)[0]
28+
new_token = ops.argmax(logits[:, -1], dim=-1)[:, None]
29+
return new_token
30+
31+
batch_size, seq_length = inputs["input_ids"].shape
32+
# with no_grad():
33+
past_key_values = StaticCache(
34+
config=model.config, max_batch_size=2, max_cache_len=1024, dtype=model.dtype
35+
)
36+
cache_position = ops.arange(seq_length)
37+
generated_ids = ops.zeros(
38+
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=mindspore.int32
39+
)
40+
generated_ids[:, cache_position] = inputs["input_ids"].to(mindspore.int32)
41+
42+
logits = model(
43+
**inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
44+
)[0]
45+
next_token = ops.argmax(logits[:, -1], dim=-1)[:, None]
46+
generated_ids[:, seq_length] = next_token[:, 0]
47+
48+
model.compile(jit_config=mindspore.JitConfig(jit_syntax_level='STRICT'))
49+
50+
cache_position = mindspore.tensor([seq_length + 1])
51+
for _ in range(1, NUM_TOKENS_TO_GENERATE):
52+
s = time.time()
53+
next_token = decode_one_tokens(model, next_token, None, cache_position, past_key_values)
54+
t = time.time()
55+
print(t - s)
56+
generated_ids[:, cache_position] = next_token.int()
57+
cache_position += 1
58+
59+
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
60+
print(text)

llm/inference/llama3/readme.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
## Run distributed (pipeline parallel)
2+
3+
### use msrun (recommend)
4+
5+
`msrun` is a MindSpore defined launcher for multi-process parallel execution, which can get best performance, you can use it by the command below:
6+
7+
```bash
8+
msrun --worker_num=2 --local_worker_num=2 --master_port=8118 --join=True run_llama3_distributed.py
9+
```
10+
11+
if you use Ascend NPU with Kunpeng CPU, you should bind-core to get better performance
12+
13+
```bash
14+
msrun --worker_num=2 --local_worker_num=2 --master_port=8118 --join=True --bind_core=True run_llama3_distributed.py
15+
```
16+
17+
### use mpirun
18+
19+
`mpirun` controls several aspects of program execution in Open MPI, you can use it by the command below:
20+
21+
```bash
22+
mpirun -n 2 python run_llama3_distributed.py
23+
```
24+
25+
if you use Ascend NPU with Kunpeng CPU, you should bind-core to get better performance:
26+
27+
```bash
28+
mpirun --bind-to numa -n 2 python run_llama3_distributed.py
29+
```
30+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import psutil
3+
import gc
4+
from memory_profiler import profile
5+
import mindspore
6+
from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM
7+
8+
model_id = "LLM-Research/Meta-Llama-3-8B-Instruct"
9+
10+
@profile
11+
def test():
12+
tokenizer = AutoTokenizer.from_pretrained(model_id, mirror='modelscope')
13+
model = AutoModelForCausalLM.from_pretrained(
14+
model_id,
15+
ms_dtype=mindspore.float16,
16+
mirror='modelscope',
17+
low_cpu_mem_usage=True
18+
)
19+
20+
if __name__ == '__main__':
21+
22+
a=test()
23+
24+
print('A:%.2f MB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024))
25+
del a
26+
gc.collect()
27+
print('B:%.2f MB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024))

mindnlp/accelerate/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,13 @@
1616
# load_checkpoint_in_model,
1717
# synchronize_rng_states,
1818
)
19+
20+
from .big_modeling import (
21+
# cpu_offload,
22+
# cpu_offload_with_hook,
23+
# disk_offload,
24+
# dispatch_model,
25+
init_empty_weights,
26+
init_on_empty,
27+
# load_checkpoint_and_dispatch,
28+
)

mindnlp/accelerate/big_modeling.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""big modeling"""
2+
from contextlib import contextmanager
3+
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
4+
from mindnlp.utils.testing_utils import parse_flag_from_env
5+
from mindnlp.core import nn
6+
7+
@contextmanager
8+
def init_empty_weights(include_buffers: bool = None):
9+
"""
10+
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
11+
empty model. Useful when just initializing the model would blow the available RAM.
12+
13+
Args:
14+
include_buffers (`bool`, *optional*):
15+
Whether or not to also put all buffers on the meta device while initializing.
16+
17+
Example:
18+
19+
```python
20+
import torch.nn as nn
21+
from accelerate import init_empty_weights
22+
23+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
24+
with init_empty_weights():
25+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
26+
```
27+
28+
<Tip warning={true}>
29+
30+
Any model created under this context manager has no weights. As such you can't do something like
31+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
32+
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
33+
called.
34+
35+
</Tip>
36+
"""
37+
if include_buffers is None:
38+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
39+
with init_on_empty(include_buffers=include_buffers) as f:
40+
yield f
41+
42+
43+
@contextmanager
44+
def init_on_empty(include_buffers: bool = None):
45+
"""
46+
A context manager under which models are initialized with all parameters on the specified device.
47+
48+
Args:
49+
device (`torch.device`):
50+
Device to initialize all parameters on.
51+
include_buffers (`bool`, *optional*):
52+
Whether or not to also put all buffers on the meta device while initializing.
53+
54+
Example:
55+
56+
```python
57+
import torch.nn as nn
58+
from accelerate import init_on_device
59+
60+
with init_on_device(device=torch.device("cuda")):
61+
tst = nn.Linear(100, 100) # on `cuda` device
62+
```
63+
"""
64+
if include_buffers is None:
65+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
66+
67+
old_register_parameter = nn.Module.register_parameter
68+
if include_buffers:
69+
old_register_buffer = nn.Module.register_buffer
70+
71+
def register_empty_parameter(module, name, param):
72+
old_register_parameter(module, name, param)
73+
if param is not None:
74+
kwargs = module._parameters[name].__dict__
75+
kwargs["requires_grad"] = param.requires_grad
76+
module._parameters[name].assign_value(Tensor_(shape=(), dtype=module._parameters[name].dtype))
77+
module._parameters[name].meta = True
78+
79+
def register_empty_buffer(module, name, buffer, persistent=True):
80+
old_register_buffer(module, name, buffer, persistent=persistent)
81+
if buffer is not None:
82+
module._buffers[name].assign_value(Tensor_(shape=(), dtype=module._buffers[name].dtype))
83+
module._buffers[name].meta = True
84+
85+
try:
86+
nn.Module.register_parameter = register_empty_parameter
87+
if include_buffers:
88+
nn.Module.register_buffer = register_empty_buffer
89+
yield
90+
finally:
91+
nn.Module.register_parameter = old_register_parameter
92+
if include_buffers:
93+
nn.Module.register_buffer = old_register_buffer

mindnlp/accelerate/utils/modeling.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
from typing import Optional, Dict, Union, List, Tuple, Set
1111
import mindspore
1212
from mindspore.communication import get_group_size, get_rank
13+
from mindnlp.configs import SUPPORT_ASYNC_DIST_OP
1314
try:
14-
from mindspore.communication.comm_func import isend, irecv, broadcast
15+
if SUPPORT_ASYNC_DIST_OP:
16+
from mindspore.communication.comm_func import send as isend, recv as irecv, broadcast
17+
else:
18+
from mindspore.communication.comm_func import isend, irecv, broadcast
1519
except:
1620
from mindnlp.parallel.comm_func import isend, irecv, broadcast
1721

File renamed without changes.

0 commit comments

Comments
 (0)