From 4c6f4717db3c62afb4b0e252b7d1dd31368c42b0 Mon Sep 17 00:00:00 2001 From: gejq21 Date: Sat, 19 Apr 2025 11:30:08 +0800 Subject: [PATCH 1/6] add v2pe utils and supprot v2pe for fintune --- .../internvl/patch/pad_data_collator.py | 2 +- .../internvl/train/internvl_chat_finetune.py | 43 ++++++- internvl_chat/internvl/v2pe_utils.py | 116 ++++++++++++++++++ 3 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 internvl_chat/internvl/v2pe_utils.py diff --git a/internvl_chat/internvl/patch/pad_data_collator.py b/internvl_chat/internvl/patch/pad_data_collator.py index 6282040f6..3d0ccf516 100644 --- a/internvl_chat/internvl/patch/pad_data_collator.py +++ b/internvl_chat/internvl/patch/pad_data_collator.py @@ -72,7 +72,7 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): feat['attention_mask'] = feat['input_ids'].ne(pad_id) if 'position_ids' in feat: - temp_position_ids = torch.LongTensor([pad_id] * max_item_length) + temp_position_ids = [pad_id] * max_item_length temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] feat['position_ids'] = temp_position_ids diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index 42f669437..c3f431c8d 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -52,6 +52,7 @@ preprocess_internvl2_5, preprocess_mpt, preprocess_phi3) from internvl.train.dataset_packed import PackedDataset, packed_collate_fn +from internvl.v2pe_utils import get_rope_pos_id from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from torch.utils.data import Dataset from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -70,6 +71,7 @@ print('petrel_client is not installed. Using PIL to load images.') has_tcs_loader = False + # Set constants for image processing and logging IGNORE_INDEX = -100 Image.MAX_IMAGE_PIXELS = None @@ -264,6 +266,14 @@ class DataTrainingArguments: default=False, metadata={'help': 'Whether to gather all during loss reduction. Default is False.'}, ) + rope_pos_id_version: Optional[str] = field( + default='default', + metadata={'help': 'version for get_rope_pos_id'}, + ) + rope_pos_id_stride: Optional[int] = field( + default=None, + metadata={'help': 'stride for the version v4 of get_rope_pos_id'}, + ) class LazySupervisedDataset(Dataset): @@ -297,16 +307,22 @@ def __init__( distributed_mode=False, force_shuffle=False, random_seed=0, + rope_pos_id_version='default', + rope_pos_id_stride=None, ): super(LazySupervisedDataset, self).__init__() self.ds_name = ds_name self.tokenizer = tokenizer self.template_name = template_name self.num_image_token = num_image_token + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride logger.info(f'[Dataset] num_image_token: {num_image_token}') logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}') logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}') logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}') + logger.info(f'[Dataset] rope_pos_id_version: {rope_pos_id_version}') + logger.info(f'[Dataset] rope_pos_id_stride: {rope_pos_id_stride}') self.image_size = image_size self.is_train = is_train @@ -421,6 +437,7 @@ def multi_modal_get_item(self, data_item): # Build transformation function transform = self.get_transform() + num_tiles = [] # Ensure the first conversation contains an image placeholder if '' not in data_item['conversations'][0]['value']: data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] @@ -430,12 +447,15 @@ def multi_modal_get_item(self, data_item): # Load the image using tcs_loader if available, otherwise use PIL image = self.load_image(image_path) + orig_size = image.size if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically - images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, + images, boxes = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, image_size=self.image_size, use_thumbnail=self.use_thumbnail) + num_tiles.append(len(images)) else: # Otherwise, use the original image as a single patch images = [image] + num_tiles.append(1) # Apply the transformation to each image and stack the results into a tensor pixel_values = [transform(image) for image in images] @@ -461,12 +481,16 @@ def multi_modal_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -511,12 +535,16 @@ def multi_modal_multi_image_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -567,12 +595,19 @@ def video_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) + assert (ret['input_ids'][0] == image_end_token_id).sum() == num_patches, f'image tokens are truncated, this dataset is {self.ds_name}' + + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, [1] * num_patches, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py new file mode 100644 index 000000000..44fcaf318 --- /dev/null +++ b/internvl_chat/internvl/v2pe_utils.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import random +# get rope pos id while evaluating +def get_rope_pos_id(ret, num_tiles, dtype, rope_pos_id_version='default', position_id=None, + IMG_START_TOKEN='',IMG_END_TOKEN='',rope_pos_id_stride=None, tokenizer=None): + image_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) + image_end_token_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) + num_image_token=256 + rope_pos_id_list = [] + + input_ids_0 = ret['input_ids'][0] + attention_mask_0 = ret['attention_mask'][0] + image_start_token_id_idxs = torch.where(input_ids_0 == image_start_token_id)[0] + image_end_token_id_idxs = torch.where(input_ids_0 == image_end_token_id)[0] + + last_record_pos_id = -1 + start_index = 0 + + assert rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd', 'default'], f'{rope_pos_id_version} not supported for eval' + + for i in range(len(image_start_token_id_idxs)): + + num_tile = num_tiles[i] + + rope_pos_id_pre = attention_mask_0[start_index:image_start_token_id_idxs[i] + 1].long().cumsum(-1) - 1 + (last_record_pos_id + 1) + rope_pos_id_pre.masked_fill_(attention_mask_0[start_index:image_start_token_id_idxs[i] + 1] == 0, 1) + rope_pos_id_list.append(rope_pos_id_pre) + + last_record_pos_id = rope_pos_id_pre[-1].long() + + if rope_pos_id_version == 'v2pe_fix': + assert rope_pos_id_stride is not None, 'when rope_pos_id_version is fix, self.rope_pos_id_stride should not be None' + small_stride = rope_pos_id_stride / num_image_token + split_img_id_idxs = torch.linspace(last_record_pos_id,last_record_pos_id+small_stride*(num_image_token * num_tile ),(num_image_token * num_tile + 1))[1:].to(dtype=dtype) + rope_pos_id_list.append(split_img_id_idxs) + last_record_pos_id = torch.ceil(split_img_id_idxs[-1]).long() + elif rope_pos_id_version == 'v2pe_rnd': + random_from=[1,2,4,8,16,32,64,128,256] + rope_pos_id_stride=random.choice(random_from) + small_stride = rope_pos_id_stride / num_image_token + split_img_id_idxs = torch.linspace(last_record_pos_id,last_record_pos_id+small_stride*(num_image_token * num_tile ),(num_image_token * num_tile + 1))[1:].to(dtype=dtype) + rope_pos_id_list.append(split_img_id_idxs) + last_record_pos_id = torch.ceil(split_img_id_idxs[-1]).long() + elif rope_pos_id_version == 'default': + split_img_id_idxs = torch.linspace(last_record_pos_id, + last_record_pos_id + (num_tile - 1) * num_image_token, + (num_tile - 1) * num_image_token + 1)[1:].to(dtype=dtype) + rope_pos_id_list.append(split_img_id_idxs) + thumbnail_id_idxs = torch.linspace(last_record_pos_id + (num_tile - 1) * num_image_token, + last_record_pos_id + num_tile * num_image_token, + num_image_token + 1)[1:].to(dtype=dtype) + rope_pos_id_list.append(thumbnail_id_idxs) + last_record_pos_id = (last_record_pos_id + num_tile * num_image_token).long() + else: + raise NotImplementedError(f'not implement for {rope_pos_id_version}') + + start_index = image_start_token_id_idxs[i] + num_tile * num_image_token + 1 + assert input_ids_0[start_index] == image_end_token_id + assert start_index == image_end_token_id_idxs[i] + + assert image_end_token_id_idxs[-1] == start_index + rope_pos_id_pre = attention_mask_0[start_index:].long().cumsum(-1) - 1 + (last_record_pos_id + 1) + rope_pos_id_pre.masked_fill_(attention_mask_0[start_index:] == 0, 1) + rope_pos_id_list.append(rope_pos_id_pre) + + rope_pos_id_list=[_.to('cpu') for _ in rope_pos_id_list] + rope_pos_id = torch.cat(rope_pos_id_list).to(dtype=dtype) + if rope_pos_id_version == 'default': + rope_pos_id = rope_pos_id.long() + assert torch.equal(rope_pos_id, position_id.to(rope_pos_id.device)), (rope_pos_id, position_id.to(rope_pos_id.device)) + assert torch.allclose(rope_pos_id, position_id.to(rope_pos_id.device), atol=1e-32) + + assert rope_pos_id.shape == input_ids_0.shape + + return list(rope_pos_id.numpy()) + + +# Rotary Position Embedding for V2PE +class V2PE(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0,scale_img=False): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = None + self.scaling_factor=scaling_factor + self.scale_img=scale_img + # inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + # self.register_buffer('inv_freq', inv_freq, persistent=False) + + self.max_seq_len_cached = -1 + + def _set_cos_sin_cache(self, pos_id, device, dtype): + if self.inv_freq is None: + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + del self.inv_freq + self.register_buffer('inv_freq', inv_freq, persistent=False) + + pos_id=pos_id.squeeze(0) + freqs = torch.outer(pos_id, self.inv_freq.to(device=pos_id.device)) + + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, global_posid=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self._set_cos_sin_cache(pos_id=global_posid, device=x.device, dtype=x.dtype,selected=selected) + + return ( + self.cos_cached[:].to(dtype=x.dtype), + self.sin_cached[:].to(dtype=x.dtype), + ) \ No newline at end of file From 0c8993dcb82474bbd1f19e01ded5806d8aa5b8ab Mon Sep 17 00:00:00 2001 From: gejq21 Date: Sun, 20 Apr 2025 00:01:34 +0800 Subject: [PATCH 2/6] add v2pe to internvl model --- .../internlm2/configuration_internlm2.py | 6 +- .../model/internlm2/modeling_internlm2.py | 56 ++++++++++++------- .../configuration_internvl_chat.py | 8 +++ .../internvl_chat/modeling_internvl_chat.py | 54 +++++++++++++++--- .../internvl/train/internvl_chat_finetune.py | 8 +++ internvl_chat/internvl/v2pe_utils.py | 3 +- 6 files changed, 106 insertions(+), 29 deletions(-) diff --git a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py index 282b13b1e..297c53173 100644 --- a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py @@ -95,6 +95,8 @@ def __init__( # pylint: disable=W0102 rope_theta=10000, rope_scaling=None, attn_implementation='eager', + rope_pos_id_version='default', + rope_pos_id_stride=None, **kwargs, ): self.vocab_size = vocab_size @@ -115,8 +117,10 @@ def __init__( # pylint: disable=W0102 self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride self._rope_scaling_validation() - + self.attn_implementation = attn_implementation if self.attn_implementation is None: self.attn_implementation = 'eager' diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py index 569513dff..983e9d42d 100644 --- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py @@ -41,6 +41,7 @@ BaseStreamer = None from .configuration_internlm2 import InternLM2Config +from internvl.v2pe_utils import V2PE logger = logging.get_logger(__name__) @@ -323,30 +324,41 @@ def __init__(self, config: InternLM2Config): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = InternLM2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.config.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling['type'] - scaling_factor = self.config.rope_scaling['factor'] - if scaling_type == 'dynamic': - self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + if self.config.rope_pos_id_version == 'v2pe': + self.rotary_emb = V2PE( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + rope_pos_id_stride=self.config.rope_pos_id_stride + ) + else: + self.rotary_emb = InternLM2RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, - scaling_factor=scaling_factor, ) - elif scaling_type == 'linear': + else: + if self.config.rope_pos_id_version == 'v2pe': + raise ValueError("V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None.") + + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, base=self.config.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, + base=self.config.rope_theta, ) else: - raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.") + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return self.rotary_emb def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -391,8 +403,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if self.config.rope_pos_id_version == 'v2pe': + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0)) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -493,10 +509,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if self.config.rope_pos_id_version == 'v2pe': + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0)) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention diff --git a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py index 80abf7cba..36de6f7da 100644 --- a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py @@ -37,6 +37,8 @@ def __init__( ps_version='v1', min_dynamic_patch=1, max_dynamic_patch=6, + rope_pos_id_version='default', + rope_pos_id_stride=None, **kwargs): super().__init__(**kwargs) @@ -72,6 +74,8 @@ def __init__( self.ps_version = ps_version # pixel shuffle version self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride self.hidden_size = self.llm_config.hidden_size # By default, we use tie_word_embeddings=False for models of all sizes. @@ -82,6 +86,8 @@ def __init__( logger.info(f'ps_version: {self.ps_version}') logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') + logger.info(f'rope_pos_id_version: {self.rope_pos_id_version}') + logger.info(f'rope_pos_id_stride: {self.rope_pos_id_stride}') def to_dict(self): """ @@ -105,5 +111,7 @@ def to_dict(self): output['ps_version'] = self.ps_version output['min_dynamic_patch'] = self.min_dynamic_patch output['max_dynamic_patch'] = self.max_dynamic_patch + output['rope_pos_id_version'] = self.rope_pos_id_version + output['rope_pos_id_stride'] = self.rope_pos_id_stride return output diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index 7242136c5..b9dc43927 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -24,6 +24,7 @@ from .configuration_internvl_chat import InternVLChatConfig from .modeling_intern_vit import InternVisionModel, has_flash_attn +from internvl.v2pe_utils import get_rope_pos_id logger = logging.get_logger(__name__) @@ -341,7 +342,7 @@ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', - verbose=False): + verbose=False, **kwargs): if history is None and pixel_values is not None and '' not in question: question = '\n' + question @@ -378,12 +379,49 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) generation_config['eos_token_id'] = eos_token_id - generation_output = self.generate( - pixel_values=pixel_values, - input_ids=input_ids, - attention_mask=attention_mask, - **generation_config - ) + + if 'rope_pos_id_version' in kwargs: + self.language_model.rope_pos_id_version = kwargs['rope_pos_id_version'] + pos_ids = [] + ret = {'input_ids': input_ids, 'attention_mask': attention_mask} + for i in range(input_ids.shape[0]): + if kwargs['rope_pos_id_version'] == 'default': + cur_dtype = torch.long + else: + cur_dtype = torch.float32 + + if 'rope_pos_id_stride' in kwargs: + rope_pos_id_stride = kwargs['rope_pos_id_stride'] + else: + rope_pos_id_stride = None + + cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer, num_tiles=kwargs['num_tiles'][i], + dtype=cur_dtype, + rope_pos_id_version=kwargs['rope_pos_id_version'], + position_id=torch.arange(0, input_ids.shape[1]), + IMG_START_TOKEN=IMG_START_TOKEN, + IMG_END_TOKEN=IMG_END_TOKEN, rope_pos_id_stride=rope_pos_id_stride) + + cur_pos_id = torch.tensor(cur_pos_id).to(device) + pos_ids.append(cur_pos_id) + + pos_ids = torch.stack(pos_ids) + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=pos_ids, + **generation_config + ) + else: + self.language_model.rope_pos_id_version = 'default' + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep.strip())[0].strip() history.append((question, response)) @@ -402,6 +440,7 @@ def generate( pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, @@ -430,6 +469,7 @@ def generate( outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, + position_ids=position_ids, generation_config=generation_config, output_hidden_states=output_hidden_states, use_cache=True, diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index c3f431c8d..b916482c7 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -941,6 +941,10 @@ def main(): config.ps_version = model_args.ps_version config.min_dynamic_patch = data_args.min_dynamic_patch config.max_dynamic_patch = data_args.max_dynamic_patch + config.rope_pos_id_version = model_args.rope_pos_id_version + config.rope_pos_id_stride = model_args.rope_pos_id_stride + config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride model = InternVLChatModel.from_pretrained( model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config) else: @@ -970,6 +974,10 @@ def main(): use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version, min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch) internvl_chat_config.force_image_size = data_args.force_image_size + internvl_chat_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.rope_pos_id_stride = model_args.rope_pos_id_stride + internvl_chat_config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride logger.info('Building InternVLChatModel...') model = InternVLChatModel(internvl_chat_config, vision_model, llm) model.img_context_token_id = img_context_token_id diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py index 44fcaf318..544003e5a 100644 --- a/internvl_chat/internvl/v2pe_utils.py +++ b/internvl_chat/internvl/v2pe_utils.py @@ -78,7 +78,7 @@ def get_rope_pos_id(ret, num_tiles, dtype, rope_pos_id_version='default', positi # Rotary Position Embedding for V2PE class V2PE(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0,scale_img=False): + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): super().__init__() self.dim = dim @@ -86,7 +86,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor self.base = base self.inv_freq = None self.scaling_factor=scaling_factor - self.scale_img=scale_img # inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) # self.register_buffer('inv_freq', inv_freq, persistent=False) From 9c45ce0988a7244d1411078f8709bd129c054821 Mon Sep 17 00:00:00 2001 From: Weiyun1025 <1612893408@qq.com> Date: Sun, 20 Apr 2025 02:23:02 +0800 Subject: [PATCH 3/6] fixed bug and support v2pe in pretrain --- .../model/internlm2/modeling_internlm2.py | 48 ++++++++-------- .../internvl_chat/modeling_internvl_chat.py | 18 +++--- .../internvl/patch/pad_data_collator.py | 16 +++++- .../internvl/train/dataset_packed.py | 17 +++++- .../internvl/train/internvl_chat_finetune.py | 30 +++++----- .../internvl/train/internvl_chat_pretrain.py | 56 +++++++++++++++++-- internvl_chat/internvl/v2pe_utils.py | 5 +- 7 files changed, 132 insertions(+), 58 deletions(-) diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py index 983e9d42d..e934d53da 100644 --- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py @@ -324,12 +324,11 @@ def __init__(self, config: InternLM2Config): def _init_rope(self): if self.config.rope_scaling is None: - if self.config.rope_pos_id_version == 'v2pe': + if self.config.rope_pos_id_version.startswith('v2pe_'): self.rotary_emb = V2PE( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, - rope_pos_id_stride=self.config.rope_pos_id_stride ) else: self.rotary_emb = InternLM2RotaryEmbedding( @@ -338,27 +337,32 @@ def _init_rope(self): base=self.config.rope_theta, ) else: - if self.config.rope_pos_id_version == 'v2pe': - raise ValueError("V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None.") - - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.config.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + if self.config.rope_pos_id_version.startswith('v2pe_'): + warnings.warn(f"V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None. rope_scaling is {self.config.rope_scaling}") + self.rotary_emb = V2PE( self.head_dim, max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, base=self.config.rope_theta, ) else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return self.rotary_emb def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -403,8 +407,8 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - if self.config.rope_pos_id_version == 'v2pe': - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if self.config.rope_pos_id_version.startswith('v2pe_'): + cos, sin = self.rotary_emb(value_states, global_posid=position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0)) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -509,8 +513,8 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - if self.config.rope_pos_id_version == 'v2pe': - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if self.config.rope_pos_id_version.startswith('v2pe_'): + cos, sin = self.rotary_emb(value_states, global_posid=position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0)) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index b9dc43927..1c333a658 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -159,6 +159,8 @@ def forward( ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if isinstance(position_ids, list): + position_ids = torch.tensor(position_ids, dtype=torch.float32).to(pixel_values.device) image_flags = image_flags.squeeze(-1) input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() @@ -379,21 +381,15 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) generation_config['eos_token_id'] = eos_token_id - - if 'rope_pos_id_version' in kwargs: + rope_pos_id_version = self.config.rope_pos_id_version + if rope_pos_id_version.startswith('v2pe_'): self.language_model.rope_pos_id_version = kwargs['rope_pos_id_version'] pos_ids = [] ret = {'input_ids': input_ids, 'attention_mask': attention_mask} for i in range(input_ids.shape[0]): - if kwargs['rope_pos_id_version'] == 'default': - cur_dtype = torch.long - else: - cur_dtype = torch.float32 - - if 'rope_pos_id_stride' in kwargs: - rope_pos_id_stride = kwargs['rope_pos_id_stride'] - else: - rope_pos_id_stride = None + + cur_dtype = torch.float32 + rope_pos_id_stride = self.config.rope_pos_id_stride cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer, num_tiles=kwargs['num_tiles'][i], dtype=cur_dtype, diff --git a/internvl_chat/internvl/patch/pad_data_collator.py b/internvl_chat/internvl/patch/pad_data_collator.py index 3d0ccf516..dffd76ba0 100644 --- a/internvl_chat/internvl/patch/pad_data_collator.py +++ b/internvl_chat/internvl/patch/pad_data_collator.py @@ -73,7 +73,12 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): if 'position_ids' in feat: temp_position_ids = [pad_id] * max_item_length - temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] + if isinstance(feat['position_ids'], (list, tuple)): + pos_length = len(feat['position_ids']) + temp_position_ids[:pos_length] = feat['position_ids'] + else: + pos_length = feat['position_ids'].shape[0] + temp_position_ids[:pos_length] = feat['position_ids'] feat['position_ids'] = temp_position_ids if 'loss_weight' in feat: @@ -98,7 +103,7 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): # Handling of all other possible keys. # Again, we will use the first element to figure out which key/values are not None for this model. for k, v in first.items(): - if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ + if k not in ('label', 'label_ids', 'pixel_values', 'image_flags', 'position_ids') and \ v is not None and not isinstance(v, str): if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) @@ -113,6 +118,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): batch[k] = torch.concat(np.stack([f[k] for f in features])) else: batch[k] = torch.concat([f[k] for f in features]) + if k in ('position_ids'): + if isinstance(v, torch.Tensor): + batch[k] = torch.concat([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.concat(np.stack([f[k] for f in features])) + else: + batch[k] = [f[k] for f in features] return batch diff --git a/internvl_chat/internvl/train/dataset_packed.py b/internvl_chat/internvl/train/dataset_packed.py index 3b54ed475..7bee03301 100644 --- a/internvl_chat/internvl/train/dataset_packed.py +++ b/internvl_chat/internvl/train/dataset_packed.py @@ -240,7 +240,22 @@ def update_buffer(self, buffer, new_sample): assert buffer.keys() == new_sample.keys() for k in buffer: - buffer[k] = torch.cat([buffer[k], new_sample[k]]) + if isinstance(buffer[k], list) or isinstance(new_sample[k], list): + # Handle list type + if isinstance(buffer[k], list): + buffer_data = buffer[k] + else: + buffer_data = buffer[k].tolist() + + if isinstance(new_sample[k], list): + new_data = new_sample[k] + else: + new_data = new_sample[k].tolist() + + buffer[k] = buffer_data + new_data + else: + # Handle tensor type + buffer[k] = torch.cat([buffer[k], new_sample[k]]) return buffer @staticmethod diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index b916482c7..10f19f98c 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -159,6 +159,14 @@ class ModelArguments: default=False, metadata={'help': 'Set to True to use the liger kernel.'} ) + rope_pos_id_version: Optional[str] = field( + default='default', + metadata={'help': 'version for get_rope_pos_id'}, + ) + rope_pos_id_stride: Optional[int] = field( + default=None, + metadata={'help': 'stride for the version v4 of get_rope_pos_id'}, + ) @dataclass @@ -266,14 +274,7 @@ class DataTrainingArguments: default=False, metadata={'help': 'Whether to gather all during loss reduction. Default is False.'}, ) - rope_pos_id_version: Optional[str] = field( - default='default', - metadata={'help': 'version for get_rope_pos_id'}, - ) - rope_pos_id_stride: Optional[int] = field( - default=None, - metadata={'help': 'stride for the version v4 of get_rope_pos_id'}, - ) + class LazySupervisedDataset(Dataset): @@ -450,7 +451,7 @@ def multi_modal_get_item(self, data_item): orig_size = image.size if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically - images, boxes = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, + images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, image_size=self.image_size, use_thumbnail=self.use_thumbnail) num_tiles.append(len(images)) else: # Otherwise, use the original image as a single patch @@ -595,9 +596,6 @@ def video_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) - image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) - assert (ret['input_ids'][0] == image_end_token_id).sum() == num_patches, f'image tokens are truncated, this dataset is {self.ds_name}' - if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: position_ids = get_rope_pos_id(ret, [1] * num_patches, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) else: @@ -746,6 +744,8 @@ def build_datasets( min_num_frame=8, max_num_frame=32, normalize_type='imagenet', + rope_pos_id_version='default', + rope_pos_id_stride=None, ): datasets = [] lengths = [] @@ -784,6 +784,8 @@ def build_datasets( distributed_mode=data_args.use_packed_ds, force_shuffle=data_args.use_packed_ds, random_seed=ds_idx, + rope_pos_id_version=rope_pos_id_version, + rope_pos_id_stride=rope_pos_id_stride, ) logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}') datasets.append(dataset) @@ -1026,8 +1028,8 @@ def main(): dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail, min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame, - max_num_frame=data_args.max_num_frame) - + max_num_frame=data_args.max_num_frame,rope_pos_id_version=model_args.rope_pos_id_version, + rope_pos_id_stride=model_args.rope_pos_id_stride) def _freeze_params(module): for param in module.parameters(): param.requires_grad = False diff --git a/internvl_chat/internvl/train/internvl_chat_pretrain.py b/internvl_chat/internvl/train/internvl_chat_pretrain.py index d7962ae92..bb4385029 100644 --- a/internvl_chat/internvl/train/internvl_chat_pretrain.py +++ b/internvl_chat/internvl/train/internvl_chat_pretrain.py @@ -52,6 +52,7 @@ preprocess_internvl2_5, preprocess_mpt, preprocess_phi3) from internvl.train.dataset_packed import PackedDataset, packed_collate_fn +from internvl.v2pe_utils import get_rope_pos_id from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from torch.utils.data import Dataset from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -157,6 +158,14 @@ class ModelArguments: default=False, metadata={'help': 'Set to True to use the liger kernel.'} ) + rope_pos_id_version: Optional[str] = field( + default='default', + metadata={'help': 'version for get_rope_pos_id'}, + ) + rope_pos_id_stride: Optional[int] = field( + default=None, + metadata={'help': 'stride for the version v4 of get_rope_pos_id'}, + ) @dataclass @@ -297,16 +306,22 @@ def __init__( distributed_mode=False, force_shuffle=False, random_seed=0, + rope_pos_id_version='default', + rope_pos_id_stride=None, ): super(LazySupervisedDataset, self).__init__() self.ds_name = ds_name self.tokenizer = tokenizer self.template_name = template_name self.num_image_token = num_image_token + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride logger.info(f'[Dataset] num_image_token: {num_image_token}') logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}') logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}') logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}') + logger.info(f'[Dataset] rope_pos_id_version: {rope_pos_id_version}') + logger.info(f'[Dataset] rope_pos_id_stride: {rope_pos_id_stride}') self.image_size = image_size self.is_train = is_train @@ -464,6 +479,7 @@ def multi_modal_get_item(self, data_item): # Build transformation function transform = self.get_transform() + num_tiles = [] # Ensure the first conversation contains an image placeholder if '' not in data_item['conversations'][0]['value']: data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] @@ -473,12 +489,15 @@ def multi_modal_get_item(self, data_item): # Load the image using tcs_loader if available, otherwise use PIL image = self.load_image(image_path) + orig_size = image.size if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, image_size=self.image_size, use_thumbnail=self.use_thumbnail) + num_tiles.append(len(images)) else: # Otherwise, use the original image as a single patch images = [image] + num_tiles.append(1) # Apply the transformation to each image and stack the results into a tensor pixel_values = [transform(image) for image in images] @@ -504,12 +523,16 @@ def multi_modal_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -554,12 +577,16 @@ def multi_modal_multi_image_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -610,12 +637,16 @@ def video_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, [1] * num_patches, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -653,6 +684,7 @@ def pure_text_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], @@ -755,6 +787,8 @@ def build_datasets( min_num_frame=8, max_num_frame=32, normalize_type='imagenet', + rope_pos_id_version='default', + rope_pos_id_stride=None, ): datasets = [] lengths = [] @@ -793,6 +827,8 @@ def build_datasets( distributed_mode=data_args.use_packed_ds, force_shuffle=data_args.use_packed_ds, random_seed=ds_idx, + rope_pos_id_version=rope_pos_id_version, + rope_pos_id_stride=rope_pos_id_stride, ) logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}') datasets.append(dataset) @@ -950,6 +986,10 @@ def main(): config.ps_version = model_args.ps_version config.min_dynamic_patch = data_args.min_dynamic_patch config.max_dynamic_patch = data_args.max_dynamic_patch + config.rope_pos_id_version = model_args.rope_pos_id_version + config.rope_pos_id_stride = model_args.rope_pos_id_stride + config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride model = InternVLChatModel.from_pretrained( model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config) else: @@ -977,8 +1017,13 @@ def main(): pad2square=data_args.pad2square, template=data_args.conv_style, select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version, - min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch) + min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, + rope_pos_id_version=model_args.rope_pos_id_version, rope_pos_id_stride=model_args.rope_pos_id_stride) internvl_chat_config.force_image_size = data_args.force_image_size + internvl_chat_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.rope_pos_id_stride = model_args.rope_pos_id_stride + internvl_chat_config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride logger.info('Building InternVLChatModel...') model = InternVLChatModel(internvl_chat_config, vision_model, llm) model.img_context_token_id = img_context_token_id @@ -1027,7 +1072,8 @@ def main(): dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail, min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame, - max_num_frame=data_args.max_num_frame) + max_num_frame=data_args.max_num_frame, rope_pos_id_version=model_args.rope_pos_id_version, + rope_pos_id_stride=model_args.rope_pos_id_stride) def _freeze_params(module): for param in module.parameters(): diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py index 544003e5a..28381b78f 100644 --- a/internvl_chat/internvl/v2pe_utils.py +++ b/internvl_chat/internvl/v2pe_utils.py @@ -78,14 +78,13 @@ def get_rope_pos_id(ret, num_tiles, dtype, rope_pos_id_version='default', positi # Rotary Position Embedding for V2PE class V2PE(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.inv_freq = None - self.scaling_factor=scaling_factor # inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) # self.register_buffer('inv_freq', inv_freq, persistent=False) @@ -107,7 +106,7 @@ def _set_cos_sin_cache(self, pos_id, device, dtype): def forward(self, x, global_posid=None): # x: [bs, num_attention_heads, seq_len, head_size] - self._set_cos_sin_cache(pos_id=global_posid, device=x.device, dtype=x.dtype,selected=selected) + self._set_cos_sin_cache(pos_id=global_posid, device=x.device, dtype=x.dtype) return ( self.cos_cached[:].to(dtype=x.dtype), From 3f07317d3a0983dbdac19b04039b67b850fe26ae Mon Sep 17 00:00:00 2001 From: Weiyun1025 <1612893408@qq.com> Date: Sun, 20 Apr 2025 03:31:27 +0800 Subject: [PATCH 4/6] support inference and fix bugs --- .../internvl/model/internlm2/modeling_internlm2.py | 6 ++++-- .../model/internvl_chat/modeling_internvl_chat.py | 11 ++++------- .../internvl/train/internvl_chat_finetune.py | 6 +++--- .../internvl/train/internvl_chat_pretrain.py | 6 +++--- internvl_chat/internvl/v2pe_utils.py | 9 ++++----- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py index e934d53da..3e7634250 100644 --- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py @@ -1159,7 +1159,7 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] - + position_ids = kwargs.get('position_ids', None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1167,13 +1167,15 @@ def prepare_inputs_for_generation( position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1]:] + elif position_ids is not None: + if self.config.rope_pos_id_version!='default' and past_key_values is not None: + position_ids=(position_ids[:,-1]+attention_mask[:,position_ids.shape[1]:].sum(dim=1)).unsqueeze(1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {'inputs_embeds': inputs_embeds} else: model_inputs = {'input_ids': input_ids} - model_inputs.update( { 'position_ids': position_ids, diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index 1c333a658..29071a918 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -383,25 +383,22 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non generation_config['eos_token_id'] = eos_token_id rope_pos_id_version = self.config.rope_pos_id_version if rope_pos_id_version.startswith('v2pe_'): - self.language_model.rope_pos_id_version = kwargs['rope_pos_id_version'] pos_ids = [] ret = {'input_ids': input_ids, 'attention_mask': attention_mask} for i in range(input_ids.shape[0]): - cur_dtype = torch.float32 rope_pos_id_stride = self.config.rope_pos_id_stride - - cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer, num_tiles=kwargs['num_tiles'][i], + cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer, dtype=cur_dtype, - rope_pos_id_version=kwargs['rope_pos_id_version'], + rope_pos_id_version=rope_pos_id_version, position_id=torch.arange(0, input_ids.shape[1]), IMG_START_TOKEN=IMG_START_TOKEN, - IMG_END_TOKEN=IMG_END_TOKEN, rope_pos_id_stride=rope_pos_id_stride) + IMG_END_TOKEN=IMG_END_TOKEN, rope_pos_id_stride=rope_pos_id_stride, num_image_token=self.num_image_token) cur_pos_id = torch.tensor(cur_pos_id).to(device) pos_ids.append(cur_pos_id) - pos_ids = torch.stack(pos_ids) + pos_ids = torch.stack(pos_ids).to(device) generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index 10f19f98c..c8251315a 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -483,7 +483,7 @@ def multi_modal_get_item(self, data_item): assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: - position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: position_ids = position_ids[0] # Create the final return dictionary @@ -537,7 +537,7 @@ def multi_modal_multi_image_get_item(self, data_item): assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: - position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: position_ids = position_ids[0] # Create the final return dictionary @@ -597,7 +597,7 @@ def video_get_item(self, data_item): position_ids.masked_fill_(ret['attention_mask'] == 0, 1) if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: - position_ids = get_rope_pos_id(ret, [1] * num_patches, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: position_ids = position_ids[0] # Create the final return dictionary diff --git a/internvl_chat/internvl/train/internvl_chat_pretrain.py b/internvl_chat/internvl/train/internvl_chat_pretrain.py index bb4385029..26571e0c6 100644 --- a/internvl_chat/internvl/train/internvl_chat_pretrain.py +++ b/internvl_chat/internvl/train/internvl_chat_pretrain.py @@ -524,7 +524,7 @@ def multi_modal_get_item(self, data_item): assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: - position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: position_ids = position_ids[0] # Create the final return dictionary @@ -578,7 +578,7 @@ def multi_modal_multi_image_get_item(self, data_item): assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: - position_ids = get_rope_pos_id(ret, num_tiles, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: position_ids = position_ids[0] # Create the final return dictionary @@ -638,7 +638,7 @@ def video_get_item(self, data_item): position_ids.masked_fill_(ret['attention_mask'] == 0, 1) if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: - position_ids = get_rope_pos_id(ret, [1] * num_patches, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer) + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: position_ids = position_ids[0] # Create the final return dictionary diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py index 28381b78f..2e0d70e89 100644 --- a/internvl_chat/internvl/v2pe_utils.py +++ b/internvl_chat/internvl/v2pe_utils.py @@ -2,18 +2,17 @@ import torch.nn as nn import random # get rope pos id while evaluating -def get_rope_pos_id(ret, num_tiles, dtype, rope_pos_id_version='default', position_id=None, - IMG_START_TOKEN='',IMG_END_TOKEN='',rope_pos_id_stride=None, tokenizer=None): +def get_rope_pos_id(ret, dtype, rope_pos_id_version='default', position_id=None, + IMG_START_TOKEN='',IMG_END_TOKEN='',rope_pos_id_stride=None, tokenizer=None, num_image_token=256): image_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) image_end_token_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) - num_image_token=256 rope_pos_id_list = [] - + assert ret['input_ids'].shape[0] == 1, 'batch size should be 1, other batch sizes are not supported yet' input_ids_0 = ret['input_ids'][0] attention_mask_0 = ret['attention_mask'][0] image_start_token_id_idxs = torch.where(input_ids_0 == image_start_token_id)[0] image_end_token_id_idxs = torch.where(input_ids_0 == image_end_token_id)[0] - + num_tiles = (image_end_token_id_idxs - image_start_token_id_idxs) // num_image_token last_record_pos_id = -1 start_index = 0 From 5cd4f3dba8245ac9833fcde4798c4b49d0526c11 Mon Sep 17 00:00:00 2001 From: gejq21 Date: Sun, 20 Apr 2025 11:45:34 +0800 Subject: [PATCH 5/6] fixed bugs --- .../model/internlm2/configuration_internlm2.py | 2 +- .../model/internlm2/modeling_internlm2.py | 17 +++++++++-------- .../internvl_chat/modeling_internvl_chat.py | 2 +- internvl_chat/internvl/train/dataset_packed.py | 4 ++-- .../internvl/train/internvl_chat_finetune.py | 4 ++-- internvl_chat/internvl/v2pe_utils.py | 9 ++++++--- 6 files changed, 21 insertions(+), 17 deletions(-) diff --git a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py index 297c53173..e58366abd 100644 --- a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py @@ -120,7 +120,7 @@ def __init__( # pylint: disable=W0102 self.rope_pos_id_version = rope_pos_id_version self.rope_pos_id_stride = rope_pos_id_stride self._rope_scaling_validation() - + self.attn_implementation = attn_implementation if self.attn_implementation is None: self.attn_implementation = 'eager' diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py index 3e7634250..85cef0c11 100644 --- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py @@ -40,9 +40,10 @@ except: # noqa # pylint: disable=bare-except BaseStreamer = None -from .configuration_internlm2 import InternLM2Config from internvl.v2pe_utils import V2PE +from .configuration_internlm2 import InternLM2Config + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = 'InternLM2Config' @@ -338,23 +339,23 @@ def _init_rope(self): ) else: if self.config.rope_pos_id_version.startswith('v2pe_'): - warnings.warn(f"V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None. rope_scaling is {self.config.rope_scaling}") + warnings.warn(f'V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None. rope_scaling is {self.config.rope_scaling}') self.rotary_emb = V2PE( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, ) else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'linear': self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.config.rope_theta, ) - elif scaling_type == "dynamic": + elif scaling_type == 'dynamic': self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, @@ -362,7 +363,7 @@ def _init_rope(self): base=self.config.rope_theta, ) else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') return self.rotary_emb def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -1159,7 +1160,7 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] - + position_ids = kwargs.get('position_ids', None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index 29071a918..caeaf3945 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -13,6 +13,7 @@ from internvl.conversation import get_conv_template from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM +from internvl.v2pe_utils import get_rope_pos_id from peft import LoraConfig, get_peft_model from torch import nn from torch.nn import CrossEntropyLoss @@ -24,7 +25,6 @@ from .configuration_internvl_chat import InternVLChatConfig from .modeling_intern_vit import InternVisionModel, has_flash_attn -from internvl.v2pe_utils import get_rope_pos_id logger = logging.get_logger(__name__) diff --git a/internvl_chat/internvl/train/dataset_packed.py b/internvl_chat/internvl/train/dataset_packed.py index 7bee03301..ef52a7bd8 100644 --- a/internvl_chat/internvl/train/dataset_packed.py +++ b/internvl_chat/internvl/train/dataset_packed.py @@ -246,12 +246,12 @@ def update_buffer(self, buffer, new_sample): buffer_data = buffer[k] else: buffer_data = buffer[k].tolist() - + if isinstance(new_sample[k], list): new_data = new_sample[k] else: new_data = new_sample[k].tolist() - + buffer[k] = buffer_data + new_data else: # Handle tensor type diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index c8251315a..00be4befc 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -276,7 +276,6 @@ class DataTrainingArguments: ) - class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" @@ -539,7 +538,7 @@ def multi_modal_multi_image_get_item(self, data_item): if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: - position_ids = position_ids[0] + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], @@ -1030,6 +1029,7 @@ def main(): normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame, max_num_frame=data_args.max_num_frame,rope_pos_id_version=model_args.rope_pos_id_version, rope_pos_id_stride=model_args.rope_pos_id_stride) + def _freeze_params(module): for param in module.parameters(): param.requires_grad = False diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py index 2e0d70e89..35803df9b 100644 --- a/internvl_chat/internvl/v2pe_utils.py +++ b/internvl_chat/internvl/v2pe_utils.py @@ -1,6 +1,9 @@ +import random + import torch import torch.nn as nn -import random + + # get rope pos id while evaluating def get_rope_pos_id(ret, dtype, rope_pos_id_version='default', position_id=None, IMG_START_TOKEN='',IMG_END_TOKEN='',rope_pos_id_stride=None, tokenizer=None, num_image_token=256): @@ -87,7 +90,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000): # inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) # self.register_buffer('inv_freq', inv_freq, persistent=False) - self.max_seq_len_cached = -1 + self.max_seq_len_cached = -1 def _set_cos_sin_cache(self, pos_id, device, dtype): if self.inv_freq is None: @@ -110,4 +113,4 @@ def forward(self, x, global_posid=None): return ( self.cos_cached[:].to(dtype=x.dtype), self.sin_cached[:].to(dtype=x.dtype), - ) \ No newline at end of file + ) From b7851a69ec6bef148ba3612af9df3aefd7298de5 Mon Sep 17 00:00:00 2001 From: gejq21 Date: Sun, 20 Apr 2025 11:45:44 +0800 Subject: [PATCH 6/6] fixed bugs --- internvl_chat/internvl/train/internvl_chat_pretrain.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internvl_chat/internvl/train/internvl_chat_pretrain.py b/internvl_chat/internvl/train/internvl_chat_pretrain.py index 26571e0c6..6fd97b77c 100644 --- a/internvl_chat/internvl/train/internvl_chat_pretrain.py +++ b/internvl_chat/internvl/train/internvl_chat_pretrain.py @@ -580,7 +580,7 @@ def multi_modal_multi_image_get_item(self, data_item): if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) else: - position_ids = position_ids[0] + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], @@ -684,7 +684,6 @@ def pure_text_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) - # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0],