Skip to content

Commit 9af34c1

Browse files
mresolessw2020larryliu0820vmpurivmpuri
authored
Integrate distributed inference into torchchat cli (pytorch#1327)
* add pp_dim, distributed, num_gpus, num_nodes as cmd line args * add tp_dim * add elastic_launch * working, can now launch from cli * Remove numpy < 2.0 pin to align with pytorch (pytorch#1301) Fix pytorch#1296 Align with https://github.com/pytorch/pytorch/blame/main/requirements.txt#L5 * Update torchtune pin to 0.4.0-dev20241010 (pytorch#1300) Co-authored-by: vmpuri <puri@meta.com> * Unbreak gguf util CI job by fixing numpy version (pytorch#1307) Setting numpy version to be the range required by gguf: https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/pyproject.toml * Remove apparently-unused import torchvision in model.py (pytorch#1305) Co-authored-by: vmpuri <45368418+vmpuri@users.noreply.github.com> * remove global var for tokenizer type + patch tokenizer to allow list of sequences * make pp tp visible in interface * Add llama 3.1 to dist_run.py * [WIP] Move dist inf into its own generator * Add initial generator interface to dist inference * Added generate method and placeholder scheduler * use prompt parameter for dist generation * Enforce tp>=2 * Build tokenizer from TokenizerArgs * Disable torchchat format + constrain possible models for distributed * disable calling dist_run.py directly for now * Restore original dist_run.py for now * disable _maybe_parallelize_model again * Reenable arg.model_name in dist_run.py * Use singleton logger instead of print in generate * Address PR comments; try/expect in launch_dist_inference; added comments --------- Co-authored-by: lessw2020 <lessw@etrillium.com> Co-authored-by: Mengwei Liu <larryliu0820@users.noreply.github.com> Co-authored-by: vmpuri <45368418+vmpuri@users.noreply.github.com> Co-authored-by: vmpuri <puri@meta.com> Co-authored-by: Scott Wolchok <swolchok@meta.com>
1 parent 7fe2c86 commit 9af34c1

File tree

6 files changed

+1010
-53
lines changed

6 files changed

+1010
-53
lines changed

dist_run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
2121
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
2222

23-
from torchchat.distributed.logging_utils import SingletonLogger
24-
2523
# TODO - these are not distributed specific, consider moving to new package
2624
from torchchat.distributed.checkpoint_utils import (
2725
get_hf_config_file,
2826
load_weights_from_hf_format,
2927
load_weights_from_torchchat_format,
3028
)
29+
30+
from torchchat.distributed.logging_utils import SingletonLogger
3131
from torchchat.distributed.utils import (
3232
bytes_to_readable,
3333
Color as color,
@@ -153,7 +153,9 @@ def _load_model_weights(
153153
# This format stands for:
154154
# single binary file, OR
155155
# multiple binary files without index files.
156-
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
156+
load_weights_from_torchchat_format(
157+
stage_module, distribution, device, model_config
158+
)
157159
else:
158160
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
159161

@@ -593,9 +595,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
593595
parser.add_argument(
594596
"model_name",
595597
type=str,
598+
default="llama3",
596599
help="Name of the model to load",
597600
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
598601
)
602+
599603
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
600604
parser.add_argument(
601605
"--ntokens",

torchchat/cli/builder.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,14 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
20-
21-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
22-
2319
from torch.distributed.device_mesh import DeviceMesh
20+
from torch.distributed.elastic.multiprocessing.errors import record
21+
from torch.distributed.elastic.utils.distributed import get_free_port
2422

25-
from torchtune.models.convert_weights import meta_to_tune
26-
27-
from torchtune.training import set_default_dtype
23+
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2824

2925
from torchchat.model import Model, ModelArgs, ModelType
3026

31-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
32-
3327
from torchchat.model_config.model_config import resolve_model_config
3428
from torchchat.utils.build_utils import (
3529
device_sync,
@@ -40,6 +34,14 @@
4034
from torchchat.utils.measure_time import measure_time
4135
from torchchat.utils.quantize import quantize_model
4236

37+
from torchtune.models.convert_weights import meta_to_tune
38+
39+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
40+
41+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
42+
43+
from torchtune.training import set_default_dtype
44+
4345

4446
@dataclass
4547
class BuilderArgs:
@@ -55,7 +57,10 @@ class BuilderArgs:
5557
device: Optional[str] = None
5658
precision: torch.dtype = torch.float32
5759
setup_caches: bool = False
58-
use_distributed: bool = False
60+
distributed: bool = False
61+
pp: int = 1
62+
tp: int = 1
63+
chpt_from: str = "hf"
5964
is_chat_model: bool = False
6065
prefill_possible: bool = False
6166
dynamic_shapes: bool = False
@@ -87,7 +92,9 @@ def __post_init__(self):
8792
]
8893
for param, param_msg in ignored_params:
8994
if param:
90-
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
95+
print(
96+
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
97+
)
9198
else:
9299
self.prefill_possible = True
93100

@@ -153,7 +160,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
153160
dtype = torch.float16
154161
else:
155162
dtype = name_to_dtype(args.dtype, args.device)
156-
163+
# distributed args
164+
distributed = getattr(args, "distributed", False)
165+
pp = getattr(args, "pp", 1)
166+
tp = getattr(args, "tp", 1)
167+
chpt_from = getattr(args, "chpt_from", "hf")
157168
return cls(
158169
checkpoint_dir=checkpoint_dir,
159170
checkpoint_path=checkpoint_path,
@@ -167,7 +178,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
167178
device=args.device,
168179
precision=dtype,
169180
setup_caches=(output_dso_path or output_pte_path),
170-
use_distributed=args.distributed,
181+
distributed=distributed,
182+
pp=pp,
183+
tp=tp,
184+
chpt_from=chpt_from,
171185
is_chat_model=is_chat_model,
172186
dynamic_shapes=getattr(args, "dynamic_shapes", False),
173187
max_seq_length=getattr(args, "max_seq_length", None),
@@ -397,10 +411,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
397411
# does not host any actual values, need to reinitialize them in the actual
398412
# device. Only do those buffer initialization, without initializing the entire
399413
# model.
400-
decoder_config = model.config.transformer_args['decoder']
401-
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
402-
max_seq_len = decoder_config['max_seq_len']
403-
rope_base = decoder_config['rope_base']
414+
decoder_config = model.config.transformer_args["decoder"]
415+
head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"]
416+
max_seq_len = decoder_config["max_seq_len"]
417+
rope_base = decoder_config["rope_base"]
404418
for submodule in model.modules():
405419
if isinstance(submodule, Llama3ScaledRoPE):
406420
submodule.__init__(head_dim, max_seq_len, rope_base)
@@ -476,18 +490,19 @@ def _maybe_parallelize_model(
476490

477491

478492
def _load_model(builder_args: BuilderArgs) -> Model:
479-
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
493+
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
480494
if builder_args.gguf_path:
481495
model = _load_model_gguf(builder_args)
482-
elif builder_args.use_distributed:
483-
model = _init_model_on_meta_device(builder_args)
496+
# elif builder_args.use_distributed:
497+
# model = _init_model_on_meta_device(builder_args)
484498
else:
485499
model = _load_model_default(builder_args)
486-
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
500+
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
487501

488502
model = model.to(device=builder_args.device, dtype=builder_args.precision)
489503
return model.eval()
490504

505+
491506
def _initialize_model(
492507
builder_args: BuilderArgs,
493508
quantize,
@@ -496,7 +511,6 @@ def _initialize_model(
496511
support_tensor_subclass: bool = True,
497512
) -> Model:
498513
print("Loading model...")
499-
500514
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
501515
print("Setting gguf_kwargs for generate.")
502516
is_dso = builder_args.dso_path is not None

torchchat/cli/cli.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,7 @@ def _add_distributed_args(parser) -> None:
399399
parser.add_argument(
400400
"--distributed",
401401
action="store_true",
402-
help=argparse.SUPPRESS,
403-
# "Whether to enable distributed inference",
402+
help="Whether to enable distributed inference",
404403
)
405404
parser.add_argument(
406405
"--dcp-dir",
@@ -409,6 +408,27 @@ def _add_distributed_args(parser) -> None:
409408
help=argparse.SUPPRESS,
410409
# "Use the specified model checkpoint directory",
411410
)
411+
parser.add_argument(
412+
"--pp",
413+
"--pipeline-parallel",
414+
type=int,
415+
default=1,
416+
help="Pipeline parallel degree",
417+
)
418+
parser.add_argument(
419+
"--tp",
420+
"--tensor-parallel",
421+
type=int,
422+
default=2,
423+
help="Tensor parallel degree",
424+
)
425+
parser.add_argument(
426+
"--chpt-from",
427+
type=str,
428+
default="hf", # TODO: change to torchchat once we support it well
429+
help="Checkpoint format to load from",
430+
choices=["hf", "torchchat"],
431+
)
412432

413433

414434
# Add CLI Args related to custom model inputs
@@ -425,13 +445,13 @@ def _add_custom_model_args(parser) -> None:
425445
"--params-path",
426446
type=Path,
427447
default=None,
428-
help= "Use the specified parameter file, instead of one specified under torchchat.model_params",
448+
help="Use the specified parameter file, instead of one specified under torchchat.model_params",
429449
)
430450
parser.add_argument(
431451
"--tokenizer-path",
432452
type=Path,
433453
default=None,
434-
help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
454+
help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
435455
)
436456

437457

0 commit comments

Comments
 (0)