Skip to content

Commit baad2d7

Browse files
authored
Fix BLIP2 mixed precision on CPU (#179)
* allow both cpu string and torch.device to be identified for model loading. * blip2 amp cpu compatibility. * use dtype=float16 by default.
1 parent a557de5 commit baad2d7

File tree

6 files changed

+61
-44
lines changed

6 files changed

+61
-44
lines changed

lavis/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import logging
9+
import torch
910
from omegaconf import OmegaConf
1011
from lavis.common.registry import registry
1112

@@ -211,7 +212,7 @@ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
211212
"""
212213
)
213214

214-
if device == "cpu":
215+
if device == "cpu" or device == torch.device("cpu"):
215216
model = model.float()
216217

217218
return model.to(device), vis_processors, txt_processors

lavis/models/blip2_models/blip2.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
SPDX-License-Identifier: BSD-3-Clause
55
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
66
"""
7+
import contextlib
78
import logging
89
import os
910
import time
@@ -32,6 +33,16 @@ def init_tokenizer(cls):
3233
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
3334
return tokenizer
3435

36+
def maybe_autocast(self, dtype=torch.float16):
37+
# if on cpu, don't use autocast
38+
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
39+
enable_autocast = self.device != torch.device("cpu")
40+
41+
if enable_autocast:
42+
return torch.cuda.amp.autocast(dtype=dtype)
43+
else:
44+
return contextlib.nullcontext()
45+
3546
@classmethod
3647
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
3748
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
@@ -42,7 +53,7 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
4253
encoder_config.query_length = num_query_token
4354
Qformer = BertLMHeadModel.from_pretrained(
4455
"bert-base-uncased", config=encoder_config
45-
)
56+
)
4657
query_tokens = nn.Parameter(
4758
torch.zeros(1, num_query_token, encoder_config.hidden_size)
4859
)
@@ -52,16 +63,17 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
5263
@classmethod
5364
def init_vision_encoder(
5465
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
55-
):
56-
assert model_name in ["eva_clip_g","clip_L"], "vit model must be eva_clip_g or clip_L"
57-
if model_name=="eva_clip_g":
66+
):
67+
assert model_name in [
68+
"eva_clip_g",
69+
"clip_L",
70+
], "vit model must be eva_clip_g or clip_L"
71+
if model_name == "eva_clip_g":
5872
visual_encoder = create_eva_vit_g(
5973
img_size, drop_path_rate, use_grad_checkpoint, precision
6074
)
61-
elif model_name=="clip_L":
62-
visual_encoder = create_clip_vit_L(
63-
img_size, use_grad_checkpoint, precision
64-
)
75+
elif model_name == "clip_L":
76+
visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
6577
ln_vision = LayerNorm(visual_encoder.num_features)
6678
return visual_encoder, ln_vision
6779

@@ -80,7 +92,7 @@ def load_from_pretrained(self, url_or_filename):
8092

8193
msg = self.load_state_dict(state_dict, strict=False)
8294

83-
logging.info("Missing keys {}".format(msg.missing_keys))
95+
# logging.info("Missing keys {}".format(msg.missing_keys))
8496
logging.info("load checkpoint from %s" % url_or_filename)
8597

8698
return msg

lavis/models/blip2_models/blip2_image_text_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def forward(self, samples, match_head="itm"):
5454
image = samples["image"]
5555
caption = samples["text_input"]
5656

57-
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
57+
with self.maybe_autocast():
5858
image_embeds = self.ln_vision(self.visual_encoder(image))
59-
image_embeds = image_embeds.float()
59+
image_embeds = image_embeds.float()
6060
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
6161
image.device
6262
)

lavis/models/blip2_models/blip2_opt.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
)
6060
if freeze_vit:
6161
for name, param in self.visual_encoder.named_parameters():
62-
param.requires_grad = False
62+
param.requires_grad = False
6363
self.visual_encoder = self.visual_encoder.eval()
6464
self.visual_encoder.train = disabled_train
6565
logging.info("freeze vision encoder")
@@ -95,7 +95,8 @@ def __init__(
9595

9696
def forward(self, samples):
9797
image = samples["image"]
98-
image_embeds = self.ln_vision(self.visual_encoder(image))
98+
with self.maybe_autocast():
99+
image_embeds = self.ln_vision(self.visual_encoder(image))
99100
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
100101
image.device
101102
)
@@ -138,12 +139,13 @@ def forward(self, samples):
138139
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
139140
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
140141

141-
outputs = self.opt_model(
142-
inputs_embeds=inputs_embeds,
143-
attention_mask=attention_mask,
144-
return_dict=True,
145-
labels=targets,
146-
)
142+
with self.maybe_autocast():
143+
outputs = self.opt_model(
144+
inputs_embeds=inputs_embeds,
145+
attention_mask=attention_mask,
146+
return_dict=True,
147+
labels=targets,
148+
)
147149
loss = outputs.loss
148150

