diff --git a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py index 282b13b1e..e58366abd 100644 --- a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py @@ -95,6 +95,8 @@ def __init__( # pylint: disable=W0102 rope_theta=10000, rope_scaling=None, attn_implementation='eager', + rope_pos_id_version='default', + rope_pos_id_stride=None, **kwargs, ): self.vocab_size = vocab_size @@ -115,6 +117,8 @@ def __init__( # pylint: disable=W0102 self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride self._rope_scaling_validation() self.attn_implementation = attn_implementation diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py index 569513dff..85cef0c11 100644 --- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py @@ -40,6 +40,8 @@ except: # noqa # pylint: disable=bare-except BaseStreamer = None +from internvl.v2pe_utils import V2PE + from .configuration_internlm2 import InternLM2Config logger = logging.get_logger(__name__) @@ -323,30 +325,45 @@ def __init__(self, config: InternLM2Config): def _init_rope(self): if self.config.rope_scaling is None: - self.rotary_emb = InternLM2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.config.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling['type'] - scaling_factor = self.config.rope_scaling['factor'] - if scaling_type == 'dynamic': - self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + if self.config.rope_pos_id_version.startswith('v2pe_'): + self.rotary_emb = V2PE( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, - scaling_factor=scaling_factor, ) - elif scaling_type == 'linear': - self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + else: + self.rotary_emb = InternLM2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + else: + if self.config.rope_pos_id_version.startswith('v2pe_'): + warnings.warn(f'V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None. rope_scaling is {self.config.rope_scaling}') + self.rotary_emb = V2PE( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, - scaling_factor=scaling_factor, ) else: - raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.") + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'linear': + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + ) + elif scaling_type == 'dynamic': + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + ) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') return self.rotary_emb def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -391,8 +408,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if self.config.rope_pos_id_version.startswith('v2pe_'): + cos, sin = self.rotary_emb(value_states, global_posid=position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0)) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -493,10 +514,12 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if self.config.rope_pos_id_version.startswith('v2pe_'): + cos, sin = self.rotary_emb(value_states, global_posid=position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0)) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -1145,13 +1168,15 @@ def prepare_inputs_for_generation( position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1]:] + elif position_ids is not None: + if self.config.rope_pos_id_version!='default' and past_key_values is not None: + position_ids=(position_ids[:,-1]+attention_mask[:,position_ids.shape[1]:].sum(dim=1)).unsqueeze(1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {'inputs_embeds': inputs_embeds} else: model_inputs = {'input_ids': input_ids} - model_inputs.update( { 'position_ids': position_ids, diff --git a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py index 80abf7cba..36de6f7da 100644 --- a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py @@ -37,6 +37,8 @@ def __init__( ps_version='v1', min_dynamic_patch=1, max_dynamic_patch=6, + rope_pos_id_version='default', + rope_pos_id_stride=None, **kwargs): super().__init__(**kwargs) @@ -72,6 +74,8 @@ def __init__( self.ps_version = ps_version # pixel shuffle version self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride self.hidden_size = self.llm_config.hidden_size # By default, we use tie_word_embeddings=False for models of all sizes. @@ -82,6 +86,8 @@ def __init__( logger.info(f'ps_version: {self.ps_version}') logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') + logger.info(f'rope_pos_id_version: {self.rope_pos_id_version}') + logger.info(f'rope_pos_id_stride: {self.rope_pos_id_stride}') def to_dict(self): """ @@ -105,5 +111,7 @@ def to_dict(self): output['ps_version'] = self.ps_version output['min_dynamic_patch'] = self.min_dynamic_patch output['max_dynamic_patch'] = self.max_dynamic_patch + output['rope_pos_id_version'] = self.rope_pos_id_version + output['rope_pos_id_stride'] = self.rope_pos_id_stride return output diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index 7242136c5..caeaf3945 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -13,6 +13,7 @@ from internvl.conversation import get_conv_template from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM +from internvl.v2pe_utils import get_rope_pos_id from peft import LoraConfig, get_peft_model from torch import nn from torch.nn import CrossEntropyLoss @@ -158,6 +159,8 @@ def forward( ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if isinstance(position_ids, list): + position_ids = torch.tensor(position_ids, dtype=torch.float32).to(pixel_values.device) image_flags = image_flags.squeeze(-1) input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() @@ -341,7 +344,7 @@ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', - verbose=False): + verbose=False, **kwargs): if history is None and pixel_values is not None and '' not in question: question = '\n' + question @@ -378,12 +381,40 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) generation_config['eos_token_id'] = eos_token_id - generation_output = self.generate( - pixel_values=pixel_values, - input_ids=input_ids, - attention_mask=attention_mask, - **generation_config - ) + rope_pos_id_version = self.config.rope_pos_id_version + if rope_pos_id_version.startswith('v2pe_'): + pos_ids = [] + ret = {'input_ids': input_ids, 'attention_mask': attention_mask} + for i in range(input_ids.shape[0]): + cur_dtype = torch.float32 + rope_pos_id_stride = self.config.rope_pos_id_stride + cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer, + dtype=cur_dtype, + rope_pos_id_version=rope_pos_id_version, + position_id=torch.arange(0, input_ids.shape[1]), + IMG_START_TOKEN=IMG_START_TOKEN, + IMG_END_TOKEN=IMG_END_TOKEN, rope_pos_id_stride=rope_pos_id_stride, num_image_token=self.num_image_token) + + cur_pos_id = torch.tensor(cur_pos_id).to(device) + pos_ids.append(cur_pos_id) + + pos_ids = torch.stack(pos_ids).to(device) + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=pos_ids, + **generation_config + ) + else: + self.language_model.rope_pos_id_version = 'default' + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep.strip())[0].strip() history.append((question, response)) @@ -402,6 +433,7 @@ def generate( pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, @@ -430,6 +462,7 @@ def generate( outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, + position_ids=position_ids, generation_config=generation_config, output_hidden_states=output_hidden_states, use_cache=True, diff --git a/internvl_chat/internvl/patch/pad_data_collator.py b/internvl_chat/internvl/patch/pad_data_collator.py index 6282040f6..dffd76ba0 100644 --- a/internvl_chat/internvl/patch/pad_data_collator.py +++ b/internvl_chat/internvl/patch/pad_data_collator.py @@ -72,8 +72,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): feat['attention_mask'] = feat['input_ids'].ne(pad_id) if 'position_ids' in feat: - temp_position_ids = torch.LongTensor([pad_id] * max_item_length) - temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] + temp_position_ids = [pad_id] * max_item_length + if isinstance(feat['position_ids'], (list, tuple)): + pos_length = len(feat['position_ids']) + temp_position_ids[:pos_length] = feat['position_ids'] + else: + pos_length = feat['position_ids'].shape[0] + temp_position_ids[:pos_length] = feat['position_ids'] feat['position_ids'] = temp_position_ids if 'loss_weight' in feat: @@ -98,7 +103,7 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): # Handling of all other possible keys. # Again, we will use the first element to figure out which key/values are not None for this model. for k, v in first.items(): - if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ + if k not in ('label', 'label_ids', 'pixel_values', 'image_flags', 'position_ids') and \ v is not None and not isinstance(v, str): if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) @@ -113,6 +118,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0): batch[k] = torch.concat(np.stack([f[k] for f in features])) else: batch[k] = torch.concat([f[k] for f in features]) + if k in ('position_ids'): + if isinstance(v, torch.Tensor): + batch[k] = torch.concat([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.concat(np.stack([f[k] for f in features])) + else: + batch[k] = [f[k] for f in features] return batch diff --git a/internvl_chat/internvl/train/dataset_packed.py b/internvl_chat/internvl/train/dataset_packed.py index 3b54ed475..ef52a7bd8 100644 --- a/internvl_chat/internvl/train/dataset_packed.py +++ b/internvl_chat/internvl/train/dataset_packed.py @@ -240,7 +240,22 @@ def update_buffer(self, buffer, new_sample): assert buffer.keys() == new_sample.keys() for k in buffer: - buffer[k] = torch.cat([buffer[k], new_sample[k]]) + if isinstance(buffer[k], list) or isinstance(new_sample[k], list): + # Handle list type + if isinstance(buffer[k], list): + buffer_data = buffer[k] + else: + buffer_data = buffer[k].tolist() + + if isinstance(new_sample[k], list): + new_data = new_sample[k] + else: + new_data = new_sample[k].tolist() + + buffer[k] = buffer_data + new_data + else: + # Handle tensor type + buffer[k] = torch.cat([buffer[k], new_sample[k]]) return buffer @staticmethod diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index 42f669437..00be4befc 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -52,6 +52,7 @@ preprocess_internvl2_5, preprocess_mpt, preprocess_phi3) from internvl.train.dataset_packed import PackedDataset, packed_collate_fn +from internvl.v2pe_utils import get_rope_pos_id from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from torch.utils.data import Dataset from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -70,6 +71,7 @@ print('petrel_client is not installed. Using PIL to load images.') has_tcs_loader = False + # Set constants for image processing and logging IGNORE_INDEX = -100 Image.MAX_IMAGE_PIXELS = None @@ -157,6 +159,14 @@ class ModelArguments: default=False, metadata={'help': 'Set to True to use the liger kernel.'} ) + rope_pos_id_version: Optional[str] = field( + default='default', + metadata={'help': 'version for get_rope_pos_id'}, + ) + rope_pos_id_stride: Optional[int] = field( + default=None, + metadata={'help': 'stride for the version v4 of get_rope_pos_id'}, + ) @dataclass @@ -297,16 +307,22 @@ def __init__( distributed_mode=False, force_shuffle=False, random_seed=0, + rope_pos_id_version='default', + rope_pos_id_stride=None, ): super(LazySupervisedDataset, self).__init__() self.ds_name = ds_name self.tokenizer = tokenizer self.template_name = template_name self.num_image_token = num_image_token + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride logger.info(f'[Dataset] num_image_token: {num_image_token}') logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}') logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}') logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}') + logger.info(f'[Dataset] rope_pos_id_version: {rope_pos_id_version}') + logger.info(f'[Dataset] rope_pos_id_stride: {rope_pos_id_stride}') self.image_size = image_size self.is_train = is_train @@ -421,6 +437,7 @@ def multi_modal_get_item(self, data_item): # Build transformation function transform = self.get_transform() + num_tiles = [] # Ensure the first conversation contains an image placeholder if '' not in data_item['conversations'][0]['value']: data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] @@ -430,12 +447,15 @@ def multi_modal_get_item(self, data_item): # Load the image using tcs_loader if available, otherwise use PIL image = self.load_image(image_path) + orig_size = image.size if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, image_size=self.image_size, use_thumbnail=self.use_thumbnail) + num_tiles.append(len(images)) else: # Otherwise, use the original image as a single patch images = [image] + num_tiles.append(1) # Apply the transformation to each image and stack the results into a tensor pixel_values = [transform(image) for image in images] @@ -461,12 +481,16 @@ def multi_modal_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -511,12 +535,16 @@ def multi_modal_multi_image_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -567,12 +595,16 @@ def video_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -711,6 +743,8 @@ def build_datasets( min_num_frame=8, max_num_frame=32, normalize_type='imagenet', + rope_pos_id_version='default', + rope_pos_id_stride=None, ): datasets = [] lengths = [] @@ -749,6 +783,8 @@ def build_datasets( distributed_mode=data_args.use_packed_ds, force_shuffle=data_args.use_packed_ds, random_seed=ds_idx, + rope_pos_id_version=rope_pos_id_version, + rope_pos_id_stride=rope_pos_id_stride, ) logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}') datasets.append(dataset) @@ -906,6 +942,10 @@ def main(): config.ps_version = model_args.ps_version config.min_dynamic_patch = data_args.min_dynamic_patch config.max_dynamic_patch = data_args.max_dynamic_patch + config.rope_pos_id_version = model_args.rope_pos_id_version + config.rope_pos_id_stride = model_args.rope_pos_id_stride + config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride model = InternVLChatModel.from_pretrained( model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config) else: @@ -935,6 +975,10 @@ def main(): use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version, min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch) internvl_chat_config.force_image_size = data_args.force_image_size + internvl_chat_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.rope_pos_id_stride = model_args.rope_pos_id_stride + internvl_chat_config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride logger.info('Building InternVLChatModel...') model = InternVLChatModel(internvl_chat_config, vision_model, llm) model.img_context_token_id = img_context_token_id @@ -983,7 +1027,8 @@ def main(): dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail, min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame, - max_num_frame=data_args.max_num_frame) + max_num_frame=data_args.max_num_frame,rope_pos_id_version=model_args.rope_pos_id_version, + rope_pos_id_stride=model_args.rope_pos_id_stride) def _freeze_params(module): for param in module.parameters(): diff --git a/internvl_chat/internvl/train/internvl_chat_pretrain.py b/internvl_chat/internvl/train/internvl_chat_pretrain.py index d7962ae92..6fd97b77c 100644 --- a/internvl_chat/internvl/train/internvl_chat_pretrain.py +++ b/internvl_chat/internvl/train/internvl_chat_pretrain.py @@ -52,6 +52,7 @@ preprocess_internvl2_5, preprocess_mpt, preprocess_phi3) from internvl.train.dataset_packed import PackedDataset, packed_collate_fn +from internvl.v2pe_utils import get_rope_pos_id from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from torch.utils.data import Dataset from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -157,6 +158,14 @@ class ModelArguments: default=False, metadata={'help': 'Set to True to use the liger kernel.'} ) + rope_pos_id_version: Optional[str] = field( + default='default', + metadata={'help': 'version for get_rope_pos_id'}, + ) + rope_pos_id_stride: Optional[int] = field( + default=None, + metadata={'help': 'stride for the version v4 of get_rope_pos_id'}, + ) @dataclass @@ -297,16 +306,22 @@ def __init__( distributed_mode=False, force_shuffle=False, random_seed=0, + rope_pos_id_version='default', + rope_pos_id_stride=None, ): super(LazySupervisedDataset, self).__init__() self.ds_name = ds_name self.tokenizer = tokenizer self.template_name = template_name self.num_image_token = num_image_token + self.rope_pos_id_version = rope_pos_id_version + self.rope_pos_id_stride = rope_pos_id_stride logger.info(f'[Dataset] num_image_token: {num_image_token}') logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}') logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}') logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}') + logger.info(f'[Dataset] rope_pos_id_version: {rope_pos_id_version}') + logger.info(f'[Dataset] rope_pos_id_stride: {rope_pos_id_stride}') self.image_size = image_size self.is_train = is_train @@ -464,6 +479,7 @@ def multi_modal_get_item(self, data_item): # Build transformation function transform = self.get_transform() + num_tiles = [] # Ensure the first conversation contains an image placeholder if '' not in data_item['conversations'][0]['value']: data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] @@ -473,12 +489,15 @@ def multi_modal_get_item(self, data_item): # Load the image using tcs_loader if available, otherwise use PIL image = self.load_image(image_path) + orig_size = image.size if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, image_size=self.image_size, use_thumbnail=self.use_thumbnail) + num_tiles.append(len(images)) else: # Otherwise, use the original image as a single patch images = [image] + num_tiles.append(1) # Apply the transformation to each image and stack the results into a tensor pixel_values = [transform(image) for image in images] @@ -504,12 +523,16 @@ def multi_modal_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -554,12 +577,16 @@ def multi_modal_multi_image_get_item(self, data_item): image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -610,12 +637,16 @@ def video_get_item(self, data_item): position_ids = ret['attention_mask'].long().cumsum(-1) - 1 position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']: + position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token) + else: + position_ids = position_ids[0] # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], - position_ids=position_ids[0], + position_ids=position_ids, pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) @@ -755,6 +786,8 @@ def build_datasets( min_num_frame=8, max_num_frame=32, normalize_type='imagenet', + rope_pos_id_version='default', + rope_pos_id_stride=None, ): datasets = [] lengths = [] @@ -793,6 +826,8 @@ def build_datasets( distributed_mode=data_args.use_packed_ds, force_shuffle=data_args.use_packed_ds, random_seed=ds_idx, + rope_pos_id_version=rope_pos_id_version, + rope_pos_id_stride=rope_pos_id_stride, ) logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}') datasets.append(dataset) @@ -950,6 +985,10 @@ def main(): config.ps_version = model_args.ps_version config.min_dynamic_patch = data_args.min_dynamic_patch config.max_dynamic_patch = data_args.max_dynamic_patch + config.rope_pos_id_version = model_args.rope_pos_id_version + config.rope_pos_id_stride = model_args.rope_pos_id_stride + config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride model = InternVLChatModel.from_pretrained( model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config) else: @@ -977,8 +1016,13 @@ def main(): pad2square=data_args.pad2square, template=data_args.conv_style, select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version, - min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch) + min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, + rope_pos_id_version=model_args.rope_pos_id_version, rope_pos_id_stride=model_args.rope_pos_id_stride) internvl_chat_config.force_image_size = data_args.force_image_size + internvl_chat_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.rope_pos_id_stride = model_args.rope_pos_id_stride + internvl_chat_config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version + internvl_chat_config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride logger.info('Building InternVLChatModel...') model = InternVLChatModel(internvl_chat_config, vision_model, llm) model.img_context_token_id = img_context_token_id @@ -1027,7 +1071,8 @@ def main(): dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail, min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame, - max_num_frame=data_args.max_num_frame) + max_num_frame=data_args.max_num_frame, rope_pos_id_version=model_args.rope_pos_id_version, + rope_pos_id_stride=model_args.rope_pos_id_stride) def _freeze_params(module): for param in module.parameters(): diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py new file mode 100644 index 000000000..35803df9b --- /dev/null +++ b/internvl_chat/internvl/v2pe_utils.py @@ -0,0 +1,116 @@ +import random + +import torch +import torch.nn as nn + + +# get rope pos id while evaluating +def get_rope_pos_id(ret, dtype, rope_pos_id_version='default', position_id=None, + IMG_START_TOKEN='',IMG_END_TOKEN='',rope_pos_id_stride=None, tokenizer=None, num_image_token=256): + image_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) + image_end_token_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) + rope_pos_id_list = [] + assert ret['input_ids'].shape[0] == 1, 'batch size should be 1, other batch sizes are not supported yet' + input_ids_0 = ret['input_ids'][0] + attention_mask_0 = ret['attention_mask'][0] + image_start_token_id_idxs = torch.where(input_ids_0 == image_start_token_id)[0] + image_end_token_id_idxs = torch.where(input_ids_0 == image_end_token_id)[0] + num_tiles = (image_end_token_id_idxs - image_start_token_id_idxs) // num_image_token + last_record_pos_id = -1 + start_index = 0 + + assert rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd', 'default'], f'{rope_pos_id_version} not supported for eval' + + for i in range(len(image_start_token_id_idxs)): + + num_tile = num_tiles[i] + + rope_pos_id_pre = attention_mask_0[start_index:image_start_token_id_idxs[i] + 1].long().cumsum(-1) - 1 + (last_record_pos_id + 1) + rope_pos_id_pre.masked_fill_(attention_mask_0[start_index:image_start_token_id_idxs[i] + 1] == 0, 1) + rope_pos_id_list.append(rope_pos_id_pre) + + last_record_pos_id = rope_pos_id_pre[-1].long() + + if rope_pos_id_version == 'v2pe_fix': + assert rope_pos_id_stride is not None, 'when rope_pos_id_version is fix, self.rope_pos_id_stride should not be None' + small_stride = rope_pos_id_stride / num_image_token + split_img_id_idxs = torch.linspace(last_record_pos_id,last_record_pos_id+small_stride*(num_image_token * num_tile ),(num_image_token * num_tile + 1))[1:].to(dtype=dtype) + rope_pos_id_list.append(split_img_id_idxs) + last_record_pos_id = torch.ceil(split_img_id_idxs[-1]).long() + elif rope_pos_id_version == 'v2pe_rnd': + random_from=[1,2,4,8,16,32,64,128,256] + rope_pos_id_stride=random.choice(random_from) + small_stride = rope_pos_id_stride / num_image_token + split_img_id_idxs = torch.linspace(last_record_pos_id,last_record_pos_id+small_stride*(num_image_token * num_tile ),(num_image_token * num_tile + 1))[1:].to(dtype=dtype) + rope_pos_id_list.append(split_img_id_idxs) + last_record_pos_id = torch.ceil(split_img_id_idxs[-1]).long() + elif rope_pos_id_version == 'default': + split_img_id_idxs = torch.linspace(last_record_pos_id, + last_record_pos_id + (num_tile - 1) * num_image_token, + (num_tile - 1) * num_image_token + 1)[1:].to(dtype=dtype) + rope_pos_id_list.append(split_img_id_idxs) + thumbnail_id_idxs = torch.linspace(last_record_pos_id + (num_tile - 1) * num_image_token, + last_record_pos_id + num_tile * num_image_token, + num_image_token + 1)[1:].to(dtype=dtype) + rope_pos_id_list.append(thumbnail_id_idxs) + last_record_pos_id = (last_record_pos_id + num_tile * num_image_token).long() + else: + raise NotImplementedError(f'not implement for {rope_pos_id_version}') + + start_index = image_start_token_id_idxs[i] + num_tile * num_image_token + 1 + assert input_ids_0[start_index] == image_end_token_id + assert start_index == image_end_token_id_idxs[i] + + assert image_end_token_id_idxs[-1] == start_index + rope_pos_id_pre = attention_mask_0[start_index:].long().cumsum(-1) - 1 + (last_record_pos_id + 1) + rope_pos_id_pre.masked_fill_(attention_mask_0[start_index:] == 0, 1) + rope_pos_id_list.append(rope_pos_id_pre) + + rope_pos_id_list=[_.to('cpu') for _ in rope_pos_id_list] + rope_pos_id = torch.cat(rope_pos_id_list).to(dtype=dtype) + if rope_pos_id_version == 'default': + rope_pos_id = rope_pos_id.long() + assert torch.equal(rope_pos_id, position_id.to(rope_pos_id.device)), (rope_pos_id, position_id.to(rope_pos_id.device)) + assert torch.allclose(rope_pos_id, position_id.to(rope_pos_id.device), atol=1e-32) + + assert rope_pos_id.shape == input_ids_0.shape + + return list(rope_pos_id.numpy()) + + +# Rotary Position Embedding for V2PE +class V2PE(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = None + # inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + # self.register_buffer('inv_freq', inv_freq, persistent=False) + + self.max_seq_len_cached = -1 + + def _set_cos_sin_cache(self, pos_id, device, dtype): + if self.inv_freq is None: + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + del self.inv_freq + self.register_buffer('inv_freq', inv_freq, persistent=False) + + pos_id=pos_id.squeeze(0) + freqs = torch.outer(pos_id, self.inv_freq.to(device=pos_id.device)) + + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, global_posid=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self._set_cos_sin_cache(pos_id=global_posid, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:].to(dtype=x.dtype), + self.sin_cached[:].to(dtype=x.dtype), + )