Skip to content

q #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

q #2

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 171 additions & 5 deletions gnnrl/___real_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from torch import nn
from torch.nn.utils import prune
from torchvision import models
import sys
import os


# 获取当前文件的上级目录路径
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

# 将上级目录路径添加到 sys.path
sys.path.append(parent_dir)
from gnnrl.networks import resnet
from gnnrl.graph_env.feedback_calculation import top5validate
from gnnrl.graph_env.network_pruning import real_pruning, channel_pruning
from gnnrl.utils.split_dataset import get_dataset
Expand All @@ -30,7 +39,7 @@ def parse_args():
return parser.parse_args()

def load_model(model_name):

print('=> Building model..')
if model_name == "vgg16":
net = models.vgg16(pretrained=True)
net = channel_pruning(net,torch.ones(100, 1))
Expand All @@ -42,6 +51,162 @@ def load_model(model_name):
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()
if args.model == 'mobilenet':
from networks.mobilenet import MobileNet
net = MobileNet(n_class=1000)
if args.finetuning:
print("Fine-Tuning...")
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
# path = os.path.join(args.ckpt_path, args.model+'ckpt.best.pth.tar')
path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd, False)
# net.apply(weights_init)
for name, layer in net.named_modules():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
net.cuda()


elif args.model == 'mobilenetv2':
net = models.mobilenet_v2(pretrained=True)
if args.finetuning:
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path
# path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
# for name,layer in net.named_modules():
# if hasattr(layer, 'reset_parameters'):
# layer.reset_parameters()
net.cuda()

elif args.model == 'mobilenet_0.5flops':
from networks.mobilenet_cifar100 import MobileNet
net = MobileNet(n_class=1000, profile='0.5flops')
net.cuda()

elif args.model == 'resnet18':
net = models.resnet18(pretrained=True)
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path
# path = os.path.join(args.ckpt_path, args.model+'ckpt.best.pth.tar')
# checkpoint = torch.load(args.ckpt_path, args.model+'ckpt.best.pth.tar')
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == 'resnet50':
net = models.resnet50(pretrained=True)
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path
# path = os.path.join(args.ckpt_path, args.model+'ckpt.best.pth.tar')
# checkpoint = torch.load(args.ckpt_path, args.model+'ckpt.best.pth.tar')
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == "resnet56":
net = resnet.__dict__['resnet56']()
# net = torch.nn.DataParallel(net,list(range(args.n_gpu)))
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = os.path.join(args.ckpt_path)

checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == "resnet44":
net = resnet.__dict__['resnet44']()
# net = torch.nn.DataParallel(net,list(range(args.n_gpu)))
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path

checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == "resnet110":
net = resnet.__dict__['resnet110']()
# net = torch.nn.DataParallel(net,list(range(args.n_gpu)))
net = channel_pruning(net, torch.ones(120, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == "resnet32":
net = resnet.__dict__['resnet32']()
# net = torch.nn.DataParallel(net,list(range(args.n_gpu)))
net = channel_pruning(net, torch.ones(120, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == "resnet20":
net = resnet.__dict__['resnet20']()
# net = torch.nn.DataParallel(net,list(range(args.n_gpu)))
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

elif args.model == 'shufflenet':
from networks.shufflenet import shufflenet
net = shufflenet()
if args.finetuning:
print("Finetuning")
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = os.path.join(args.ckpt_path, args.model + 'ckpt.best.pth.tar')
# path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()
elif args.model == 'shufflenetv2':
from networks.shufflenetv2 import shufflenetv2
net = shufflenetv2()
if args.finetuning:
net = channel_pruning(net, torch.ones(100, 1))
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
path = os.path.join(args.ckpt_path, args.model + 'ckpt.best.pth.tar')

checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.cuda()

else:
raise KeyError
Expand Down Expand Up @@ -69,10 +234,11 @@ def load_model(model_name):

if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from pruned model..')
path = os.path.join(args.ckpt_path,'vgg16_20FLOPs.pth')
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net = load_model(args.model)
# path = os.path.join(args.ckpt_path,'vgg16_20FLOPs.pth')
# checkpoint = torch.load(path)
# sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
# net.load_state_dict(sd)


criterion = nn.CrossEntropyLoss().to(device)
Expand Down
64 changes: 62 additions & 2 deletions gnnrl/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import argparse
import shutil
import math
import sys
import os
from cmath import phase

# 获取当前文件的上级目录路径
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

# 将上级目录路径添加到 sys.path
sys.path.append(parent_dir)
import torch
import torch.nn as nn
import numpy as np
Expand All @@ -15,6 +23,9 @@
from gnnrl.utils.train_utils import accuracy, AverageMeter, progress_bar, get_output_folder
from gnnrl.graph_env.network_pruning import channel_pruning
from gnnrl.utils.split_dataset import get_dataset
from gnnrl.graph_env.flops_calculation import flops_caculation_forward


def weights_init(m):
if isinstance(m, nn.Conv2d):
m.reset_parameters()
Expand Down Expand Up @@ -137,7 +148,7 @@ def get_model():
path = args.ckpt_path
checkpoint = torch.load(path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net.load_state_dict(sd,False)
#net.apply(weights_init)
for name,layer in net.named_modules():
if hasattr(layer, 'reset_parameters'):
Expand Down Expand Up @@ -362,6 +373,8 @@ def test(epoch, test_loader, save=True):
top5 = AverageMeter()
end = time.time()

record_data = ''

with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
if use_cuda:
Expand All @@ -380,6 +393,44 @@ def test(epoch, test_loader, save=True):

progress_bar(batch_idx, len(test_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'
.format(losses.avg, top1.avg, top5.avg))
if save:
record_data = top1.avg
record_data = f'Epoch:{epoch} Acc1:{record_data}\n'
with open(r"{}/acc1_data.txt".format(log_dir), 'a') as f:
f.write(record_data)


# total_params = sum(p.numel() for p in net.parameters())
# print(f"Current model Total parameters: {total_params:,} ({total_params / 1e6:.2f}M)")
# # 新增FLOPs计算逻辑
# input_shape = (3, 224, 224) if args.dataset == 'imagenet' else (3, 32, 32)
# input_x = torch.randn(1, *input_shape).cuda()
# try:
# flops, _ = flops_caculation_forward(net, args.model, input_x)
# total_flops = sum(flops)
# print(f"Current model FLOPs: {total_flops / 1e9:.2f} GFLOPs")
# except NotImplementedError as e:
# print(f"FLOPs calculation failed: {e}")

# # 计算原始模型的FLOPs和参数量
# # 修正原始模型加载错误
# original_model_name = 'resnet110'
# o_net = resnet.__dict__[original_model_name]()
# o_net = torch.nn.DataParallel(o_net) # 删除:torch.nn.DataParallel(net)
# path = "E:\\GraduationProject\\GNNRL_1\\GNN-RL-Model-Compression\\gnnrl\\networks\\pretrained_models\\cifar10\\resnet110.th"
# checkpoint = torch.load(path, map_location='cuda')
# o_net.load_state_dict(checkpoint['state_dict'])
# o_net.eval() # 添加模型评估模式
#
# # 计算原始模型参数量
# o_total_params = sum(p.numel() for p in o_net.parameters())
# print(f"Original {original_model_name} parameters: {o_total_params:,} ({o_total_params / 1e6:.2f}M)")
#
# # 计算原始模型FLOPs
# input_shape_original = (3, 32, 32) # CIFAR-10输入尺寸
# input_x_original = torch.randn(1, *input_shape_original).cuda()
# o_flops, _ = flops_caculation_forward(o_net, original_model_name, input_x_original)
# print(f"Original {original_model_name} FLOPs: {sum(o_flops)/1e9:.2f} GFLOPs")

if save:
# writer.add_scalar('loss/test', losses.avg, epoch)
Expand All @@ -390,6 +441,10 @@ def test(epoch, test_loader, save=True):
if top1.avg > best_acc:
best_acc = top1.avg
is_best = True
record_data = top1.avg
record_data = f'Epoch:{epoch} Acc1:{record_data}\n'
with open(r"{}/best_acc1_data.txt".format(log_dir), 'w') as f:
f.write(record_data)

print('Current best acc: {}'.format(best_acc))
save_checkpoint({
Expand Down Expand Up @@ -500,4 +555,9 @@ def save_checkpoint(state, is_best, checkpoint_dir='.'):


--finetuning
'''
'''





Loading