Skip to content

[Feature]V2PE #1000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __init__( # pylint: disable=W0102
rope_theta=10000,
rope_scaling=None,
attn_implementation='eager',
rope_pos_id_version='default',
rope_pos_id_stride=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -115,6 +117,8 @@ def __init__( # pylint: disable=W0102
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_pos_id_version = rope_pos_id_version
self.rope_pos_id_stride = rope_pos_id_stride
self._rope_scaling_validation()

self.attn_implementation = attn_implementation
Expand Down
69 changes: 47 additions & 22 deletions internvl_chat/internvl/model/internlm2/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
except: # noqa # pylint: disable=bare-except
BaseStreamer = None

from internvl.v2pe_utils import V2PE

from .configuration_internlm2 import InternLM2Config

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -323,30 +325,45 @@ def __init__(self, config: InternLM2Config):

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = InternLM2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
)
else:
scaling_type = self.config.rope_scaling['type']
scaling_factor = self.config.rope_scaling['factor']
if scaling_type == 'dynamic':
self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
if self.config.rope_pos_id_version.startswith('v2pe_'):
self.rotary_emb = V2PE(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
scaling_factor=scaling_factor,
)
elif scaling_type == 'linear':
self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
else:
self.rotary_emb = InternLM2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
)
else:
if self.config.rope_pos_id_version.startswith('v2pe_'):
warnings.warn(f'V2PE is not compatible with rope_scaling. When using V2PE, rope_scaling must be None. rope_scaling is {self.config.rope_scaling}')
self.rotary_emb = V2PE(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
scaling_factor=scaling_factor,
)
else:
raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
scaling_type = self.config.rope_scaling['type']
scaling_factor = self.config.rope_scaling['factor']
if scaling_type == 'linear':
self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.config.rope_theta,
)
elif scaling_type == 'dynamic':
self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.config.rope_theta,
)
else:
raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
return self.rotary_emb

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
Expand Down Expand Up @@ -391,8 +408,12 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if self.config.rope_pos_id_version.startswith('v2pe_'):
cos, sin = self.rotary_emb(value_states, global_posid=position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0))
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -493,10 +514,12 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if self.config.rope_pos_id_version.startswith('v2pe_'):
cos, sin = self.rotary_emb(value_states, global_posid=position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, torch.arange(0,position_ids.shape[1]).unsqueeze(0))
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -1145,13 +1168,15 @@ def prepare_inputs_for_generation(
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
elif position_ids is not None:
if self.config.rope_pos_id_version!='default' and past_key_values is not None:
position_ids=(position_ids[:,-1]+attention_mask[:,position_ids.shape[1]:].sum(dim=1)).unsqueeze(1)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}