149151
return {"loss": loss}
@@ -177,9 +179,7 @@ def generate(
177179
captions (list): A list of strings of length batch_size * num_captions.
178180
"""
179181
image = samples["image"]
180-
with torch.cuda.amp.autocast(
181-
enabled=(self.device != torch.device("cpu"))
182-
):
182+
with self.maybe_autocast():
183183
image_embeds = self.ln_vision(self.visual_encoder(image))
184184
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
185185
image.device
@@ -194,7 +194,9 @@ def generate(
194194
)
195195

196196
inputs_opt = self.opt_proj(query_output.last_hidden_state)
197-
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)
197+
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(
198+
image.device
199+
)
198200

199201
if "prompt" in samples.keys():
200202
prompt = samples["prompt"]
@@ -203,7 +205,9 @@ def generate(
203205

204206
prompt = [prompt] * image.size(0)
205207

206-
opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to(image.device)
208+
opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to(
209+
image.device
210+
)
207211
input_ids = opt_tokens.input_ids
208212
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
209213

@@ -238,7 +242,7 @@ def generate(
238242

239243
@classmethod
240244
def from_config(cls, cfg):
241-
vit_model = cfg.get("vit_model","eva_clip_g")
245+
vit_model = cfg.get("vit_model", "eva_clip_g")
242246
img_size = cfg.get("image_size")
243247
num_query_token = cfg.get("num_query_token")
244248
opt_model = cfg.get("opt_model")

lavis/models/blip2_models/blip2_qformer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def __init__(
6464
)
6565
if freeze_vit:
6666
for name, param in self.visual_encoder.named_parameters():
67-
param.requires_grad = False
67+
param.requires_grad = False
6868
self.visual_encoder = self.visual_encoder.eval()
69-
self.visual_encoder.train = disabled_train
69+
self.visual_encoder.train = disabled_train
7070
logging.info("freeze vision encoder")
7171
self.Qformer, self.query_tokens = self.init_Qformer(
7272
num_query_token, self.visual_encoder.num_features, cross_attention_freq
@@ -90,7 +90,7 @@ def __init__(
9090
def forward(self, samples):
9191
image = samples["image"]
9292
text = samples["text_input"]
93-
93+
9494
image_embeds = self.ln_vision(self.visual_encoder(image))
9595
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
9696
image.device
@@ -247,7 +247,7 @@ def forward(self, samples):
247247
return_dict=True,
248248
labels=labels,
249249
)
250-
250+
251251
loss_lm = lm_output.loss
252252

253253
return BlipOutput(
@@ -403,9 +403,9 @@ def extract_features(self, samples, mode="multimodal"):
403403
image is not None
404404
), "Image is not provided for mode 'image' or 'multimodal'"
405405
# return query features
406-
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
406+
with self.maybe_autocast():
407407
image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
408-
image_embeds_frozen = image_embeds_frozen.float()
408+
image_embeds_frozen = image_embeds_frozen.float()
409409
image_atts = torch.ones(
410410
image_embeds_frozen.size()[:-1], dtype=torch.long
411411
).to(self.device)
@@ -443,9 +443,9 @@ def extract_features(self, samples, mode="multimodal"):
443443

444444
elif mode == "multimodal":
445445
# return multimodel query features
446-
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
446+
with self.maybe_autocast():
447447
image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
448-
image_embeds_frozen = image_embeds_frozen.float()
448+
image_embeds_frozen = image_embeds_frozen.float()
449449
image_atts = torch.ones(
450450
image_embeds_frozen.size()[:-1], dtype=torch.long
451451
).to(self.device)
@@ -482,10 +482,10 @@ def extract_features(self, samples, mode="multimodal"):
482482

483483
@classmethod
484484
def from_config(cls, cfg):
485-
vit_model = cfg.get("vit_model","eva_clip_g")
485+
vit_model = cfg.get("vit_model", "eva_clip_g")
486486
img_size = cfg.get("image_size")
487487
num_query_token = cfg.get("num_query_token")
488-
cross_attention_freq = cfg.get("cross_attention_freq",2)
488+
cross_attention_freq = cfg.get("cross_attention_freq", 2)
489489

490490
drop_path_rate = cfg.get("drop_path_rate", 0)
491491
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)

lavis/models/blip2_models/blip2_t5.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def __init__(
101101

102102
def forward(self, samples):
103103
image = samples["image"]
104-
image_embeds = self.ln_vision(self.visual_encoder(image))
104+
105+
with self.maybe_autocast():
106+
image_embeds = self.ln_vision(self.visual_encoder(image))
105107
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
106108
image.device
107109
)
@@ -117,7 +119,7 @@ def forward(self, samples):
117119
inputs_t5 = self.t5_proj(query_output.last_hidden_state)
118120
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
119121

120-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
122+
with self.maybe_autocast(dtype=torch.bfloat16):
121123
input_tokens = self.t5_tokenizer(
122124
samples["text_input"],
123125
padding="longest",
@@ -182,9 +184,8 @@ def generate(
182184
captions (list): A list of strings of length batch_size * num_captions.
183185
"""
184186
image = samples["image"]
185-
enable_autocast = self.device != torch.device("cpu")
186187

187-
with torch.cuda.amp.autocast(enabled=enable_autocast):
188+
with self.maybe_autocast():
188189
image_embeds = self.ln_vision(self.visual_encoder(image))
189190
image_embeds = image_embeds.float()
190191
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
@@ -220,7 +221,7 @@ def generate(
220221

221222
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
222223

223-
with torch.cuda.amp.autocast(enabled=enable_autocast, dtype=torch.bfloat16):
224+
with self.maybe_autocast(dtype=torch.bfloat16):
224225
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
225226
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
226227

@@ -257,7 +258,7 @@ def predict_answers(
257258
**kwargs
258259
):
259260
image = samples["image"]
260-
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
261+
with self.maybe_autocast():
261262
image_embeds = self.ln_vision(self.visual_encoder(image))
262263
image_embeds = image_embeds.float()
263264
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
@@ -288,8 +289,7 @@ def predict_answers(
288289

289290
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
290291

291-
device_type = "cuda" if "cuda" in str(self.device) else "cpu"
292-
with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16):
292+
with self.maybe_autocast(dtype=torch.bfloat16):
293293
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
294294
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
295295

0 commit comments

Comments
 (0)