#!/usr/bin/env python
# coding: utf-8

"""
Image Super Resolution using SRGAN
Demonstrates perceptual loss-based super resolution
"""

import os
import subprocess
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image

# Create output directory
os.makedirs('./figs', exist_ok=True)

print("=" * 60)
print("SRGAN: Image Super Resolution")
print("=" * 60)

print("\n1. Setting up SRGAN model...")

if not os.path.exists('srgan.pth.tar'):
    print("   Installing dependencies...")
    # Install torch_snippets
    subprocess.check_call(["pip", "install", "-q", "torch_snippets"])

    # Download models.py using wget
    print("   Downloading model architecture...")
    subprocess.check_call([
        "wget", "-q",
        "https://raw.githubusercontent.com/sizhky/a-PyTorch-Tutorial-to-Super-Resolution/master/models.py",
        "-O", "models.py"
    ])

    # Download the pretrained model
    print("   Downloading pretrained weights...")
    print("   (This may take a few minutes...)")
    try:
        # Try direct download
        subprocess.check_call([
            "wget", "-q", "--no-check-certificate",
            "https://drive.google.com/uc?export=download&id=1_PJ1Uimbr0xrPjE8U3Q_bG7XycGgsbVo",
            "-O", "srgan.pth.tar"
        ])
    except:
        # Try with gdown if wget fails
        try:
            subprocess.check_call(["pip", "install", "-q", "gdown"])
            import gdown
            gdown.download(
                "https://drive.google.com/uc?id=1_PJ1Uimbr0xrPjE8U3Q_bG7XycGgsbVo",
                "srgan.pth.tar",
                quiet=False
            )
        except:
            print("\n" + "!" * 60)
            print("ERROR: Automatic download failed!")
            print("!" * 60)
            print("\nPlease download manually:")
            print("1. Go to: https://drive.google.com/file/d/1_PJ1Uimbr0xrPjE8U3Q_bG7XycGgsbVo/view")
            print("2. Download the file and save as 'srgan.pth.tar'")
            print("3. Place it in the current directory")
            print("4. Re-run this script")
            print("!" * 60)
            exit(1)

# Import torch_snippets (MUST be outside if block!)
from torch_snippets import *

# Import models to register the Generator class
import models

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"   Using device: {device}")

# Load model
print("   Loading pretrained SRGAN model...")
# Set weights_only=False since we trust this checkpoint from the tutorial
model = torch.load('srgan.pth.tar', map_location=device, weights_only=False)['generator'].to(device)
model.eval()
print("   ✓ Model loaded successfully")

# Download test image
print("\n2. Downloading test image...")

# Check if file exists and is valid
need_download = False
if os.path.exists('MyImage.jpg'):
    try:
        Image.open('MyImage.jpg').verify()
        print("   Using existing MyImage.jpg")
    except:
        print("   Existing file is corrupted, re-downloading...")
        os.remove('MyImage.jpg')
        need_download = True
else:
    need_download = True

if need_download:
    # Use a reliable test image source
    try:
        # Try Wikipedia Commons cat image
        subprocess.check_call([
            "wget", "-q",
            "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/481px-Cat03.jpg",
            "-O", "MyImage.jpg"
        ], timeout=10)
        print("   ✓ Downloaded test image")
    except:
        print("\n" + "!" * 60)
        print("ERROR: Could not download test image!")
        print("!" * 60)
        print("\nPlease provide your own image:")
        print("1. Place any JPG image in the current directory")
        print("2. Rename it to 'MyImage.jpg'")
        print("3. Re-run this script")
        print("!" * 60)
        exit(1)
else:
    print("   ✓ Test image ready")

# Define preprocessing and postprocessing
preprocess = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225]),
    T.Lambda(lambda x: x.to(device))
])

postprocess = T.Compose([
    T.Lambda(lambda x: (x.cpu().detach()+1)/2),
    T.ToPILImage()
])

# Load and downscale the original image
print("\n3. Processing image...")
image = readPIL('MyImage.jpg')
w, h = image.size
print(f"   Original image size: {w}x{h}")

reduction_factor = 4
low_res_image = image.resize((w // reduction_factor, h // reduction_factor))
print(f"   Low resolution size: {w//reduction_factor}x{h//reduction_factor}")

# Save the low resolution input
low_res_image.save('./figs/01_low_resolution_input.jpg')
print("   ✓ Saved low resolution input")

# Preprocess for model
im = preprocess(low_res_image)

# Generate super-resolved image
print("\n4. Generating super-resolution image...")
with torch.no_grad():
    sr = model(im[None])[0]
    sr = postprocess(sr)

print(f"   Super resolution size: {sr.size[0]}x{sr.size[1]}")

# Save super-resolution output (as PIL Image, not tensor!)
sr.save('./figs/02_super_resolution_output.jpg')
print("   ✓ Saved super resolution output")

# Create comparison figure
print("\n5. Creating comparison visualization...")
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
fig.suptitle('SRGAN: Image Super Resolution', fontsize=16, fontweight='bold')

axes[0].imshow(low_res_image)
axes[0].set_title(f'Low Resolution Input\n({w//reduction_factor}x{h//reduction_factor})', fontsize=12)
axes[0].axis('off')

axes[1].imshow(sr)
axes[1].set_title(f'Super Resolution Output\n({sr.size[0]}x{sr.size[1]})', fontsize=12)
axes[1].axis('off')

plt.tight_layout()
plt.savefig('./figs/03_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("   ✓ Saved comparison figure")

# Display the comparison
show_fig = plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(low_res_image)
plt.title(f'Low Resolution Input\n({w//reduction_factor}x{h//reduction_factor})')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(sr)
plt.title(f'Super Resolution Output\n({sr.size[0]}x{sr.size[1]})')
plt.axis('off')
plt.suptitle('SRGAN: Image Super Resolution', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "=" * 60)
print("SUPER RESOLUTION COMPLETE!")
print("=" * 60)
print("\nGenerated files in ./figs/:")
print("  - 01_low_resolution_input.jpg")
print("  - 02_super_resolution_output.jpg")
print("  - 03_comparison.png")
print("\nUpscaling factor: 4x")
print(f"Input size: {w//reduction_factor}x{h//reduction_factor} pixels")
print(f"Output size: {sr.size[0]}x{sr.size[1]} pixels")
print("=" * 60)