model_inputs.update(
{
'position_ids': position_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
rope_pos_id_version='default',
rope_pos_id_stride=None,
**kwargs):
super().__init__(**kwargs)

Expand Down Expand Up @@ -72,6 +74,8 @@ def __init__(
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.rope_pos_id_version = rope_pos_id_version
self.rope_pos_id_stride = rope_pos_id_stride

self.hidden_size = self.llm_config.hidden_size
# By default, we use tie_word_embeddings=False for models of all sizes.
Expand All @@ -82,6 +86,8 @@ def __init__(
logger.info(f'ps_version: {self.ps_version}')
logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
logger.info(f'rope_pos_id_version: {self.rope_pos_id_version}')
logger.info(f'rope_pos_id_stride: {self.rope_pos_id_stride}')

def to_dict(self):
"""
Expand All @@ -105,5 +111,7 @@ def to_dict(self):
output['ps_version'] = self.ps_version
output['min_dynamic_patch'] = self.min_dynamic_patch
output['max_dynamic_patch'] = self.max_dynamic_patch
output['rope_pos_id_version'] = self.rope_pos_id_version
output['rope_pos_id_stride'] = self.rope_pos_id_stride

return output
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from internvl.conversation import get_conv_template
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM
from internvl.v2pe_utils import get_rope_pos_id
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -158,6 +159,8 @@ def forward(
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if isinstance(position_ids, list):
position_ids = torch.tensor(position_ids, dtype=torch.float32).to(pixel_values.device)
image_flags = image_flags.squeeze(-1)
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()

Expand Down Expand Up @@ -341,7 +344,7 @@ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_

def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
verbose=False):
verbose=False, **kwargs):

if history is None and pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
Expand Down Expand Up @@ -378,12 +381,40 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non
input_ids = model_inputs['input_ids'].to(device)
attention_mask = model_inputs['attention_mask'].to(device)
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)
rope_pos_id_version = self.config.rope_pos_id_version
if rope_pos_id_version.startswith('v2pe_'):
pos_ids = []
ret = {'input_ids': input_ids, 'attention_mask': attention_mask}
for i in range(input_ids.shape[0]):
cur_dtype = torch.float32
rope_pos_id_stride = self.config.rope_pos_id_stride
cur_pos_id = get_rope_pos_id(ret, tokenizer=tokenizer,
dtype=cur_dtype,
rope_pos_id_version=rope_pos_id_version,
position_id=torch.arange(0, input_ids.shape[1]),
IMG_START_TOKEN=IMG_START_TOKEN,
IMG_END_TOKEN=IMG_END_TOKEN, rope_pos_id_stride=rope_pos_id_stride, num_image_token=self.num_image_token)

cur_pos_id = torch.tensor(cur_pos_id).to(device)
pos_ids.append(cur_pos_id)

pos_ids = torch.stack(pos_ids).to(device)
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=pos_ids,
**generation_config
)
else:
self.language_model.rope_pos_id_version = 'default'
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)

response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
response = response.split(template.sep.strip())[0].strip()
history.append((question, response))
Expand All @@ -402,6 +433,7 @@ def generate(
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -430,6 +462,7 @@ def generate(
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
Expand Down
18 changes: 15 additions & 3 deletions internvl_chat/internvl/patch/pad_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
feat['attention_mask'] = feat['input_ids'].ne(pad_id)

if 'position_ids' in feat:
temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
temp_position_ids = [pad_id] * max_item_length
if isinstance(feat['position_ids'], (list, tuple)):
pos_length = len(feat['position_ids'])
temp_position_ids[:pos_length] = feat['position_ids']
else:
pos_length = feat['position_ids'].shape[0]
temp_position_ids[:pos_length] = feat['position_ids']
feat['position_ids'] = temp_position_ids

if 'loss_weight' in feat:
Expand All @@ -98,7 +103,7 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
if k not in ('label', 'label_ids', 'pixel_values', 'image_flags', 'position_ids') and \
v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
Expand All @@ -113,6 +118,13 @@ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
batch[k] = torch.concat(np.stack([f[k] for f in features]))
else:
batch[k] = torch.concat([f[k] for f in features])
if k in ('position_ids'):
if isinstance(v, torch.Tensor):
batch[k] = torch.concat([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.concat(np.stack([f[k] for f in features]))
else:
batch[k] = [f[k] for f in features]
return batch


Expand Down
17 changes: 16 additions & 1 deletion internvl_chat/internvl/train/dataset_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,22 @@ def update_buffer(self, buffer, new_sample):

assert buffer.keys() == new_sample.keys()
for k in buffer:
buffer[k] = torch.cat([buffer[k], new_sample[k]])
if isinstance(buffer[k], list) or isinstance(new_sample[k], list):
# Handle list type
if isinstance(buffer[k], list):
buffer_data = buffer[k]
else:
buffer_data = buffer[k].tolist()

if isinstance(new_sample[k], list):
new_data = new_sample[k]
else:
new_data = new_sample[k].tolist()

buffer[k] = buffer_data + new_data
else:
# Handle tensor type
buffer[k] = torch.cat([buffer[k], new_sample[k]])
return buffer

@staticmethod
Expand Down
Loading