File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
experiments/auto_parallel Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -32,6 +32,7 @@ def parallelize_llama(
32
32
NOTE: The passed-in model preferably should be on meta device. Otherwise,
33
33
the model must fit on GPU or CPU memory.
34
34
"""
35
+
35
36
def input_fn ():
36
37
global_batch_size = job_config .training .global_batch_size
37
38
if global_batch_size < 0 :
Original file line number Diff line number Diff line change 12
12
13
13
import torch
14
14
from torch .distributed .elastic .multiprocessing .errors import record
15
+ from torch .distributed .tensor import DTensor
16
+
15
17
import torchtitan .components .ft as ft
16
18
import torchtitan .protocols .train_spec as train_spec_module
17
19
from torchtitan .components .checkpoint import CheckpointManager
23
25
)
24
26
from torchtitan .config_manager import ConfigManager , JobConfig
25
27
from torchtitan .distributed import ParallelDims , utils as dist_utils
26
- from torch .distributed .tensor import DTensor
27
28
from torchtitan .protocols .model_converter import build_model_converters
28
29
from torchtitan .tools import utils
29
30
from torchtitan .tools .logging import init_logger , logger
@@ -115,7 +116,7 @@ def __init__(self, job_config: JobConfig):
115
116
116
117
# TODO(whc)
117
118
# I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering
118
- torch ._inductor .config .force_disable_caches = True
119
+ torch ._inductor .config .force_disable_caches = True
119
120
120
121
# allow configuring inductor comms optimizations from torchtitan commandline
121
122
torch ._inductor .config .reorder_for_compute_comm_overlap = (
You can’t perform that action at this time.
0 commit comments