Skip to content

super_resolve.py not robust to Pillow version 10 or greater #17

@Declan-Curran1

Description

@Declan-Curran1

Had to edit super_resolve to solve following issue:

:36: DeprecationWarning: getsize is deprecated and will be removed in Pillow 10 (2023-07-01). Use getbbox or getlength instead.

Had to change getsize to getbbox. Full edited file pasted below


import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Model checkpoints

srgan_checkpoint = "C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/checkpoint_srgan.pth.tar"
srresnet_checkpoint = "C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/checkpoint_srresnet.pth.tar"

Load models

print("Loading SRResNet model...")
srresnet = torch.load(srresnet_checkpoint)['model'].to(device)
srresnet.eval()
print("Loading SRGAN model...")
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
srgan_generator.eval()

def visualize_sr(img, halve=False):
try:
# Load image, downsample to obtain low-res version
print("Loading and processing HR image...")
hr_img = Image.open(img, mode="r")
hr_img = hr_img.convert('RGB')
print(f"Original HR image size: {hr_img.size}")
if halve:
hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)), Image.LANCZOS)
print(f"HR image resized to: {hr_img.size}")
lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)), Image.BICUBIC)
print(f"LR image size: {lr_img.size}")

    # Bicubic Upsampling
    print("Performing bicubic upsampling...")
    bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)
    print(f"Bicubic image size: {bicubic_img.size}")

    # Super-resolution (SR) with SRResNet
    print("Generating SR image with SRResNet...")
    sr_img_srresnet = srresnet(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
    sr_img_srresnet = sr_img_srresnet.squeeze(0).cpu().detach()
    sr_img_srresnet = convert_image(sr_img_srresnet, source='[-1, 1]', target='pil')
    print(f"SRResNet image size: {sr_img_srresnet.size}")

    # Super-resolution (SR) with SRGAN
    print("Generating SR image with SRGAN...")
    sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
    sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach()
    sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil')
    print(f"SRGAN image size: {sr_img_srgan.size}")

    # Save intermediate images for verification
    print("Saving intermediate images for verification...")
    bicubic_img.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/bicubic_img.png", "PNG")
    sr_img_srresnet.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/sr_img_srresnet.png", "PNG")
    sr_img_srgan.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/sr_img_srgan.png", "PNG")
    hr_img.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/hr_img.png", "PNG")

    # Create grid
    print("Creating image grid...")
    margin = 40
    grid_img = Image.new('RGB', (2 * hr_img.width + 3 * margin, 2 * hr_img.height + 3 * margin), (255, 255, 255))

    # Font
    draw = ImageDraw.Draw(grid_img)
    try:
        font = ImageFont.truetype("calibril.ttf", size=23)
        print("Loaded custom font.")
    except OSError:
        print("Defaulting to a terrible font. To use a font of your choice, include the link to its TTF file in the function.")
        font = ImageFont.load_default()
        print("Loaded default font.")

    # Place bicubic-upsampled image
    grid_img.paste(bicubic_img, (margin, margin))
    try:
        text_size = draw.textbbox((0, 0), "Bicubic", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[margin + bicubic_img.width / 2 - text_width / 2, margin - text_size[3] - 5], text="Bicubic", font=font, fill='black')
        print("Added Bicubic text.")
    except Exception as e:
        print(f"An error occurred while drawing Bicubic text: {e}")

    # Place SRResNet image
    grid_img.paste(sr_img_srresnet, (2 * margin + bicubic_img.width, margin))
    try:
        text_size = draw.textbbox((0, 0), "SRResNet", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_width / 2, margin - text_size[3] - 5], text="SRResNet", font=font, fill='black')
        print("Added SRResNet text.")
    except Exception as e:
        print(f"An error occurred while drawing SRResNet text: {e}")

    # Place SRGAN image
    grid_img.paste(sr_img_srgan, (margin, 2 * margin + sr_img_srresnet.height))
    try:
        text_size = draw.textbbox((0, 0), "SRGAN", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[margin + bicubic_img.width / 2 - text_width / 2, 2 * margin + sr_img_srresnet.height - text_size[3] - 5], text="SRGAN", font=font, fill='black')
        print("Added SRGAN text.")
    except Exception as e:
        print(f"An error occurred while drawing SRGAN text: {e}")

    # Place original HR image
    grid_img.paste(hr_img, (2 * margin + bicubic_img.width, 2 * margin + sr_img_srresnet.height))
    try:
        text_size = draw.textbbox((0, 0), "Original HR", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_width / 2, 2 * margin + sr_img_srresnet.height - text_size[3] - 1], text="Original HR", font=font, fill='black')
        print("Added Original HR text.")
    except Exception as e:
        print(f"An error occurred while drawing Original HR text: {e}")

    # Display grid using matplotlib
    print("Displaying image grid...")
    plt.imshow(grid_img)
    plt.axis('off')
    plt.show()

    # Save grid
    print("Saving image grid...")
    grid_img.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/img1.png", "PNG")
    print("Image saved successfully")

    return grid_img
except Exception as e:
    print(f"An error occurred: {e}")

if name == 'main':
print("Starting visualization...")
grid_img = visualize_sr("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/BSDS100/62096.png", halve=True)
print("Visualization completed")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions