diff --git a/configs/20B_internlm2.py b/configs/20B_internlm2.py new file mode 100644 index 000000000..14fc06996 --- /dev/null +++ b/configs/20B_internlm2.py @@ -0,0 +1,248 @@ +JOB_NAME = "7b_internlm2_train" +model_type = "INTERNLM2" +DO_ALERT = False + +VOCAB_SIZE = 92544 +SEQ_LEN = 16*1024 +HIDDEN_SIZE = 6144 +NUM_ATTENTION_HEAD = 48 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 8 / 3 +NUM_LAYER = 48 + + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=False, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" +VALID_FOLDER = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" # "/path/to/dataset" +data = dict( + type="tokenized", + # tokenizer_path="/mnt/petrelfs/lusitian/tokenizer/hf-internlm2-tokenizer", + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=50, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. +loss = dict(label_smoothing=0, op_type="py_vocab_parallel") + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False +model = dict( + checkpoint=0.5, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + num_chunks=1, + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + no_bias=True, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. + 3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + + wdp = world_size / pp / wp + zero1 <= wdp +""" +parallel = dict( + zero1=dict(size=8), + tensor=dict(size=1, mode="isp"), + pipeline=dict(size=4, interleaved_overlap=True, mode="1f1b"), + weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) + +# cpu_offloading = dict( +# enable=True, +# num_layers=10, +# offloading_activations=True, +# ) + + +selective_checkpoint = True +selective_checkpoint_offload = False diff --git a/configs/20B_isp_sft.py b/configs/20B_isp_sft.py new file mode 100644 index 000000000..acc0eff29 --- /dev/null +++ b/configs/20B_isp_sft.py @@ -0,0 +1,273 @@ +JOB_NAME = "20b_internlm2_train" +TASK_NAME = "0312-20B-ckpt-Dweb-64k-t32w4z8-G32-S50" +# MEMORY_PATH = "20B_64k_32g" +model_type = "INTERNLM2" +DO_ALERT = False + +VOCAB_SIZE = 92544 +SEQ_LEN = 64*1024 +HIDDEN_SIZE = 6144 +NUM_ATTENTION_HEAD = 48 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 8 / 3 +NUM_LAYER = 48 + + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + load_ckpt_folder="local:llm_ckpts/", + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=True, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +# TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0715_llama_tokenized_refined_real/train/" +TRAIN_FOLDER = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" # "/path/to/dataset" +VALID_FOLDER = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" # "/path/to/dataset" +data = dict( + type="tokenized", + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=50, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, + # use_packed_dataset=False, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. + +loss = dict( + label_smoothing=0, + op_type="py_vocab_parallel", # flash_vocab_parallel +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +# cpu_offloading = dict( +# enable=True, +# num_layers=3, +# ) +selective_checkpoint = False +selective_checkpoint_offload = False + +use_fp32_norm = False +model = dict( + checkpoint=1, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + no_bias=True, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. +sequence_2D (dict): + 1. enable: bool, whether enable the 2D sequence parallel or not. + 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). + head_size * context_size should be equal tensor size. + 3. context_size: int, the parallel degree of context parallelism. + head_size * context_size should be equal tensor size. + 4. window_size: int, the sliding window size in context parallelism. + 5. device_placement_strategy: dict, + head_first: bool, if `True`, ranks of the same head parallel group are + given high priority for colocation on the same node; + if `False`, ranks of the same context parallel group are + given high priority for colocation on the same node; + interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could + interleaved the ranks in the same window to make full use of NIC as much as possible. +""" + +# wdp = world_size // wp // pp # isp +# dp = world_size // tp // pp +# zero1 size is up to wdp + +parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=32, mode="isp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), + sequence_2D=dict( + enable=False, + head_size=8, + context_size=4, + window_size=1, + device_placement_strategy=dict(head_first=True, interleaved=False), + ), +) + +cudnn_deterministic = False +cudnn_benchmark = False + + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 3c7bb9f4f..a546bcca1 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -1,13 +1,15 @@ JOB_NAME = "7b_internlm2_train" model_type = "INTERNLM2" +# MEMORY_PATH = "20B_16k_32g DO_ALERT = False +TASK_NAME = "0305-20B-ckpt-sc-fa-Dweb-32k-8144-z8-G32-S50" -VOCAB_SIZE = 92544 -SEQ_LEN = 2048 +VOCAB_SIZE = 103168 # 92544 +SEQ_LEN = 2*1024 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 NUM_KV_ATTENTION_HEAD = 8 -MLP_RATIO = 3.5 +MLP_RATIO = 8 / 3 # 3.5 NUM_LAYER = 32 @@ -40,9 +42,11 @@ oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) -TRAIN_FOLDER = None -VALID_FOLDER = None # "/path/to/dataset" +TRAIN_FOLDER = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" +VALID_FOLDER = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" # "/path/to/dataset" data = dict( + type="tokenized", + # tokenizer_path="/mnt/petrelfs/lusitian/tokenizer/hf-internlm2-tokenizer", seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update micro_num=4, @@ -53,7 +57,7 @@ # defaults to 0, means disable evaluate valid_every=0, pack_sample_into_one=False, - total_steps=20000, + total_steps=50, skip_batches="", # rampup_batch_size (str): A string with three space-separated integers representing the # starting batch size, the increment, and the number of steps between @@ -139,7 +143,7 @@ use_fp32_norm = False model = dict( - checkpoint=False, + checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, @@ -191,10 +195,10 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1), + zero1=dict(size=8), tensor=dict(size=2, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"), - weight=dict(size=2, overlap=True), + weight=dict(size=2, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), ) cudnn_deterministic = False @@ -231,3 +235,13 @@ repetition_penalty=1, length_penalty=1.0, ) + +cpu_offloading = dict( + enable=True, + num_layers=10, + offloading_activations=True, + ) + + +selective_checkpoint = False +selective_checkpoint_offload = False diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index f269ab4e2..8275eaeba 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -1,9 +1,11 @@ JOB_NAME = "7b_train" +TASK_NAME = "0409-7B-base-128k-t16w4z4-G16-S50" +MEMORY_PATH = "910B-7B_128k_16g" model_type = "INTERNLM2" DO_ALERT = False VOCAB_SIZE = 103168 -SEQ_LEN = 2048 +SEQ_LEN = 1*1024 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 NUM_KV_ATTENTION_HEAD = 8 @@ -49,20 +51,21 @@ ) # TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0715_llama_tokenized_refined_real/train/" -TRAIN_FOLDER = None # "/path/to/dataset" -VALID_FOLDER = None # "/path/to/dataset" +TRAIN_FOLDER = None # "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" # "/path/to/dataset" +VALID_FOLDER = None # "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data" # "/path/to/dataset" data = dict( + # type="tokenized", seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, + micro_num=1, # packed_length = micro_bsz * SEQ_LEN - micro_bsz=2, + micro_bsz=1, # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate - valid_every=50, + valid_every=0, pack_sample_into_one=False, - total_steps=50000, + total_steps=50, skip_batches="", # rampup_batch_size (str): A string with three space-separated integers representing the # starting batch size, the increment, and the number of steps between @@ -76,7 +79,7 @@ valid_folder=VALID_FOLDER, empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, - # use_packed_dataset=False, + # use_packed_dataset=False, # NPU ISP下只能使用unpacked dataset ) grad_scaler = dict( @@ -125,7 +128,7 @@ loss = dict( label_smoothing=0, - op_type="flash_vocab_parallel", + op_type="py_vocab_parallel", # flash_vocab_parallel ) adam = dict( @@ -153,14 +156,15 @@ # cpu_offloading = dict( # enable=True, -# num_layers=3, +# num_layers=10, # ) -# selective_checkpoint = True -# selective_checkpoint_offload = False + +selective_checkpoint = True +selective_checkpoint_offload = True use_fp32_norm = False model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + checkpoint=1, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, embed_split_hidden=True, @@ -228,14 +232,19 @@ interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could interleaved the ranks in the same window to make full use of NIC as much as possible. """ + +# wdp = world_size // wp // pp # isp +# dp = world_size // tp // pp +# zero1 size is up to wdp + parallel = dict( zero1=dict(size=-1), - tensor=dict(size=2, mode="isp"), + tensor=dict(size=8, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), sequence_2D=dict( enable=False, - head_size=2, + head_size=8, context_size=4, window_size=1, device_placement_strategy=dict(head_first=True, interleaved=False), @@ -262,3 +271,4 @@ # metric_dtype can be "fp32" or other string # only when set to "fp32" will use fp32 to calc in metrics # metric_dtype = "fp32" + diff --git a/doc/code-docs/locales/en/LC_MESSAGES/parallel.po b/doc/code-docs/locales/en/LC_MESSAGES/parallel.po index 6011ae72c..beafb4497 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/parallel.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/parallel.po @@ -894,7 +894,7 @@ msgid "" "``sequence_2D.device_placement_strategy.interleavd`` 字段表示是否对context " "parallel的GPU重排,该字段在 " "``sequence_2D.device_placement_strategy.head_first=False`` 和 " -"``sequence_2D.window_size>1`` 时,推荐设置为 ``True``" +"``.window_size>1`` 时,推荐设置为 ``True``" msgstr "" "``sequence_2D.device_placement_strategy.interleavd`` determines whether to rearrange the GPUs for context parallel." "It is recommend to set it to True when ``sequence_2D.device_placement_strategy.head_first=False`` and ``sequence_2D.window_size>1``." diff --git a/fusion_result.json b/fusion_result.json new file mode 100644 index 000000000..ec747fa47 --- /dev/null +++ b/fusion_result.json @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 5278426ed..45ba64994 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -403,7 +403,7 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, use_cpu (bool): whether to set up cpu process group. """ # initialize the default process group - init_method = f"tcp://[{host}]:{port}" + init_method = f"tcp://{host}:{port}" dist.init_process_group( rank=rank, world_size=world_size, diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py index be170f286..a9185d7b2 100644 --- a/internlm/core/parallel/comm/__init__.py +++ b/internlm/core/parallel/comm/__init__.py @@ -1,3 +1,3 @@ -from .attn_offload import get_offload_manager, initialize_offload_manager +from .attn_offload import get_offload_npu_manager, initialize_offload_npu_manager -__all__ = ["initialize_offload_manager", "get_offload_manager"] +__all__ = ["initialize_offload_npu_manager", "get_offload_npu_manager"] diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py index da23f3ae8..69c768230 100644 --- a/internlm/core/parallel/comm/attn_offload.py +++ b/internlm/core/parallel/comm/attn_offload.py @@ -1,9 +1,11 @@ import torch - +import torch_npu from internlm.utils.common import get_current_device +from internlm.core.context import global_context as gpc +import pdb global_attn_offload = None - +global_attn_npu_offload = None class AttnOffloadManager: """ @@ -14,16 +16,18 @@ def __init__(self, enable_cpu_offload: bool = False) -> None: # cpu offload overlapping self.cpu_offload = enable_cpu_offload # layer id mapping to flash attn output - self.fa_output_mapping = {} - self.fa_stream = torch.cuda.Stream() - self.d2h_final_event = torch.cuda.Event() - self.h2d_final_event = torch.cuda.Event() + self.fa_output_mapping = {} # 存储各层注意力输出的字典 + self.fa_stream = torch_npu.npu.Stream() # CUDA流用于异步传输 + self.d2h_final_event = torch_npu.npu.Event() # device to host事件 + self.h2d_final_event = torch_npu.npu.Event() # host to device事件 # prepare for tensor buffer - self.tensor_id_to_tensor_bufs = {} + self.tensor_id_to_tensor_bufs = {} # 按层和id缓存的GPU缓存区 - def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): + def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): # 分配存储缓存区 """Get tensor buffer for offloaded tensor.""" - layer_id = layer_id % 2 + layer_id = layer_id % 2 # 分奇偶双缓冲 + + # 检查对应层对应id的tensor是否在缓存中,否则分配新缓冲 if layer_id not in self.tensor_id_to_tensor_bufs: self.tensor_id_to_tensor_bufs[layer_id] = {} @@ -42,76 +46,104 @@ def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): device=tensor.device, ) - self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer + self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer # 将空间分配给字典,字典定层定序 - return self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + return self.tensor_id_to_tensor_bufs[layer_id][tensor_id] # 返回buffer空间 - def insert_fa_output_with_layer(self, layer_idx, output): + def insert_fa_output_with_layer(self, layer_idx, output): # 构建一个输出字典 assert layer_idx not in self.fa_output_mapping if self.cpu_offload is False: - self.fa_output_mapping[layer_idx] = output + self.fa_output_mapping[layer_idx] = output # 若无需offload,则将输出直接存储 return tensors = [] - for tensor_id, tensor in enumerate(output): - if tensor is None: + for tensor_id, item in enumerate(output): + if isinstance(item, torch.Tensor): + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(item, layer_idx, tensor_id) + tensor_buf.copy_(item) + tensors.append(tensor_buf) + elif item is None: tensors.append(None) continue - tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id) - tensor_buf.copy_(tensor) - tensors.append(tensor_buf) - self.fa_output_mapping[layer_idx] = tensors - - def get_fa_output_with_layer(self, layer_idx): + else: + tensors.append(item) + self.fa_output_mapping[layer_idx] = tensors # 若需offload,则将输出存储到buf中再给字典 + # if gpc.is_rank_for_log(): + # breakpoint() + # print(f"insert: {self.fa_output_mapping}") + + def get_fa_output_with_layer(self, layer_idx): # 取出输出 assert layer_idx in self.fa_output_mapping - return self.fa_output_mapping.pop(layer_idx) + return self.fa_output_mapping.pop(layer_idx) # 按层id取出输出 - def offload_fa_output_with_layer(self, layer_idx): + def offload_fa_output_with_layer(self, layer_idx): # 将输出offload至CPU + # if gpc.is_rank_for_log(): + # breakpoint() + # print(f"offload: {self.fa_output_mapping}") assert layer_idx in self.fa_output_mapping - self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_stream(torch_npu.npu.current_stream()) self.fa_stream.wait_event(self.d2h_final_event) - with torch.cuda.stream(self.fa_stream): - _gpu_tensors = self.fa_output_mapping.pop(layer_idx) + with torch_npu.npu.stream(self.fa_stream): + _gpu_tensors = self.fa_output_mapping.pop(layer_idx) # 获取GPU上输出,应该在缓存中 _cpu_tensors = [] for _tensor in _gpu_tensors: - if _tensor is None: - _cpu_tensors.append(_tensor) - continue - - _cpu_backup = torch.empty( + if isinstance(_tensor, torch.Tensor): + _cpu_backup = torch.empty( _tensor.size(), dtype=_tensor.dtype, layout=_tensor.layout, device="cpu", pin_memory=True, ) - _cpu_backup.copy_(_tensor, non_blocking=True) - _cpu_tensors.append(_cpu_backup) + _cpu_backup.copy_(_tensor, non_blocking=True) + _cpu_tensors.append(_cpu_backup) + elif _tensor is None: + _cpu_tensors.append(_tensor) + continue + else: + _cpu_tensors.append(_tensor) # _cpu_tensors.append(_tensor.to("cpu", non_blocking=False)) - self.fa_output_mapping[layer_idx] = _cpu_tensors - + self.fa_output_mapping[layer_idx] = _cpu_tensors # 用cuda流将输出从GPU(buf)放到CPU,字典记录Cpu——tensors + # if gpc.is_rank_for_log(): + # breakpoint() self.fa_stream.record_event(self.d2h_final_event) - def preload_fa_output_with_layer(self, layer_idx): + def preload_fa_output_with_layer(self, layer_idx):# 将输出重新载入gpu assert layer_idx in self.fa_output_mapping - - self.fa_stream.wait_stream(torch.cuda.current_stream()) + # breakpoint() + self.fa_stream.wait_stream(torch_npu.npu.current_stream()) self.fa_stream.wait_event(self.h2d_final_event) # Important: get device before with stream, in stream get device is error _device = get_current_device() - with torch.cuda.stream(self.fa_stream): + print(f"device: {_device}") + with torch_npu.npu.stream(self.fa_stream): _cpu_tensors = self.fa_output_mapping.pop(layer_idx) - self.fa_output_mapping[layer_idx] = [ - _tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor - for _tensor in _cpu_tensors - ] - - self.fa_stream.record_event(self.h2d_final_event) + self.fa_output_mapping[layer_idx] = [] + for _tensor in _cpu_tensors: + if isinstance(_tensor, torch.Tensor): + _gpu_backup = torch.empty( + _tensor.size(), + dtype=_tensor.dtype, + layout=_tensor.layout, + device=_device, + # pin_memory=True, + ) + _gpu_backup.copy_(_tensor, non_blocking=True) + self.fa_output_mapping[layer_idx].append(_gpu_backup) + + elif _tensor is None: + self.fa_output_mapping[layer_idx].append(_tensor) + continue + else: + self.fa_output_mapping[layer_idx].append(_tensor) + # print(f"preload:{self.fa_output_mapping[layer_idx]}") + self.fa_stream.record_event(self.h2d_final_event) + # breakpoint() def initialize_offload_manager(enable_cpu_offload: bool = False): @@ -125,3 +157,16 @@ def initialize_offload_manager(enable_cpu_offload: bool = False): def get_offload_manager(): assert global_attn_offload is not None return global_attn_offload + + +def initialize_offload_npu_manager(enable_cpu_offload: bool = False): + global global_attn_npu_offload + if global_attn_npu_offload is None: + global_attn_npu_offload = AttnOffloadManager(enable_cpu_offload) + + return global_attn_npu_offload + + +def get_offload_npu_manager(): + assert global_attn_npu_offload is not None + return global_attn_npu_offload diff --git a/internlm/core/parallel/comm/cpu_offload.py b/internlm/core/parallel/comm/cpu_offload.py index 89e5912b3..3eda62740 100644 --- a/internlm/core/parallel/comm/cpu_offload.py +++ b/internlm/core/parallel/comm/cpu_offload.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Optional import torch +import torch_npu __all__ = ["get_cpu_offload_context"] @@ -302,8 +303,8 @@ def __init__( ) # allocate streams and events for synchronization - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() + self.d2h_stream = torch_npu.npu.Stream() + self.h2d_stream = torch_npu.npu.Stream() def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: torch_stray_tensor = False @@ -346,7 +347,7 @@ def tensor_pop(self, tensor_tag, **kwargs): def bulk_offload_group(self, group_to_offload): """Bulk offload group.""" - with torch.cuda.stream(self.d2h_stream): + with torch_npu.npu.stream(self.d2h_stream): for tensor_tag, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_tag if group_id == group_to_offload: @@ -364,7 +365,7 @@ def synchronize_on_group_commit_forward(self, current_group): # For the first group, kickstart the offload after we have # the first compute completion if current_group == 0: - self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self.d2h_stream.wait_stream(torch_npu.npu.current_stream()) self.bulk_offload_group(current_group) # Window map data structure helps us synchronize based on number @@ -373,8 +374,8 @@ def synchronize_on_group_commit_forward(self, current_group): if self.layer_window_map[self.offloaded_group_count] == current_group: # Stream synchronization both ways - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.d2h_stream) + self.d2h_stream.wait_stream(torch_npu.npu.current_stream()) + torch_npu.npu.current_stream().wait_stream(self.d2h_stream) # Time to free the activation memory after usage for tensor_tag, _ in self.tensor_tag_to_buf.items(): @@ -399,7 +400,7 @@ def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" assert group_to_reload < self.num_offload_group - with torch.cuda.stream(self.h2d_stream): + with torch_npu.npu.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label @@ -420,8 +421,8 @@ def on_group_commit_backward(self): if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: # Stream synchronization both ways - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.h2d_stream.wait_stream(torch_npu.npu.current_stream()) + torch_npu.npu.current_stream().wait_stream(self.h2d_stream) # Time to reload the next group self.bulk_reload_group(self.offloaded_group_count - 1) @@ -431,7 +432,7 @@ def on_group_commit_backward(self): # Last group computation needs to wait till all the reloads complete if self.current_group == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) + torch_npu.npu.current_stream().wait_stream(self.h2d_stream) self.offloaded_group_count = 0 diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 23a92980c..7fdee74e8 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -37,7 +37,7 @@ params_dispatch_with_condition, ) -from .attn_offload import get_offload_manager +from .attn_offload import get_offload_npu_manager # not really useful, only for code hint. @@ -333,7 +333,8 @@ def __init__( int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1, ] self.sc_offload = selective_ckpt_offload - + print(f"{self.layers_fa_not_release=}") + print(f"{self.sc_offload=}") # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -427,10 +428,11 @@ def is_allgather_launch_module(name, module): self._overlap_states[cid].isp_modules.append(child) self._overlap_states[cid].index_to_isp_modules[idx].append(child) - setattr(child, "isp_name", name) + # setattr(child, "isp_name", name) setattr(child, "isp_layer_idx", idx) full_name = f"{cid}.{idx}.{name}" + setattr(child, "isp_name", full_name) setattr( child.weight, "isp_reduce_scatter_name", @@ -533,7 +535,7 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args) and block_index not in self.layers_fa_not_release and block_index < self._ckpt_block_num ): - get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index) + get_offload_npu_manager().offload_fa_output_with_layer(layer_idx=block_index) # load previous layer's attn output from CPU to GPU asynchronizely if ( @@ -541,7 +543,7 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args) and gpc.config.selective_checkpoint and (0 <= (block_index - 1) < self._ckpt_block_num) ): - get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1) + get_offload_npu_manager().preload_fa_output_with_layer(layer_idx=block_index - 1) def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: @@ -709,8 +711,15 @@ def weight_hook( result = self._bias_global_output[module] else: assert module is not None, "The module parameter must be specified" + # breakpoint() + try: result = self._weight_global_output[module] - + except KeyError: + print(f"{gpc.is_forward=}") + print(f"{self._ckpt_block_num=}") + print(module.isp_name) + exit() + return result def grad_hook( @@ -881,7 +890,6 @@ def forward( ] output_list_next = [torch.empty_like(input_list_next[0]) for _ in range(seq_world_size)] handle_next = dist.all_to_all(output_list_next, input_list_next, group=group, async_op=True) - handle_last.wait() outputs.append(torch.cat(output_list, dim=gather_idx[i]).contiguous()) diff --git a/internlm/core/parallel/comm/test.py b/internlm/core/parallel/comm/test.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 4c18fd326..57c994c96 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -1,6 +1,7 @@ import gc import logging import time +import os from functools import partial from typing import Dict, List, Optional, Union @@ -11,7 +12,7 @@ from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode -from internlm.core.parallel.comm import initialize_offload_manager +from internlm.core.parallel.comm.attn_offload import initialize_offload_npu_manager from internlm.core.trainer import Trainer from internlm.data.streaming.utils import streaming_simple_resume from internlm.data.train_state import get_train_state @@ -112,7 +113,7 @@ def __init__( criterion = self._initialize_criterion() # initialize cpu offload manager for selective checkpoint - initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + initialize_offload_npu_manager(gpc.config.get("selective_checkpoint_offload", False)) # initialize train state train_state = get_train_state(train_dl) @@ -257,16 +258,30 @@ def fit(self): self.train() train_iter = iter(self.train_dl) + # memory_trace start + # torch.cuda.memory._record_memory_history() + with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: gc.disable() for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): if self._process_batch(batch_count, train_iter, prof): break - + self.ckpt_manager.wait_async_upload_finish() def _process_batch(self, batch_count: int, train_iter, prof) -> bool: empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) + + # set task_name + task_name = gpc.config.TASK_NAME + + if gpc.config.IS_MEMORY_TRACE == True: + memory_trace_path = gpc.config.MEMORY_PATH + + # start record memory_trace + # if gpc.config.IS_MEMORY_TRACE == True and ((batch_count + 1) % 10 == 0 or batch_count == 0) and gpc.is_rank_for_log(): + # torch.cuda.memory._record_memory_history() + start_time = time.time() timer("one-batch").start() @@ -288,6 +303,24 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: if self._should_evaluate(): self._evaluate() + # snapshot and close memory_trace + # if gpc.config.IS_MEMORY_TRACE == True and ((batch_count + 1) % 10 == 0 or batch_count == 0) and gpc.is_rank_for_log(): + + # print(f"batch_count:{batch_count}") + # path = os.path.join("/mnt/petrelfs/lusitian/workspace/InternEvo-fork/memory_trace", memory_trace_path, task_name, f"no{batch_count}.pickle") + + # directory = os.path.dirname(path) + # # 检查目录是否存在,如果不存在则创建 + # if not os.path.exists(directory): + # print(f"Directory {directory} does not exist. Creating it...") + # os.makedirs(directory, exist_ok=True) # exist_ok=True 避免路径已存在时报错 + # else: + # print(f"Directory {directory} already exists.") + + # torch.cuda.memory._dump_snapshot(path) + # torch.cuda.memory._record_memory_history(enabled=None) + + if self.ckpt_manager.try_save_checkpoint(self.train_state): return True diff --git a/internlm/data/tokenized/dummy_dataset.py b/internlm/data/tokenized/dummy_dataset.py index dcb6c027d..028bb41ae 100644 --- a/internlm/data/tokenized/dummy_dataset.py +++ b/internlm/data/tokenized/dummy_dataset.py @@ -3,6 +3,7 @@ import numpy as np from torch.utils.data import Dataset +from internlm.core.context.parallel_context import global_context as gpc # from internlm.core.context.parallel_context import global_context as gpc @@ -30,7 +31,7 @@ def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False) while len(d) < max_len: r *= 2 d = list(range(n)) * r - # r = r % gpc.config.model.vocab_size + r = r % gpc.config.model.vocab_size d = [n, r] + d d = d[:max_len] data.append(d) diff --git a/internlm/data/train_state.py b/internlm/data/train_state.py index 2c678b05b..230c46a41 100644 --- a/internlm/data/train_state.py +++ b/internlm/data/train_state.py @@ -11,7 +11,7 @@ def get_train_state(dataloader): DataType.streaming.name, DataType.megatron.name, DataType.mocked.name, - ]: + ]: train_state = TrainState(gpc.config, dataloader.batch_sampler) else: raise ValueError(f"dataset type {gpc.config.data.type} is not supported") diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 4b525e12e..a62163e0d 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -100,6 +100,19 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + # set task_name + if "TASK_NAME" not in gpc.config: + gpc.config._add_item("TASK_NAME", "test_task") + + # set memory_trace_path for torch_memory_viz + if "MEMORY_PATH" not in gpc.config: + # gpc.config._add_item("MEMORT_PATH", "NONE") + gpc.config._add_item("IS_MEMORY_TRACE", False) + + # whether us memory_trace + if "IS_MEMORY_TRACE" not in gpc.config: + gpc.config._add_item("IS_MEMORY_TRACE", True) + # inject HuggingFace model config into IntrainTrain if is_using_hf(): inject_hf_config_before_launch(gpc.config.hf) @@ -440,25 +453,25 @@ def args_sanity_check(): # for NPU accelerator supports: 1)FA-True + Packed-True 2) FA-False + Packed-False # for DIPU accelerator supports: 1)FA-True + Packed-False 2) FA-False + Packed-False # for GPU accelerator supports: 1)FA-True + Packed-True 2) FA-False + Packed-False - if gpc.config.parallel["tensor"][ - "mode" - ] == TensorParallelMode.isp.name and internlm_accelerator.get_accelerator_backend() in [ - AcceleratorType.NPU, - AcceleratorType.DIPU, - AcceleratorType.DITORCH, - ]: - assert ( - gpc.config.data.use_packed_dataset is False - ), "only unpacked data is supported when tensor parallel mode is isp and accelerator type is NPU or DIPU" - - if internlm_accelerator.get_accelerator_backend() in [ - AcceleratorType.NPU, - AcceleratorType.DIPU, - AcceleratorType.DITORCH, - ]: - assert ( - gpc.config.model.use_flash_attn == gpc.config.data.use_packed_dataset - ), "use_packed_dataset should be set same value as use_flash_attn" + # if gpc.config.parallel["tensor"][ + # "mode" + # ] == TensorParallelMode.isp.name and internlm_accelerator.get_accelerator_backend() in [ + # AcceleratorType.NPU, + # AcceleratorType.DIPU, + # AcceleratorType.DITORCH, + # ]: + # assert ( + # gpc.config.data.use_packed_dataset is False + # ), "only unpacked data is supported when tensor parallel mode is isp and accelerator type is NPU or DIPU" + + # if internlm_accelerator.get_accelerator_backend() in [ + # AcceleratorType.NPU, + # AcceleratorType.DIPU, + # AcceleratorType.DITORCH, + # ]: + # assert ( + # gpc.config.model.use_flash_attn == gpc.config.data.use_packed_dataset + # ), "use_packed_dataset should be set same value as use_flash_attn" # adapt to old version's sequence parallel config if gpc.config.parallel["tensor"].get("mode", None) in [ diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py index 1d1416d94..3a10d1ff1 100644 --- a/internlm/model/ops/_flash_attn.py +++ b/internlm/model/ops/_flash_attn.py @@ -84,6 +84,7 @@ def forward( # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled if gpc.is_forward and gpc.config.selective_checkpoint and _is_ckpt_layer: + breakpoint() get_offload_manager().insert_fa_output_with_layer( layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) ) diff --git a/internlm/model/ops/_flash_attn_npu.py b/internlm/model/ops/_flash_attn_npu.py new file mode 100644 index 000000000..5ebb9000f --- /dev/null +++ b/internlm/model/ops/_flash_attn_npu.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024, Huawei Technologies. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from internlm.accelerator import get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import get_offload_npu_manager +try: + import mindspeed + from mindspeed.op_builder import FusionAttentionV2OpBuilder + + npu_mindspeed_fa_impl = True +except (ModuleNotFoundError, ImportError): + npu_mindspeed_fa_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + +__all__ = ["npu_fusion_attention"] + + +class FusionAttentionV2Function(torch.autograd.Function): + + @staticmethod + def forward(ctx, query, key, value, head_num, input_layout, pse, padding_mask, atten_mask, scale, keep_prob, + pre_tokens, next_tokens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode, + gen_mask_parallel, sync, pse_type, q_start_idx, kv_start_idx, layer_idx): + mindspeed_ops = FusionAttentionV2OpBuilder().load() + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + _is_ckpt_layer = gpc.config.cpu_offloading.num_layers <= layer_idx < _ckpt_block_num + + if gpc.is_forward is False and gpc.config.selective_checkpoint and _is_ckpt_layer: + outputs = get_offload_npu_manager().get_fa_output_with_layer(layer_idx) + # breakpoint() + # attention_in, softmax_max, softmax_sum, softmax_in, seed, offset, numels + else: + outputs = mindspeed_ops.npu_fusion_attention_v2(query, key, value, head_num, + input_layout, pse, + padding_mask, atten_mask, + scale, keep_prob, pre_tokens, + next_tokens, inner_precise, prefix, + actual_seq_qlen, actual_seq_kvlen, + sparse_mode, gen_mask_parallel, + sync, pse_type, q_start_idx, + kv_start_idx) + attention_in, softmax_max, softmax_sum, softmax_in, seed, offset, numels = outputs + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and _is_ckpt_layer: + # breakpoint() + get_offload_npu_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(attention_in, softmax_max, softmax_sum, softmax_in, seed, offset, numels) + ) + ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, attention_in, + softmax_max, softmax_sum, softmax_in) + ctx.scale = scale + ctx.input_layout = input_layout + ctx.head_num = head_num + ctx.pre_tokens = pre_tokens + ctx.next_tokens = next_tokens + ctx.inner_precise = inner_precise + ctx.gen_mask_parallel = gen_mask_parallel + ctx.sync = sync + ctx.seed = seed + ctx.offset = offset + ctx.numels = numels + ctx.prefix = prefix + ctx.keep_prob = keep_prob + ctx.actual_seq_qlen = actual_seq_qlen + ctx.actual_seq_kvlen = actual_seq_kvlen + ctx.sparse_mode = sparse_mode + ctx.pse_type = pse_type + ctx.q_start_idx = q_start_idx + ctx.kv_start_idx = kv_start_idx + + return outputs + + @staticmethod + def backward(ctx, grad_outputs, dq=None, dk=None, dv=None, seed=0, offset=0, numels=0): + mindspeed_ops = FusionAttentionV2OpBuilder().load() + query, key, value, pse, padding_mask, atten_mask, attention_in, softmax_max, \ + softmax_sum, softmax_in = ctx.saved_tensors + results = mindspeed_ops.npu_fusion_attention_grad_v2( + query, key, value, grad_outputs, ctx.head_num, ctx.input_layout, pse, padding_mask, atten_mask, + softmax_max, softmax_sum, softmax_in, attention_in, ctx.scale, ctx.keep_prob, ctx.pre_tokens, + ctx.next_tokens, ctx.inner_precise, ctx.seed, ctx.offset, ctx.numels, ctx.prefix, ctx.actual_seq_qlen, + ctx.actual_seq_kvlen, ctx.sparse_mode, ctx.gen_mask_parallel, ctx.sync, ctx.pse_type, ctx.q_start_idx, + ctx.kv_start_idx) + + return results[0], results[1], results[2], None, None, results[3], None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None # 22 + + +def npu_fusion_attention(query, key, value, head_num, + input_layout, *, pse=None, + padding_mask=None, atten_mask=None, + scale=1., keep_prob=1., pre_tokens=2147483647, + next_tokens=2147483647, inner_precise=0, prefix=None, + actual_seq_qlen=None, actual_seq_kvlen=None, + sparse_mode=0, gen_mask_parallel=True, + sync=False, pse_type=1, q_start_idx=None, + kv_start_idx=None, layer_idx=0): + return FusionAttentionV2Function.apply(query, key, value, head_num, + input_layout, pse, + padding_mask, atten_mask, + scale, keep_prob, pre_tokens, + next_tokens, inner_precise, prefix, + actual_seq_qlen, actual_seq_kvlen, + sparse_mode, gen_mask_parallel, + sync, pse_type, q_start_idx, + kv_start_idx, layer_idx) + + +def npu_fusion_attention_grad(query, key, value, grad_outputs, + head_num, input_layout, *, pse=None, + padding_mask=None, atten_mask=None, + softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, + scale=1., keep_prob=1., pre_tokens=2147483647, + next_tokens=2147483647, inner_precise=0, + seed=1234, offset=0, numels=0, prefix=None, + actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, + gen_mask_parallel=True, sync=False, pse_type=1, q_start_idx=None, + kv_start_idx=None): + mindspeed_ops = FusionAttentionV2OpBuilder().load() + return mindspeed_ops.npu_fusion_attention_grad_v2(query, key, value, grad_outputs, head_num, input_layout, pse, + padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, + attention_in, scale, keep_prob, pre_tokens, next_tokens, + inner_precise, seed, offset, numels, prefix, actual_seq_qlen, + actual_seq_kvlen, sparse_mode, gen_mask_parallel, sync, + pse_type, q_start_idx, kv_start_idx) diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 3aec51f55..88d34c32e 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -9,6 +9,7 @@ import math from enum import Enum from typing import Callable, Tuple +import pdb import torch from einops import rearrange, repeat @@ -56,6 +57,14 @@ except (ModuleNotFoundError, ImportError): is_torch_npu = False +from internlm.model.ops._flash_attn_npu import npu_fusion_attention as _npu_sc_fa_func +try: + from internlm.model.ops._flash_attn_npu import npu_fusion_attention as _npu_sc_fa_func + npu_scfa_impl = True +except (ModuleNotFoundError, ImportError): + npu_scfa_impl = False +print(f'---------{npu_scfa_impl}---------------', flush=True) + try: from deeplink_ext.internevo_ops import ( flash_attn_func as _deeplink_fixedlen_qkvsplited_func, @@ -359,6 +368,7 @@ def _npu_varlen_qkvsplited_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): return _flash_float32_compatibility_wrapper( (0, 1, 2), @@ -373,6 +383,7 @@ def _npu_varlen_qkvsplited_attn( dropout_p, softmax_scale, causal, + layer_idx=layer_idx, ) @@ -388,6 +399,7 @@ def _npu_varlen_qkvsplited_func( softmax_scale=None, causal=False, use_fixlen=False, + layer_idx=0, # pylint: disable=W0613 ): """Support Huawei Ascend's torch_npu flash attention. Tested version: @@ -409,7 +421,7 @@ def _npu_varlen_qkvsplited_func( output = pack_output_after_attn(output, cu_seqlens_q, packed_length) else: output = _npu_fused_varlen_qkvsplited_attn( - q, k, v, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k + q, k, v, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, layer_idx=layer_idx ) return output @@ -462,6 +474,7 @@ def _npu_fused_varlen_qkvsplited_attn( cu_seqlens_q=None, cu_seqlens_kv=None, deterministic=False, + layer_idx=0, # pylint: disable=W0613 ): assert causal is True assert q.dtype in (torch.bfloat16, torch.float16) @@ -481,9 +494,8 @@ def _npu_fused_varlen_qkvsplited_attn( attention_mask = torch.triu(torch.ones(max_seqlen_q, max_seqlen_k, device=device), 1).bool() cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_kv = cu_seqlens_kv[1:].tolist() - - return _origin_npu_fixedlen_qkvsplited_func( - query=q, + + return _npu_sc_fa_func(query=q, key=k, value=v, head_num=N, @@ -491,15 +503,34 @@ def _npu_fused_varlen_qkvsplited_attn( pse=None, atten_mask=attention_mask, scale=softmax_scale, - sparse_mode=sparse_mode, - pre_tockens=S, # Used for sparse calculations, representing the left boundary of the slides window - next_tockens=0, keep_prob=1 - dropout_p, + pre_tokens=S, # Used for sparse calculations, representing the left boundary of the slides window + next_tokens=0, inner_precise=0 if not deterministic else 2, - actual_seq_kvlen=cu_seqlens_kv, actual_seq_qlen=cu_seqlens_q, + actual_seq_kvlen=cu_seqlens_kv, + sparse_mode=sparse_mode, + layer_idx=layer_idx )[0].unsqueeze(dim=0) + # return _origin_npu_fixedlen_qkvsplited_func( + # query=q, + # key=k, + # value=v, + # head_num=N, + # input_layout="TND", + # pse=None, + # atten_mask=attention_mask, + # scale=softmax_scale, + # sparse_mode=sparse_mode, + # pre_tockens=S, # Used for sparse calculations, representing the left boundary of the slides window + # next_tockens=0, + # keep_prob=1 - dropout_p, + # inner_precise=0 if not deterministic else 2, + # actual_seq_kvlen=cu_seqlens_kv, + # actual_seq_qlen=cu_seqlens_q, + # )[0].unsqueeze(dim=0) + def _npu_varlen_qkvpacked_attn( qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 @@ -539,6 +570,7 @@ def _npu_varlen_kvpacked_attn( dropout_p, softmax_scale, causal, + layer_idx ) @@ -1012,7 +1044,7 @@ def _q_kv_with_cu_seqlens( causal = self.causal if causal is None else causal attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked) - +# breakpoint() dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 945ee688a..3a0f99ea2 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -834,6 +834,10 @@ def record_current_batch_training_metrics( tflops = get_tflops_func(time_cost) + tgs_statistic.setdefault("sum_tflops", 0.00) + tgs_statistic["sum_tflops"] += tflops + tflops_avg = tgs_statistic["sum_tflops"] / tgs_statistic["sum_step"] + tgs_origin = round( num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) @@ -880,6 +884,20 @@ def record_current_batch_training_metrics( bwd_time = round(timer("bwd").elapsed(), 2) infos["bwd_time"] = bwd_time + # 累积总时间并计算平均值 + tgs_statistic.setdefault("sum_fwd_bwd_time", 0.0) + tgs_statistic.setdefault("sum_bwd_time", 0.0) + tgs_statistic["sum_fwd_bwd_time"] += fwd_bwd_time + tgs_statistic["sum_bwd_time"] += bwd_time + + fwd_bwd_avg = tgs_statistic["sum_fwd_bwd_time"] / tgs_statistic["sum_step"] + bwd_avg = tgs_statistic["sum_bwd_time"] / tgs_statistic["sum_step"] + + infos["fwd_bwd_avg"] = round(fwd_bwd_avg, 2) + infos["bwd_avg"] = round(bwd_avg, 2) + infos["tflops_avg"] = round(tflops_avg, 2) + + for key, value in acc_perplex.items(): infos[key] = value diff --git a/run.py b/run.py new file mode 100644 index 000000000..cb8055898 --- /dev/null +++ b/run.py @@ -0,0 +1,66 @@ +from internlm.core.context import global_context as gpc +import subprocess +import os + +gpc.load_config("./configs/7B_isp_sft.py") +job_name = gpc.config.JOB_NAME +task_name = gpc.config.TASK_NAME + +if "MEMORY_PATH" not in gpc.config: + task_folder = "NONE" + save_index = False +else: + task_folder = gpc.config.MEMORY_PATH + save_index = True + + +output_path = os.path.join("/mnt/petrelfs/lusitian/workspace/InternEvo-fork/log_output", task_folder, f"{task_name}.out") + +PARTITION = "llm_s" +NODES = 4 +TOTAL_TASKS = 32 +TASKS_PER_NODE = 8 +GPUS_PER_TASK = 1 +CONFIG_FILE = "./configs/7B_isp_sft.py" + + +def submit_job(): + + # 构建命令列表 + command = [ + "srun", + "-p", PARTITION, + # "-x", "HOST-10-140-60-6", + "-N", str(NODES), + "-n", str(TOTAL_TASKS), + "--ntasks-per-node", str(TASKS_PER_NODE), + "--gpus-per-task", str(GPUS_PER_TASK), + ] + + # 动态添加异步和输出参数 + if save_index: + # 确保输出目录存在 + # os.makedirs(os.path.dirname(output_path), exist_ok=True) + command[1:1] = ["--async", "-o", output_path] # 在 -p 参数后插入 + + # 添加固定尾部参数 + command += [ + "python", "train.py", + "--config", CONFIG_FILE, + "--profiling" + ] + + try: # + # 执行命令 + subprocess.run(command, check=True) + print(f"✅ 作业已提交,日志输出到: {output_path}") + except subprocess.CalledProcessError as e: + print(f"❌ 提交失败: {e}") + except Exception as e: + print(f"❌ 发生意外错误: {str(e)}") + + +if __name__ == "__main__": + submit_job() + + diff --git a/run_20B.sh b/run_20B.sh new file mode 100644 index 000000000..5cfb8c666 --- /dev/null +++ b/run_20B.sh @@ -0,0 +1,9 @@ +srun -p llm_s --async -o /mnt/petrelfs/lusitian/workspace/InternEvo-fork/log_output/20B_16k_32g/03051-20B-ckpt-fa-Dweb-16k-8144-z8-G32-S50.out \ +-N 4 -n 32 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/20B_internlm2.py --profiling + +# 19 51 25 20B fa 16k +# 19 55 35 7B CKPT SC FA 16K +# 20 15 25 7B ckpt scoffload 16k +# 20 26 35 7B cpuoffload 16k + +# 11 35 ckpt0.5 diff --git a/run_7B.sh b/run_7B.sh new file mode 100644 index 000000000..0e5ed9d5d --- /dev/null +++ b/run_7B.sh @@ -0,0 +1,7 @@ +srun -p llm_s --async -o /mnt/petrelfs/lusitian/workspace/InternEvo-fork/log_output/7B_16k_16g/0304-cpuoff10-fa-Dweb-16k-8122-z8-G16-S50.out \ +-N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_internlm2.py --profiling + + +#--async -o /mnt/petrelfs/lusitian/workspace/InternEvo-fork/log_output/7B_16k_16g/0304-ckpt-fa-Dweb-16k-8122-z8-G16-S50.out \ + + diff --git a/run_7b.sh b/run_7b.sh new file mode 100755 index 000000000..83c26dbef --- /dev/null +++ b/run_7b.sh @@ -0,0 +1,26 @@ +# A+K +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=8666 +export WORLD_SIZE=1 +export RANK=0 + +export GPU_NUMS=4 +export USER=root +export TZ=UTC-8 +export HCCL_IF_BASE_PORT=30000 +export HCCL_CONNECT_TIMEOUT=1200 +export HCCL_INTRA_ROCE_ENABLE=1 +export HCCL_INTRA_PCIE_ENABLE=0 + + +export PYTHONPATH=/pj_data30t/tangzhiyi/ditorch:/pj_data30t/tangzhiyi/DeepLinkExt/:$PYTHONPATH +export INTERNLM_ACCELERATOR=ditorch +export DEEPLINK_EXT_PLATFORM_TYPE=torch_npu + +log_file="log_$(date +%Y%m%d_%H%M%S)" + +cd /pj_data30t/tangzhiyi/InternEvo + +# torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=4 --nnodes=$WORLD_SIZE --node_rank=$RANK train.py --config configs/7b.py --launcher torch --seed 1024 + +torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=4 --nnodes=$WORLD_SIZE --node_rank=$RANK train.py --config configs/7b.py --launcher torch --seed 1024 2>&1 | tee -a /pj_data30t/logs/7b_internlm2_ckpt10_epoch1000.log \ No newline at end of file diff --git a/run_internevo_100b.sh b/run_internevo_100b.sh new file mode 100644 index 000000000..c12ca6b37 --- /dev/null +++ b/run_internevo_100b.sh @@ -0,0 +1,23 @@ +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +export INTERNLM_ACCELERATOR=npu +export HCCL_IF_BASE_PORT=30000 +export HCCL_CONNECT_TIMEOUT=1200 +export HCCL_INTRA_ROCE_ENABLE=1 +export HCCL_INTRA_PCIE_ENABLE=0 +export ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest + +# # use ditorch +# export PYTHONPATH=/pjlab_data/code/ditorch:/pjlab_data/code/DeepLinkExt/:$PYTHONPATH +# export INTERNLM_ACCELERATOR=ditorch +# export DEEPLINK_EXT_PLATFORM_TYPE=torch_npu +# export DITORCH_SHOW_DEVICE_AS_CUDA=0 + +cd /pjlab_data/code/InternEvo + +echo "MASTER_ADDR: ${MASTER_ADDR}" +echo "MASTER_PORT: ${MASTER_PORT}" +echo "NNODES: ${WORLD_SIZE}" +echo "RANK: ${RANK}" +torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=8 --nnodes=$WORLD_SIZE --node_rank=$RANK train.py --config configs/100b.py --launcher torch 2>&1 | tee /pjlab_data/logs/internevo_100b_train_log_$RANK + diff --git a/run_internevo_20b.sh b/run_internevo_20b.sh new file mode 100755 index 000000000..b3db9ab85 --- /dev/null +++ b/run_internevo_20b.sh @@ -0,0 +1,164 @@ +脚本启动于 2025-04-02 17:17:20+08:00 [TERM="xterm-256color" TTY="/dev/pts/2" COLUMNS="199" LINES="53"] + + +Welcome to 4.19.90-2102.2.0.0066.ctl2.aarch64 + +System information as of time: 2025年 04月 02日 星期三 17:17:20 CST + +System load: 9.18 +Processes: 2472 +Memory used: 2.5% +Swap used: 0.0% +Usage On: 16% +IP address: 10.201.20.222 +Users online: 3 + + + +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/script[root@node910b-0107140027 script]# which python +/usr/bin/python +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/script[root@node910b-0107140027 script]# ls +check_hccl.sh kill.sh plot_loss.py run_hccl_check.sh run_internevo_20b_ditorch.sh run_interntrain_100b.sh run_job.sh run_temp.sh +jin loss_en_plot.png run_7b.sh run_internevo_100b.sh run_internevo_20b.sh run_interntrain_ditorch_100b.sh run.sh start.sh +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/script[root@node910b-0107140027 script]# nano run_7b.sh +bash: nano:未找到命令 +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/script[root@node910b-0107140027 script]# cd .. +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# ls +100 backup data internlm2_5 jin main_logs_old_data puyu3-delivery puyu3-delivery_new tmp var_log_bk +1200 code deeplink-rover internlm2_5-md5.txt logs main_logs_test puyu3-delivery_0208 script train-65aa8204fc2e-004705.jsonl yuansheng +ascend_image ctyun_test250107 hwtzd internlm2_5_new main_logs process.sh puyu3-delivery_0224 test user zos +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# cd datauser +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# ls +ascend caikun cp_nohup_job_zhumingzhu dongkaixing jiaopenglong lijiaxing quwenwen run_intern.sh tangyufeng tangzhiyi wangqing zhaochaoxing zhumingzhu +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# mkdir cd quwenwen +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/quwenwen[root@node910b-0107140027 quwenwen]# ls +1.txt darwin_ckpt llm_alter ps_ckpt_chunjie run_group3_zmz.sh run_interntrain_100b_group01.sh run_ps11.sh run_ps5.sh test_group0.sh +2.txt gene_validate_data.py logs rc.cfg run_group4.sh run_interntrain_100b_group0.sh run_ps12.sh run_ps6.sh test_intern.sh +3dps.tar internlm_bak logs_bak RUN run_group5.sh run_interntrain_100b_group1.sh run_ps13.sh run_ps7.sh test_ps_ckpt +ckpt interntrain logs_old_data run_group0.sh run_group6.sh run_interntrain_100b_no3dps.sh run_ps14.sh run_ps8.sh tokenizes +ckpt_old_data interntrain-feat-3dps old_codes run_group1.sh run_group752.sh run_interntrain_100b.sh run_ps15.sh run_ps9.sh update_ip.sh +ckpt_test interntrain-feat-3dps.0.0.1.zip old_logs run_group21.sh run_group_nops.sh run_interntrain.sh run_ps1.sh run_rover.sh val_1500.jsonl +convert_100B.sh interntrain-feat-3dps-2 plot.png run_group22.sh run_intern.sh run_master.sh run_ps2.sh run_test.sh val_data +cp_nohup_job InternTrain.zip ps_ckpt run_group2.sh run_interntrain_100b_3dps_0.sh run_ps0.sh run_ps3.sh temp +cp_nohup_job_8machine kernel_meta ps_ckpt_8machine run_group3.sh run_interntrain_100b_3dps_1.sh run_ps10.sh run_ps4.sh temp.txt +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/quwenwen[root@node910b-0107140027 quwenwen]# cd .. +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# clear +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# cd ls +ascend caikun cp_nohup_job_zhumingzhu dongkaixing jiaopenglong lijiaxing quwenwen run_intern.sh tangyufeng tangzhiyi wangqing zhaochaoxing zhumingzhu +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# cd jiaopenglong/ +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/jiaopenglong[root@node910b-0107140027 jiaopenglong]# ls +convert2bf16_ps_ckpt.py new_ps_ckpt tmp_ckpt transform_ckpt_0130.out transform_ckpt_0205.out transform_ckpt_0219.out transform_ckpt_with_cast.sh +data_cmp1.txt normalized_groups_weight_init_factor.log transform_ckpt0125.out transform_ckpt_0131.out transform_ckpt_0208.out transform_ckpt_0221.out weight_factor.log +data_cmp.txt rclone transform_ckpt_0126.out transform_ckpt_0201.out transform_ckpt_0210.out transform_ckpt_0223.out +data_list1.txt rclone.conf transform_ckpt_0127.out transform_ckpt_0202.out transform_ckpt_0212.out transform_ckpt_0225.out +data_list.txt rclone-v1.69.0-linux-arm64.zip transform_ckpt_0128.out transform_ckpt_0203.out transform_ckpt_0216.out transform_ckpt.out +log_collect_0202 split_ps.py transform_ckpt_0129.out transform_ckpt_0204.out transform_ckpt_0217.out transform_ckpt.sh +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/jiaopenglong[root@node910b-0107140027 jiaopenglong]# cd .. +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# cd lijiaxing/ +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lijiaxing[root@node910b-0107140027 lijiaxing]# ls +debug.log InternEvo-feat-refactor-impl InternEvo-feat-refactor-impl.zip merge.py ps_ckpt +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lijiaxing[root@node910b-0107140027 lijiaxing]# cleatrr +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lijiaxing[root@node910b-0107140027 lijiaxing]# lscd .. +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# ,ls +ascend caikun cp_nohup_job_zhumingzhu dongkaixing jiaopenglong lijiaxing quwenwen run_intern.sh tangyufeng tangzhiyi wangqing zhaochaoxing zhumingzhu +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# mkdir lusitian +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# cd lusitian/ +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lusitian[root@node910b-0107140027 lusitian]# la +总用量 0 +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lusitian[root@node910b-0107140027 lusitian]# cd / +]0;root@node910b-0107140027:/[root@node910b-0107140027 /]# ls +bin boot dev etc home lib lib64 log lost+found media mnt opt proc root run sbin srv sys tmp usr var +]0;root@node910b-0107140027:/[root@node910b-0107140027 /]# cd hiome +]0;root@node910b-0107140027:/home[root@node910b-0107140027 home]# ls +HwHiAiUser hwMindX +]0;root@node910b-0107140027:/home[root@node910b-0107140027 home]# la +总用量 0 +drwx------ 2 HwHiAiUser HwHiAiUser 62 5月 16 2024 HwHiAiUser +drwx------ 2 hwMindX hwMindX 62 1月 7 20:40 hwMindX +]0;root@node910b-0107140027:/home[root@node910b-0107140027 home]# cd .. +]0;root@node910b-0107140027:/[root@node910b-0107140027 /]# cd usr +]0;root@node910b-0107140027:/usr[root@node910b-0107140027 usr]# ls +bin games include lib lib64 libexec local mpi sbin share src tmp +]0;root@node910b-0107140027:/usr[root@node910b-0107140027 usr]# cd share +]0;root@node910b-0107140027:/usr/share[root@node910b-0107140027 share]# ls +aclocal awk dbus-1 eula gcc-7.3.0 grub ibdm2.1.1 libreport magic mstflint polkit-1 systemtap ucx +aclocal-1.16 backgrounds dbxtool factory GConf gtk-2.0 icons libthai makedumpfile nmap publicsuffix tabset vim +appdata bash-completion defaults file gdb gtk-3.0 idl libtool man omf python-wheels tcl8 wayland-sessions +applications bison desktop-directories firewalld gettext gtk-doc info licenses metainfo os-prober selinux tcl8.6 X11 +augeas cmake dict fish gettext-0.21 guile java locale mft p11-kit slsh terminfo xml +authselect config.site doc fontconfig gir-1.0 help kdump lshw mime perl5 snmp themes xsessions +autoconf cracklib emacs fonts glib-2.0 hwdata keyutils ltrace mime-info pixmaps sounds thumbnailers xtables +autogen crypto-policies empty games gnome i18n lemon lua misc pkgconfig ss tk8.6 zoneinfo +automake-1.16 ctyunos-release et gawk groff ibdiagnet2.1.1 libgpg-error lustre mlnx_ofed pki systemd tuned zsh +]0;root@node910b-0107140027:/usr/share[root@node910b-0107140027 share]# cd .. +]0;root@node910b-0107140027:/usr[root@node910b-0107140027 usr]# cd .. +]0;root@node910b-0107140027:/[root@node910b-0107140027 /]# ls +bin boot dev etc home lib lib64 log lost+found media mnt opt proc root run sbin srv sys tmp usr var +]0;root@node910b-0107140027:/[root@node910b-0107140027 /]# cd mnt +]0;root@node910b-0107140027:/mnt[root@node910b-0107140027 mnt]# ls +cwai data01 matrix +]0;root@node910b-0107140027:/mnt[root@node910b-0107140027 mnt]# cd data01 +]0;root@node910b-0107140027:/mnt/data01[root@node910b-0107140027 data01]# ls +]0;root@node910b-0107140027:/mnt/data01[root@node910b-0107140027 data01]# cd .. +]0;root@node910b-0107140027:/mnt[root@node910b-0107140027 mnt]# cd cwai +]0;root@node910b-0107140027:/mnt/cwai[root@node910b-0107140027 cwai]# ls +20250108lama70b caif-project-129 hw pjlab_data quwenwen_data +]0;root@node910b-0107140027:/mnt/cwai[root@node910b-0107140027 cwai]# cd hw +]0;root@node910b-0107140027:/mnt/cwai/hw[root@node910b-0107140027 hw]# ls +8.1.RC1.B050 deepseek-ai fusion_result.json lmdeploy model-weights qwen25-7B-hf +accelerate-0.26.0-py3-none-any.whl deepseek-r1-fp16 grpo lmdeploy_deepseek.tar mpich-4.3.0rc4 req.txt +antlr4-python3-runtime-4.7.2.tar.gz deepseek-r1-w8a8 grpo_commod.txt merge_weight.py nohup.out requirements.txt +apex-0.1.dev20240909+ascend-cp310-cp310-linux_aarch64.whl DeepSeek-V3-w8a8 hf.py mindie_2.0.T6-800I-A2-py3.11-openeuler24.03-lts-aarch64.tar.gz npu.sentinel.183 test.txt +Ascend-hdk-910b-npu-driver_24.1.0.3_linux-aarch64.run dl_w8a8.py lib64 MindSpeed_RL_code npu.sentinel.78 zjx-pip +Ascend-hdk-910b-npu-firmware_7.5.0.5.220.run dsw8a8 lib64.zip mindspeedrl.zip pip_pkg +]0;root@node910b-0107140027:/mnt/cwai/hw[root@node910b-0107140027 hw]# cd .. +]0;root@node910b-0107140027:/mnt/cwai[root@node910b-0107140027 cwai]# cd ls +20250108lama70b caif-project-129 hw pjlab_data quwenwen_data +]0;root@node910b-0107140027:/mnt/cwai[root@node910b-0107140027 cwai]# cd pjlab_data/ +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# ls +100 backup data internlm2_5 jin main_logs_old_data puyu3-delivery puyu3-delivery_new tmp var_log_bk +1200 code deeplink-rover internlm2_5-md5.txt logs main_logs_test puyu3-delivery_0208 script train-65aa8204fc2e-004705.jsonl yuansheng +ascend_image ctyun_test250107 hwtzd internlm2_5_new main_logs process.sh puyu3-delivery_0224 test user zos +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# cd ascend_image/ +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/ascend_image[root@node910b-0107140027 ascend_image]# ls +3dps_torch2_1_cann_8_0_0.tar +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/ascend_image[root@node910b-0107140027 ascend_image]# cd .. +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# conda +bash: conda:未找到命令 +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# cdonda list +bash: conda:未找到命令 +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data[root@node910b-0107140027 pjlab_data]# cd user /lusitian +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lusitian[root@node910b-0107140027 lusitian]# ls +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user/lusitian[root@node910b-0107140027 lusitian]# mkcd .. +]0;root@node910b-0107140027:/mnt/cwai/pjlab_data/user[root@node910b-0107140027 user]# df -h +文件系统 容量 已用 可用 已用% 挂载点 +devtmpfs 766G 0 766G 0% /dev +tmpfs 766G 128K 766G 1% /dev/shm +tmpfs 766G 4.2G 762G 1% /run +tmpfs 766G 0 766G 0% /sys/fs/cgroup +/dev/mapper/system-lv_root 100G 16G 85G 16% / +tmpfs 766G 64K 766G 1% /tmp +/dev/sda3 2.0G 140M 1.8G 8% /boot +/dev/sda2 1022M 6.5M 1016M 1% /boot/efi +/dev/nvme0n1p1 3.0T 79G 2.9T 3% /mnt/matrix +/dev/nvme1n1p1 3.0T 21G 2.9T 1% /mnt/data01 +tmpfs 154G 0 154G 0% /run/user/0 +100.97.192.61@o2ib:100.97.192.62@o2ib:/shRoce02/c6ceb324a3d55e56e23f0dc232f3f153_zfobzkj6jl2e8dhp 1.2P 1.1P 143T 89% /mnt/cwai/pjlab_data +s3fs 64P 0 64P 0% /mnt/cwai/quwenwen_data +tmpfs 766G 192K 766G 1% /mnt/matrix/kubelet/pods/2ee964c0-2407-4886-bdde-f27caf627ed1/volumes/kubernetes.io~secret/nodelocaldns-token-xkf52 +shm 64M 0 64M 0% /run/containerd/io.containerd.grpc.v1.cri/sandboxes/77acf91df09d6915c8e5b526d8a44fad3f4d14dcc55d2193e632fd294b3e7a07/shm +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/77acf91df09d6915c8e5b526d8a44fad3f4d14dcc55d2193e632fd294b3e7a07/rootfs +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/c26e64f9c45d2aa28cea1246f6a5c41a182ceed385bb5f562b3fb35e1811d1e8/rootfs +tmpfs 766G 192K 766G 1% /mnt/matrix/kubelet/pods/ac1bc2dc-9d1d-4c81-b337-6e3c77f40a83/volumes/kubernetes.io~secret/calico-node-token-tkddd +tmpfs 766G 192K 766G 1% /mnt/matrix/kubelet/pods/af77821d-68a3-4e12-87a4-5fbe771b3b96/volumes/kubernetes.io~secret/kube-proxy-token-89h5t +shm 64M 0 64M 0% /run/containerd/io.containerd.grpc.v1.cri/sandboxes/101c223db540e9a9eb124c3ad87355c6d7c3bcaa0d89ffe581fa13587c84cb43/shm +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/101c223db540e9a9eb124c3ad87355c6d7c3bcaa0d89ffe581fa13587c84cb43/rootfs +shm 64M 0 64M 0% /run/containerd/io.containerd.grpc.v1.cri/sandboxes/28f64ce4e52e2c2bd768d673e61a7ff00254792109421c070e96cae9c60ebb60/shm +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/28f64ce4e52e2c2bd768d673e61a7ff00254792109421c070e96cae9c60ebb60/rootfs +shm 64M 0 64M 0% /run/containerd/io.containerd.grpc.v1.cri/sandboxes/0cb0586ae659da6052c5ae8644acd07eb06d5651c589a96f5f25cd2046b565db/shm +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/0cb0586ae659da6052c5ae8644acd07eb06d5651c589a96f5f25cd2046b565db/rootfs +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/48fdd3d66d5738e83181d12125dbf8906d7d9d7ddf333a1888c653646d70ae81/rootfs +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/6e501c050238f81c7804790a677f3def51c74e20ebc50c3a3f8b66e26c2fb317/rootfs +overlay 3.0T 79G 2.9T 3% /run/containerd/io.containerd.runtime.v1.linux/k8s.io/f13ce0fa512654dbd8e8df07f7183263fca3be29ab288e49b9d0d0fe4a689121/rootfs +tmpfs 766G 128K 766G 1% /mnt/matrix/kubelet/pods/b81c07d9-5837-42bd-854f-91886f79f83d/volumes/kubernetes.io~secret/credentials +tmpfs 766G 192K 766G \ No newline at end of file diff --git a/search_module.py b/search_module.py new file mode 100644 index 000000000..5e92eb97e --- /dev/null +++ b/search_module.py @@ -0,0 +1,28 @@ +import inspect + +# 导入目标对象 +from torch_npu import npu_fusion_attention + +def get_object_path(obj): + # 检查对象是否具有 __file__ 属性(适用于模块或类) + if hasattr(obj, '__file__'): + print(f"File path of the object: {obj.__file__}") + return obj.__file__ + + # 如果是函数或方法,尝试通过 inspect 获取其所在模块 + if inspect.isfunction(obj) or inspect.ismethod(obj): + module = inspect.getmodule(obj) + if module and hasattr(module, '__file__'): + print(f"Function/Method is defined in module: {module.__name__}") + print(f"Module file path: {module.__file__}") + return module.__file__ + else: + print("Could not determine the file path of the function/method.") + return None + + # 如果对象没有明确的文件路径信息 + print("The object does not have a clear file path attribute.") + return None + +# 查找 npu_fusion_attention 的文件路径 +get_object_path(npu_fusion_attention) \ No newline at end of file diff --git a/set_env.sh b/set_env.sh new file mode 100644 index 000000000..9115647be --- /dev/null +++ b/set_env.sh @@ -0,0 +1,11 @@ +export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH +export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/$(arch):$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/tools/aml/lib64:${ASCEND_TOOLKIT_HOME}/tools/aml/lib64/plugin:$LD_LIBRARY_PATH +export PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:$PYTHONPATH +export PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:$PATH +export ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +export ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +export TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +export ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + diff --git a/tokens_num.py b/tokens_num.py new file mode 100644 index 000000000..67a9a14ce --- /dev/null +++ b/tokens_num.py @@ -0,0 +1,27 @@ + +# with open("/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data/en/refined-web-CC-MAIN-2013-20/train-000773-e51fcf79.bin.meta", "rb") as f: +# content = f.read() + +# print(content) + +import numpy as np + +meta_path = "/mnt/petrelfs/share_data/caizheng/train_ds/tokenized_data/en/refined-web-CC-MAIN-2013-20/train-000773-e51fcf79.bin.meta" +meta = np.load(meta_path) # 直接加载 NumPy 数组 + +# 输出示例: +print(meta.shape) # (N, 2) +# print(meta[:100]) # 第一行的 (cur, length) + +# 定义阈值列表(单位:token 数量) +thresholds = [16000, 32000, 64000, 128000] + +# 统计每个阈值的样本数量 +counts = {} +for thresh in thresholds: + mask = meta[:, 1] > thresh # 第二列是 token 数量 + counts[thresh] = np.sum(mask) + +# 输出结果 +for thresh, count in counts.items(): + print(f"Token 数量超过 {thresh//1000}k 的样本数: {count}") \ No newline at end of file diff --git a/torchrun.sh b/torchrun.sh new file mode 100644 index 000000000..67a26f874 --- /dev/null +++ b/torchrun.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 --master_addr="10.201.20.60" --master_port=29500 train.py --config ./configs/7B_isp_sft.py --launcher "torch" --profiling 2>&1 | tee debug.log \ No newline at end of file diff --git a/train.py b/train.py old mode 100755 new mode 100644 index 6e5e1399f..cd414f342 --- a/train.py +++ b/train.py @@ -11,7 +11,8 @@ from internlm.model.builder import create_model from internlm.monitor import internevo_monitor from internlm.utils.common import parse_args - +import torch +import torch_npu @internevo_monitor(feishu_alert=True, clean_run=True) def main(args): @@ -38,6 +39,7 @@ def main(args): # Initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None - + # Run the main function with parsed arguments main(args) +