diff --git a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py
index 282b13b1e..e58366abd 100644
--- a/internvl_chat/internvl/model/internlm2/configuration_internlm2.py
+++ b/internvl_chat/internvl/model/internlm2/configuration_internlm2.py
@@ -95,6 +95,8 @@ def __init__( # pylint: disable=W0102
rope_theta=10000,
rope_scaling=None,
attn_implementation='eager',
+ rope_pos_id_version='default',
+ rope_pos_id_stride=None,
**kwargs,
):
self.vocab_size = vocab_size
@@ -115,6 +117,8 @@ def __init__( # pylint: disable=W0102
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
+ self.rope_pos_id_version = rope_pos_id_version
+ self.rope_pos_id_stride = rope_pos_id_stride
self._rope_scaling_validation()
self.attn_implementation = attn_implementation
diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py
index 569513dff..85cef0c11 100644
--- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py
+++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py
@@ -40,6 +40,8 @@
except: # noqa # pylint: disable=bare-except
BaseStreamer = None
+from internvl.v2pe_utils import V2PE
+
from .configuration_internlm2 import InternLM2Config
logger = logging.get_logger(__name__)
@@ -323,30 +325,45 @@ def __init__(self, config: InternLM2Config):
def _init_rope(self):
if self.config.rope_scaling is None:
- self.rotary_emb = InternLM2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.config.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling['type']
- scaling_factor = self.config.rope_scaling['factor']
- if scaling_type == 'dynamic':
- self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
+ if self.config.rope_pos_id_version.startswith('v2pe_'):
+ self.rotary_emb = V2PE(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
- scaling_factor=scaling_factor,
)
- elif scaling_type == 'linear':
- self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
+ else:
+ self.rotary_emb = InternLM2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.config.rope_theta,
+ )
+ else:
+ if self.config.rope_pos_id_version.startswith('v2pe_'):
+ warnings.warn(f'V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None. rope_scaling is {self.config.rope_scaling}')
+ self.rotary_emb = V2PE(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
- scaling_factor=scaling_factor,
)
else:
- raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
+ scaling_type = self.config.rope_scaling['type']
+ scaling_factor = self.config.rope_scaling['factor']
+ if scaling_type == 'linear':
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.config.rope_theta,
+ )
+ elif scaling_type == 'dynamic':
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.config.rope_theta,
+ )
+ else:
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -391,8 +408,12 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ if self.config.rope_pos_id_version.startswith('v2pe_'):
+ cos, sin = self.rotary_emb(value_states, global_posid=position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0))
+ else:
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
@@ -493,10 +514,12 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
-
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ if self.config.rope_pos_id_version.startswith('v2pe_'):
+ cos, sin = self.rotary_emb(value_states, global_posid=position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0))
+ else:
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
@@ -1145,13 +1168,15 @@ def prepare_inputs_for_generation(
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
+ elif position_ids is not None:
+ if self.config.rope_pos_id_version!='default' and past_key_values is not None:
+ position_ids=(position_ids[:,-1]+attention_mask[:,position_ids.shape[1]:].sum(dim=1)).unsqueeze(1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
-
model_inputs.update(
{
'position_ids': position_ids,
diff --git a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py
index 80abf7cba..36de6f7da 100644
--- a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py
+++ b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py
@@ -37,6 +37,8 @@ def __init__(
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
+ rope_pos_id_version='default',
+ rope_pos_id_stride=None,
**kwargs):
super().__init__(**kwargs)
@@ -72,6 +74,8 @@ def __init__(
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
+ self.rope_pos_id_version = rope_pos_id_version
+ self.rope_pos_id_stride = rope_pos_id_stride
self.hidden_size = self.llm_config.hidden_size
# By default, we use tie_word_embeddings=False for models of all sizes.
@@ -82,6 +86,8 @@ def __init__(
logger.info(f'ps_version: {self.ps_version}')
logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
+ logger.info(f'rope_pos_id_version: {self.rope_pos_id_version}')
+ logger.info(f'rope_pos_id_stride: {self.rope_pos_id_stride}')
def to_dict(self):
"""
@@ -105,5 +111,7 @@ def to_dict(self):
output['ps_version'] = self.ps_version
output['min_dynamic_patch'] = self.min_dynamic_patch
output['max_dynamic_patch'] = self.max_dynamic_patch
+ output['rope_pos_id_version'] = self.rope_pos_id_version
+ output['rope_pos_id_stride'] = self.rope_pos_id_stride
return output
diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py
index 7242136c5..caeaf3945 100644
--- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py
+++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py
@@ -13,6 +13,7 @@
from internvl.conversation import get_conv_template
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM
+from internvl.v2pe_utils import get_rope_pos_id
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -158,6 +159,8 @@ def forward(
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if isinstance(position_ids, list):
+ position_ids = torch.tensor(position_ids, dtype=torch.float32).to(pixel_values.device)
image_flags = image_flags.squeeze(-1)
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
@@ -341,7 +344,7 @@ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
num_patches_list=None, IMG_START_TOKEN='
', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='',
- verbose=False):
+ verbose=False, **kwargs):
if history is None and pixel_values is not None and '' not in question:
question = '\n' + question
@@ -378,12 +381,40 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non
input_ids = model_inputs['input_ids'].to(device)
attention_mask = model_inputs['attention_mask'].to(device)
generation_config['eos_token_id'] = eos_token_id
- generation_output = self.generate(
- pixel_values=pixel_values,
- input_ids=input_ids,
- attention_mask=attention_mask,
- **generation_config
- )
+ rope_pos_id_version = self.config.rope_pos_id_version
+ if rope_pos_id_version.startswith('v2pe_'):
+ pos_ids = []
+ ret = {'input_ids': input_ids, 'attention_mask': attention_mask}
+ for i in range(input_ids.shape[0]):
+ cur_dtype = torch.float32
+ rope_pos_id_stride = self.config.rope_pos_id_stride
+ cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer,
+ dtype=cur_dtype,
+ rope_pos_id_version=rope_pos_id_version,
+ position_id=torch.arange(0, input_ids.shape[1]),
+ IMG_START_TOKEN=IMG_START_TOKEN,
+ IMG_END_TOKEN=IMG_END_TOKEN, rope_pos_id_stride=rope_pos_id_stride, num_image_token=self.num_image_token)
+
+ cur_pos_id = torch.tensor(cur_pos_id).to(device)
+ pos_ids.append(cur_pos_id)
+
+ pos_ids = torch.stack(pos_ids).to(device)
+ generation_output = self.generate(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=pos_ids,
+ **generation_config
+ )
+ else:
+ self.language_model.rope_pos_id_version = 'default'
+ generation_output = self.generate(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ **generation_config
+ )
+
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
response = response.split(template.sep.strip())[0].strip()
history.append((question, response))
@@ -402,6 +433,7 @@ def generate(
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
@@ -430,6 +462,7 @@ def generate(
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
+ position_ids=position_ids,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
diff --git a/internvl_chat/internvl/patch/pad_data_collator.py b/internvl_chat/internvl/patch/pad_data_collator.py
index 6282040f6..dffd76ba0 100644
--- a/internvl_chat/internvl/patch/pad_data_collator.py
+++ b/internvl_chat/internvl/patch/pad_data_collator.py
@@ -72,8 +72,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
feat['attention_mask'] = feat['input_ids'].ne(pad_id)
if 'position_ids' in feat:
- temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
- temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
+ temp_position_ids = [pad_id] * max_item_length
+ if isinstance(feat['position_ids'], (list, tuple)):
+ pos_length = len(feat['position_ids'])
+ temp_position_ids[:pos_length] = feat['position_ids']
+ else:
+ pos_length = feat['position_ids'].shape[0]
+ temp_position_ids[:pos_length] = feat['position_ids']
feat['position_ids'] = temp_position_ids
if 'loss_weight' in feat:
@@ -98,7 +103,7 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
- if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
+ if k not in ('label', 'label_ids', 'pixel_values', 'image_flags', 'position_ids') and \
v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
@@ -113,6 +118,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
batch[k] = torch.concat(np.stack([f[k] for f in features]))
else:
batch[k] = torch.concat([f[k] for f in features])
+ if k in ('position_ids'):
+ if isinstance(v, torch.Tensor):
+ batch[k] = torch.concat([f[k] for f in features])
+ elif isinstance(v, np.ndarray):
+ batch[k] = torch.concat(np.stack([f[k] for f in features]))
+ else:
+ batch[k] = [f[k] for f in features]
return batch
diff --git a/internvl_chat/internvl/train/dataset_packed.py b/internvl_chat/internvl/train/dataset_packed.py
index 3b54ed475..ef52a7bd8 100644
--- a/internvl_chat/internvl/train/dataset_packed.py
+++ b/internvl_chat/internvl/train/dataset_packed.py
@@ -240,7 +240,22 @@ def update_buffer(self, buffer, new_sample):
assert buffer.keys() == new_sample.keys()
for k in buffer:
- buffer[k] = torch.cat([buffer[k], new_sample[k]])
+ if isinstance(buffer[k], list) or isinstance(new_sample[k], list):
+ # Handle list type
+ if isinstance(buffer[k], list):
+ buffer_data = buffer[k]
+ else:
+ buffer_data = buffer[k].tolist()
+
+ if isinstance(new_sample[k], list):
+ new_data = new_sample[k]
+ else:
+ new_data = new_sample[k].tolist()
+
+ buffer[k] = buffer_data + new_data
+ else:
+ # Handle tensor type
+ buffer[k] = torch.cat([buffer[k], new_sample[k]])
return buffer
@staticmethod
diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py
index 42f669437..00be4befc 100644
--- a/internvl_chat/internvl/train/internvl_chat_finetune.py
+++ b/internvl_chat/internvl/train/internvl_chat_finetune.py
@@ -52,6 +52,7 @@
preprocess_internvl2_5, preprocess_mpt,
preprocess_phi3)
from internvl.train.dataset_packed import PackedDataset, packed_collate_fn
+from internvl.v2pe_utils import get_rope_pos_id
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
from torch.utils.data import Dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
@@ -70,6 +71,7 @@
print('petrel_client is not installed. Using PIL to load images.')
has_tcs_loader = False
+
# Set constants for image processing and logging
IGNORE_INDEX = -100
Image.MAX_IMAGE_PIXELS = None
@@ -157,6 +159,14 @@ class ModelArguments:
default=False,
metadata={'help': 'Set to True to use the liger kernel.'}
)
+ rope_pos_id_version: Optional[str] = field(
+ default='default',
+ metadata={'help': 'version for get_rope_pos_id'},
+ )
+ rope_pos_id_stride: Optional[int] = field(
+ default=None,
+ metadata={'help': 'stride for the version v4 of get_rope_pos_id'},
+ )
@dataclass
@@ -297,16 +307,22 @@ def __init__(
distributed_mode=False,
force_shuffle=False,
random_seed=0,
+ rope_pos_id_version='default',
+ rope_pos_id_stride=None,
):
super(LazySupervisedDataset, self).__init__()
self.ds_name = ds_name
self.tokenizer = tokenizer
self.template_name = template_name
self.num_image_token = num_image_token
+ self.rope_pos_id_version = rope_pos_id_version
+ self.rope_pos_id_stride = rope_pos_id_stride
logger.info(f'[Dataset] num_image_token: {num_image_token}')
logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
+ logger.info(f'[Dataset] rope_pos_id_version: {rope_pos_id_version}')
+ logger.info(f'[Dataset] rope_pos_id_stride: {rope_pos_id_stride}')
self.image_size = image_size
self.is_train = is_train
@@ -421,6 +437,7 @@ def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
+ num_tiles = []
# Ensure the first conversation contains an image placeholder
if '' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value']
@@ -430,12 +447,15 @@ def multi_modal_get_item(self, data_item):
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
+ orig_size = image.size
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
+ num_tiles.append(len(images))
else: # Otherwise, use the original image as a single patch
images = [image]
+ num_tiles.append(1)
# Apply the transformation to each image and stack the results into a tensor
pixel_values = [transform(image) for image in images]
@@ -461,12 +481,16 @@ def multi_modal_get_item(self, data_item):
image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}'
+ if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']:
+ position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token)
+ else:
+ position_ids = position_ids[0]
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
- position_ids=position_ids[0],
+ position_ids=position_ids,
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
@@ -511,12 +535,16 @@ def multi_modal_multi_image_get_item(self, data_item):
image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}'
+ if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']:
+ position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token)
+ else:
+ position_ids = position_ids[0]
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
- position_ids=position_ids[0],
+ position_ids=position_ids,
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
@@ -567,12 +595,16 @@ def video_get_item(self, data_item):
position_ids = ret['attention_mask'].long().cumsum(-1) - 1
position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
+ if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']:
+ position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token)
+ else:
+ position_ids = position_ids[0]
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
- position_ids=position_ids[0],
+ position_ids=position_ids,
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
@@ -711,6 +743,8 @@ def build_datasets(
min_num_frame=8,
max_num_frame=32,
normalize_type='imagenet',
+ rope_pos_id_version='default',
+ rope_pos_id_stride=None,
):
datasets = []
lengths = []
@@ -749,6 +783,8 @@ def build_datasets(
distributed_mode=data_args.use_packed_ds,
force_shuffle=data_args.use_packed_ds,
random_seed=ds_idx,
+ rope_pos_id_version=rope_pos_id_version,
+ rope_pos_id_stride=rope_pos_id_stride,
)
logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
datasets.append(dataset)
@@ -906,6 +942,10 @@ def main():
config.ps_version = model_args.ps_version
config.min_dynamic_patch = data_args.min_dynamic_patch
config.max_dynamic_patch = data_args.max_dynamic_patch
+ config.rope_pos_id_version = model_args.rope_pos_id_version
+ config.rope_pos_id_stride = model_args.rope_pos_id_stride
+ config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version
+ config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride
model = InternVLChatModel.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
else:
@@ -935,6 +975,10 @@ def main():
use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
internvl_chat_config.force_image_size = data_args.force_image_size
+ internvl_chat_config.rope_pos_id_version = model_args.rope_pos_id_version
+ internvl_chat_config.rope_pos_id_stride = model_args.rope_pos_id_stride
+ internvl_chat_config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version
+ internvl_chat_config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride
logger.info('Building InternVLChatModel...')
model = InternVLChatModel(internvl_chat_config, vision_model, llm)
model.img_context_token_id = img_context_token_id
@@ -983,7 +1027,8 @@ def main():
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
- max_num_frame=data_args.max_num_frame)
+ max_num_frame=data_args.max_num_frame,rope_pos_id_version=model_args.rope_pos_id_version,
+ rope_pos_id_stride=model_args.rope_pos_id_stride)
def _freeze_params(module):
for param in module.parameters():
diff --git a/internvl_chat/internvl/train/internvl_chat_pretrain.py b/internvl_chat/internvl/train/internvl_chat_pretrain.py
index d7962ae92..6fd97b77c 100644
--- a/internvl_chat/internvl/train/internvl_chat_pretrain.py
+++ b/internvl_chat/internvl/train/internvl_chat_pretrain.py
@@ -52,6 +52,7 @@
preprocess_internvl2_5, preprocess_mpt,
preprocess_phi3)
from internvl.train.dataset_packed import PackedDataset, packed_collate_fn
+from internvl.v2pe_utils import get_rope_pos_id
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
from torch.utils.data import Dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
@@ -157,6 +158,14 @@ class ModelArguments:
default=False,
metadata={'help': 'Set to True to use the liger kernel.'}
)
+ rope_pos_id_version: Optional[str] = field(
+ default='default',
+ metadata={'help': 'version for get_rope_pos_id'},
+ )
+ rope_pos_id_stride: Optional[int] = field(
+ default=None,
+ metadata={'help': 'stride for the version v4 of get_rope_pos_id'},
+ )
@dataclass
@@ -297,16 +306,22 @@ def __init__(
distributed_mode=False,
force_shuffle=False,
random_seed=0,
+ rope_pos_id_version='default',
+ rope_pos_id_stride=None,
):
super(LazySupervisedDataset, self).__init__()
self.ds_name = ds_name
self.tokenizer = tokenizer
self.template_name = template_name
self.num_image_token = num_image_token
+ self.rope_pos_id_version = rope_pos_id_version
+ self.rope_pos_id_stride = rope_pos_id_stride
logger.info(f'[Dataset] num_image_token: {num_image_token}')
logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
+ logger.info(f'[Dataset] rope_pos_id_version: {rope_pos_id_version}')
+ logger.info(f'[Dataset] rope_pos_id_stride: {rope_pos_id_stride}')
self.image_size = image_size
self.is_train = is_train
@@ -464,6 +479,7 @@ def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
+ num_tiles = []
# Ensure the first conversation contains an image placeholder
if '' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value']
@@ -473,12 +489,15 @@ def multi_modal_get_item(self, data_item):
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
+ orig_size = image.size
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
+ num_tiles.append(len(images))
else: # Otherwise, use the original image as a single patch
images = [image]
+ num_tiles.append(1)
# Apply the transformation to each image and stack the results into a tensor
pixel_values = [transform(image) for image in images]
@@ -504,12 +523,16 @@ def multi_modal_get_item(self, data_item):
image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}'
+ if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']:
+ position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token)
+ else:
+ position_ids = position_ids[0]
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
- position_ids=position_ids[0],
+ position_ids=position_ids,
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
@@ -554,12 +577,16 @@ def multi_modal_multi_image_get_item(self, data_item):
image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}'
+ if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']:
+ position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token)
+ else:
+ position_ids = position_ids[0]
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
- position_ids=position_ids[0],
+ position_ids=position_ids,
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
@@ -610,12 +637,16 @@ def video_get_item(self, data_item):
position_ids = ret['attention_mask'].long().cumsum(-1) - 1
position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
+ if self.rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd']:
+ position_ids = get_rope_pos_id(ret, torch.float32, self.rope_pos_id_version, position_ids[0], tokenizer=self.tokenizer, num_image_token=self.num_image_token)
+ else:
+ position_ids = position_ids[0]
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
- position_ids=position_ids[0],
+ position_ids=position_ids,
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
@@ -755,6 +786,8 @@ def build_datasets(
min_num_frame=8,
max_num_frame=32,
normalize_type='imagenet',
+ rope_pos_id_version='default',
+ rope_pos_id_stride=None,
):
datasets = []
lengths = []
@@ -793,6 +826,8 @@ def build_datasets(
distributed_mode=data_args.use_packed_ds,
force_shuffle=data_args.use_packed_ds,
random_seed=ds_idx,
+ rope_pos_id_version=rope_pos_id_version,
+ rope_pos_id_stride=rope_pos_id_stride,
)
logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
datasets.append(dataset)
@@ -950,6 +985,10 @@ def main():
config.ps_version = model_args.ps_version
config.min_dynamic_patch = data_args.min_dynamic_patch
config.max_dynamic_patch = data_args.max_dynamic_patch
+ config.rope_pos_id_version = model_args.rope_pos_id_version
+ config.rope_pos_id_stride = model_args.rope_pos_id_stride
+ config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version
+ config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride
model = InternVLChatModel.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
else:
@@ -977,8 +1016,13 @@ def main():
pad2square=data_args.pad2square, template=data_args.conv_style,
select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
- min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
+ rope_pos_id_version=model_args.rope_pos_id_version, rope_pos_id_stride=model_args.rope_pos_id_stride)
internvl_chat_config.force_image_size = data_args.force_image_size
+ internvl_chat_config.rope_pos_id_version = model_args.rope_pos_id_version
+ internvl_chat_config.rope_pos_id_stride = model_args.rope_pos_id_stride
+ internvl_chat_config.llm_config.rope_pos_id_version = model_args.rope_pos_id_version
+ internvl_chat_config.llm_config.rope_pos_id_stride = model_args.rope_pos_id_stride
logger.info('Building InternVLChatModel...')
model = InternVLChatModel(internvl_chat_config, vision_model, llm)
model.img_context_token_id = img_context_token_id
@@ -1027,7 +1071,8 @@ def main():
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
- max_num_frame=data_args.max_num_frame)
+ max_num_frame=data_args.max_num_frame, rope_pos_id_version=model_args.rope_pos_id_version,
+ rope_pos_id_stride=model_args.rope_pos_id_stride)
def _freeze_params(module):
for param in module.parameters():
diff --git a/internvl_chat/internvl/v2pe_utils.py b/internvl_chat/internvl/v2pe_utils.py
new file mode 100644
index 000000000..35803df9b
--- /dev/null
+++ b/internvl_chat/internvl/v2pe_utils.py
@@ -0,0 +1,116 @@
+import random
+
+import torch
+import torch.nn as nn
+
+
+# get rope pos id while evaluating
+def get_rope_pos_id(ret, dtype, rope_pos_id_version='default', position_id=None,
+ IMG_START_TOKEN='
',IMG_END_TOKEN='',rope_pos_id_stride=None, tokenizer=None, num_image_token=256):
+ image_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
+ image_end_token_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
+ rope_pos_id_list = []
+ assert ret['input_ids'].shape[0] == 1, 'batch size should be 1, other batch sizes are not supported yet'
+ input_ids_0 = ret['input_ids'][0]
+ attention_mask_0 = ret['attention_mask'][0]
+ image_start_token_id_idxs = torch.where(input_ids_0 == image_start_token_id)[0]
+ image_end_token_id_idxs = torch.where(input_ids_0 == image_end_token_id)[0]
+ num_tiles = (image_end_token_id_idxs - image_start_token_id_idxs) // num_image_token
+ last_record_pos_id = -1
+ start_index = 0
+
+ assert rope_pos_id_version in ['v2pe_fix', 'v2pe_rnd', 'default'], f'{rope_pos_id_version} not supported for eval'
+
+ for i in range(len(image_start_token_id_idxs)):
+
+ num_tile = num_tiles[i]
+
+ rope_pos_id_pre = attention_mask_0[start_index:image_start_token_id_idxs[i] + 1].long().cumsum(-1) - 1 + (last_record_pos_id + 1)
+ rope_pos_id_pre.masked_fill_(attention_mask_0[start_index:image_start_token_id_idxs[i] + 1] == 0, 1)
+ rope_pos_id_list.append(rope_pos_id_pre)
+
+ last_record_pos_id = rope_pos_id_pre[-1].long()
+
+ if rope_pos_id_version == 'v2pe_fix':
+ assert rope_pos_id_stride is not None, 'when rope_pos_id_version is fix, self.rope_pos_id_stride should not be None'
+ small_stride = rope_pos_id_stride / num_image_token
+ split_img_id_idxs = torch.linspace(last_record_pos_id,last_record_pos_id+small_stride*(num_image_token * num_tile ),(num_image_token * num_tile + 1))[1:].to(dtype=dtype)
+ rope_pos_id_list.append(split_img_id_idxs)
+ last_record_pos_id = torch.ceil(split_img_id_idxs[-1]).long()
+ elif rope_pos_id_version == 'v2pe_rnd':
+ random_from=[1,2,4,8,16,32,64,128,256]
+ rope_pos_id_stride=random.choice(random_from)
+ small_stride = rope_pos_id_stride / num_image_token
+ split_img_id_idxs = torch.linspace(last_record_pos_id,last_record_pos_id+small_stride*(num_image_token * num_tile ),(num_image_token * num_tile + 1))[1:].to(dtype=dtype)
+ rope_pos_id_list.append(split_img_id_idxs)
+ last_record_pos_id = torch.ceil(split_img_id_idxs[-1]).long()
+ elif rope_pos_id_version == 'default':
+ split_img_id_idxs = torch.linspace(last_record_pos_id,
+ last_record_pos_id + (num_tile - 1) * num_image_token,
+ (num_tile - 1) * num_image_token + 1)[1:].to(dtype=dtype)
+ rope_pos_id_list.append(split_img_id_idxs)
+ thumbnail_id_idxs = torch.linspace(last_record_pos_id + (num_tile - 1) * num_image_token,
+ last_record_pos_id + num_tile * num_image_token,
+ num_image_token + 1)[1:].to(dtype=dtype)
+ rope_pos_id_list.append(thumbnail_id_idxs)
+ last_record_pos_id = (last_record_pos_id + num_tile * num_image_token).long()
+ else:
+ raise NotImplementedError(f'not implement for {rope_pos_id_version}')
+
+ start_index = image_start_token_id_idxs[i] + num_tile * num_image_token + 1
+ assert input_ids_0[start_index] == image_end_token_id
+ assert start_index == image_end_token_id_idxs[i]
+
+ assert image_end_token_id_idxs[-1] == start_index
+ rope_pos_id_pre = attention_mask_0[start_index:].long().cumsum(-1) - 1 + (last_record_pos_id + 1)
+ rope_pos_id_pre.masked_fill_(attention_mask_0[start_index:] == 0, 1)
+ rope_pos_id_list.append(rope_pos_id_pre)
+
+ rope_pos_id_list=[_.to('cpu') for _ in rope_pos_id_list]
+ rope_pos_id = torch.cat(rope_pos_id_list).to(dtype=dtype)
+ if rope_pos_id_version == 'default':
+ rope_pos_id = rope_pos_id.long()
+ assert torch.equal(rope_pos_id, position_id.to(rope_pos_id.device)), (rope_pos_id, position_id.to(rope_pos_id.device))
+ assert torch.allclose(rope_pos_id, position_id.to(rope_pos_id.device), atol=1e-32)
+
+ assert rope_pos_id.shape == input_ids_0.shape
+
+ return list(rope_pos_id.numpy())
+
+
+# Rotary Position Embedding for V2PE
+class V2PE(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.inv_freq = None
+ # inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ # self.register_buffer('inv_freq', inv_freq, persistent=False)
+
+ self.max_seq_len_cached = -1
+
+ def _set_cos_sin_cache(self, pos_id, device, dtype):
+ if self.inv_freq is None:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
+ del self.inv_freq
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+
+ pos_id=pos_id.squeeze(0)
+ freqs = torch.outer(pos_id, self.inv_freq.to(device=pos_id.device))
+
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, global_posid=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self._set_cos_sin_cache(pos_id=global_posid, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:].to(dtype=x.dtype),
+ self.sin_cached[:].to(dtype=x.dtype),
+ )