- 
                Notifications
    You must be signed in to change notification settings 
- Fork 126
Open
Description
Traceback (most recent call last):
File "E:\New GG\a.py", line 11, in 
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
KeyError: 'generator'
import torch
from utils import *
from PIL import Image, ImageFont
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
srgan_checkpoint = "./RealESRGAN_x4plus_anime_6B.pth"
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
srgan_generator.eval()
def save_super_resolved_image(input_image_path, output_image_path):
    input_image = Image.open(input_image_path)
    input_image = input_image.convert('RGB')
    lr_img = input_image.resize((int(input_image.width / 4), int(input_image.height / 4)), Image.BICUBIC)
    sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
    sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach()
    sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil')
    sr_img_srgan.save(output_image_path)
if __name__ == '__main__':
    input_image_path = "input.png"
    output_image_path = "output.jpg"
    save_super_resolved_image(input_image_path, output_image_path)
Metadata
Metadata
Assignees
Labels
No labels