Skip to content

Commit 0f769a4

Browse files
authored
Add files via upload
1 parent 01fdcb4 commit 0f769a4

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

tools/weight_fuse_fcclip.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
import argparse
3+
4+
def get_parser():
5+
parser = argparse.ArgumentParser(description="Fuse weights from two models")
6+
parser.add_argument("--model_first_phase_path", type=str, required=True, help="Path to the first phase model")
7+
parser.add_argument("--model_sem_seg_path", type=str, required=True, help="Path to the semantic segmentation model")
8+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the fused model")
9+
return parser
10+
11+
def main():
12+
parser = get_parser()
13+
args = parser.parse_args()
14+
15+
model_first_phase_dict = torch.load(args.model_first_phase_path)
16+
model_dict_sem_seg = torch.load(args.model_sem_seg_path)
17+
18+
for key in model_dict_sem_seg["model"].keys():
19+
if key.startswith("sem_seg_head"):
20+
model_first_phase_dict["model"][key] = model_dict_sem_seg["model"][key]
21+
22+
torch.save(model_first_phase_dict, args.output_path)
23+
24+
if __name__ == "__main__":
25+
main()

tools/weight_fuse_maftp.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
import argparse
3+
4+
def get_parser():
5+
parser = argparse.ArgumentParser(description="Fuse weights from two models")
6+
parser.add_argument("--model_first_phase_path", type=str, required=True, help="Path to the first phase model")
7+
parser.add_argument("--model_sem_seg_path", type=str, required=True, help="Path to the semantic segmentation model")
8+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the fused model")
9+
return parser
10+
11+
def main():
12+
parser = get_parser()
13+
args = parser.parse_args()
14+
15+
model_first_phase_dict = torch.load(args.model_first_phase_path)
16+
model_dict_sem_seg = torch.load(args.model_sem_seg_path)
17+
18+
for key in model_dict_sem_seg.keys():
19+
if key.startswith("sem_seg_head"):
20+
model_first_phase_dict["model"][key] = model_dict_sem_seg[key]
21+
22+
torch.save(model_first_phase_dict, args.output_path)
23+
24+
if __name__ == "__main__":
25+
main()

0 commit comments

Comments
 (0)