#!/usr/bin/env python
# coding: utf-8

"""
StyleGAN2 Customization Script
Standalone version - runs without Jupyter, saves all outputs to ./figs
"""

import os
import subprocess
import sys
import shutil
import matplotlib.pyplot as plt
import numpy as np
import torch
import glob
from PIL import Image

# Create output directory for figures
os.makedirs('./figs', exist_ok=True)

print("Setting up StyleGAN2 encoder...")
if not os.path.exists('pytorch_stylegan_encoder'):
    # Clone the repo
    print("Cloning repository...")
    subprocess.check_call(['git', 'clone', 'https://github.com/sizhky/pytorch_stylegan_encoder.git'])

# Always change to the encoder directory
os.chdir('pytorch_stylegan_encoder')

# Check if setup is complete
if not os.path.exists('InterFaceGAN/models/pretrain/stylegan_ffhq.pth'):
    # Update submodules
    print("Updating submodules...")
    subprocess.check_call(['git', 'submodule', 'update', '--init', '--recursive'])

    # Download and unzip trained models
    print("Downloading trained models (this may take a few minutes)...")
    
    # Clean up any corrupted files from previous attempts
    if os.path.exists('trained_models.zip'):
        print("   Removing corrupted zip from previous attempt...")
        os.remove('trained_models.zip')
    
    try:
        # Try with wget with proper flags
        subprocess.check_call([
            'wget', '-q', '--show-progress',
            'https://github.com/jacobhallberg/pytorch_stylegan_encoder/releases/download/v1.0/trained_models.zip',
            '-O', 'trained_models.zip'
        ])
    except:
        print("   wget failed, trying with curl...")
        subprocess.check_call([
            'curl', '-L', '-o', 'trained_models.zip',
            'https://github.com/jacobhallberg/pytorch_stylegan_encoder/releases/download/v1.0/trained_models.zip'
        ])
    
    # Verify the zip file is valid
    print("Extracting models...")
    try:
        subprocess.check_call(['unzip', '-q', 'trained_models.zip'])
    except subprocess.CalledProcessError:
        print("\n" + "!" * 60)
        print("ERROR: Downloaded zip file is corrupted!")
        print("!" * 60)
        print("\nPlease download manually:")
        print("1. Visit: https://github.com/jacobhallberg/pytorch_stylegan_encoder/releases/tag/v1.0")
        print("2. Download 'trained_models.zip'")
        print("3. Extract it in: pytorch_stylegan_encoder/")
        print("4. Run this script again")
        print("!" * 60)
        exit(1)
    
    os.remove('trained_models.zip')

    # Install torch_snippets
    print("Installing dependencies...")
    subprocess.check_call(['pip', 'install', '-qU', 'torch_snippets'])

    # Move pretrained model
    os.makedirs('InterFaceGAN/models/pretrain', exist_ok=True)
    subprocess.check_call([
        'mv', 'trained_models/stylegan_ffhq.pth', 'InterFaceGAN/models/pretrain'
    ])
    print("Setup complete!")
else:
    print("Using existing setup...")

# Add current directory to Python path for imports
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())

from torch_snippets import *
from InterFaceGAN.models.stylegan_generator import StyleGANGenerator

# Import PostSynthesisProcessing more carefully
try:
    from models.latent_optimizer import PostSynthesisProcessing
except (ImportError, ModuleNotFoundError):
    # Fallback: try importing from a different location or create a simple wrapper
    print("   Note: Using alternative post-processing method...")
    # Define a simple post-processing function inline
    class PostSynthesisProcessing:
        def __call__(self, image):
            # Simple post-processing: clamp and scale to 0-255 range
            return torch.clamp((image + 1) * 127.5, 0, 255).to(torch.uint8)

print("Loading StyleGAN2 models...")
synthesizer = StyleGANGenerator("stylegan_ffhq").model.synthesis
mapper = StyleGANGenerator("stylegan_ffhq").model.mapping
trunc = StyleGANGenerator("stylegan_ffhq").model.truncation

post_processing = PostSynthesisProcessing()
post_process = lambda image: post_processing(image).detach().cpu().numpy().astype(np.uint8)[0]

def latent2image(latent):
    """Convert latent vector to image"""
    img = post_process(synthesizer(latent))
    img = img.transpose(1,2,0)
    return img

def save_image(img, filename, title=None):
    """Save image to file"""
    plt.figure(figsize=(5, 5))
    plt.imshow(img)
    if title:
        plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved: {filename}")

# Generate random face
print("\n1. Generating random face from noise...")
rand_latents = torch.randn(1,512).cuda()
random_face = latent2image(trunc(mapper(rand_latents)))
save_image(random_face, '../figs/01_random_generated_face.png', 'Random Generated Face')

# Download and prepare user image
print("\n2. Downloading and preparing user image...")

# First, check if MyImage.jpg exists in the parent directory (where code4 would use it)
parent_image = '../MyImage.jpg'
need_download = False

if os.path.exists(parent_image):
    try:
        Image.open(parent_image).verify()
        print("   Found MyImage.jpg in parent directory")
        # Copy it to current directory
        subprocess.check_call(['cp', parent_image, 'MyImage.jpg'])
        print("   ✓ Using your custom image")
    except:
        print("   Image in parent directory is corrupted")
        need_download = True
elif os.path.exists('MyImage.jpg'):
    try:
        Image.open('MyImage.jpg').verify()
        print("   Using existing MyImage.jpg")
    except:
        print("   Existing file is corrupted, re-downloading...")
        os.remove('MyImage.jpg')
        need_download = True
else:
    need_download = True

if need_download:
    # Use a reliable test image source
    try:
        # Try Wikipedia Commons image
        subprocess.check_call([
            "wget", "-q",
            "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/481px-Cat03.jpg",
            "-O", "MyImage.jpg"
        ], timeout=10)
        print("   ✓ Downloaded test image")
    except:
        print("\n" + "!" * 60)
        print("ERROR: Could not download test image!")
        print("!" * 60)
        print("\nPlease provide your own face image:")
        print("1. Place a JPG image containing a face in: pytorch_stylegan_encoder/")
        print("   OR place it in the parent directory (same location as code4 uses)")
        print("2. Rename it to 'MyImage.jpg'")
        print("3. Re-run this script")
        print("!" * 60)
        exit(1)

# Clone the encoder repository if needed
if not os.path.exists('stylegan-encoder'):
    print("Setting up image encoder...")
    subprocess.check_call([
        "git", "clone", "https://github.com/sizhky/stylegan-encoder.git"
    ])

# Create necessary directories
os.makedirs("stylegan-encoder/raw_images", exist_ok=True)
os.makedirs("stylegan-encoder/aligned_images", exist_ok=True)

# Copy image to raw_images folder
subprocess.check_call([
    "cp", "MyImage.jpg", "stylegan-encoder/raw_images/"
])

# Try to align the face
print("Aligning face in image...")
alignment_success = False
try:
    # First check if TensorFlow is installed
    try:
        import tensorflow
        print("   TensorFlow found, proceeding with alignment...")
    except ImportError:
        print("   TensorFlow not found, installing (this may take a few minutes)...")
        subprocess.check_call([
            "pip", "install", "-q", "tensorflow-cpu", "--break-system-packages"
        ])
    
    # Run the alignment script
    subprocess.check_call([
        "python", "stylegan-encoder/align_images.py",
        "stylegan-encoder/raw_images/", "stylegan-encoder/aligned_images/"
    ])
    
    # Move the aligned image
    aligned_files = glob.glob("stylegan-encoder/aligned_images/*.png") + glob.glob("stylegan-encoder/aligned_images/*.jpg")
    if aligned_files:
        subprocess.check_call(["cp", aligned_files[0], "MyImage.jpg"])
        alignment_success = True
        print("   ✓ Face aligned successfully")
except Exception as e:
    print(f"   ⚠ Face alignment failed: {str(e)}")
    print("   Continuing with original image (results may be suboptimal)...")
    # Keep using the original MyImage.jpg

# Load and save original image
img = Image.open('MyImage.jpg')

# Convert to RGB if needed (JPEG only supports RGB)
if img.mode != 'RGB':
    print(f"   Converting {img.mode} image to RGB...")
    if img.mode == 'RGBA':
        # Create a white background for transparent images
        rgb_img = Image.new('RGB', img.size, (255, 255, 255))
        # Paste the image on the white background using alpha channel as mask
        rgb_img.paste(img, mask=img.split()[3])  # 3 is the alpha channel
        img = rgb_img
    else:
        # For other modes (L, P, etc.), just convert directly
        img = img.convert('RGB')

# Resize to 1024x1024 if needed for StyleGAN
if img.size != (1024, 1024):
    print(f"   Resizing image from {img.size} to 1024x1024...")
    img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
    
# Save the processed image
img.save('MyImage.jpg', 'JPEG')
original_img = np.array(img)
save_image(original_img, '../figs/02_original_image.png', 'Original Image')

# Encode image to latent space
print("\n3. Encoding image to latent space...")
subprocess.check_call([
    "python", "encode_image.py",
    "./MyImage.jpg",
    "pred_dlatents_myImage.npy",
    "--use_latent_finder", "true",
    "--image_to_latent_path", "./trained_models/image_to_latent.pt"
])

# Load and reconstruct from latent
pred_dlatents = np.load('pred_dlatents_myImage.npy')
pred_dlatent = torch.from_numpy(pred_dlatents).float().cuda()
pred_image = latent2image(pred_dlatent)
save_image(pred_image, '../figs/03_reconstructed_image.png', 'Reconstructed from Latent')

# Feature transfer experiments
print("\n4. Performing feature transfer experiments...")

# High-level features (0-3)
print("  - Transferring high-level features (layers 0-3)...")
idxs_to_swap = slice(0,3)
my_latents = torch.Tensor(np.load('pred_dlatents_myImage.npy', allow_pickle=True))

# Generate new random face
rand_latents = torch.randn(1,512).cuda()
A, B = latent2image(my_latents.cuda()), latent2image(trunc(mapper(rand_latents)))
generated_image_latents = trunc(mapper(rand_latents))

# Transfer high-level features from generated to original
x = my_latents.clone()
x[:,idxs_to_swap] = generated_image_latents[:,idxs_to_swap]
a = latent2image(x.float().cuda())

# Transfer high-level features from original to generated
x = generated_image_latents.clone()
x[:,idxs_to_swap] = my_latents[:,idxs_to_swap]
b = latent2image(x.float().cuda())

# Create subplot
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Transfer High Level Features (Layers 0-3)', fontsize=16)
axes[0, 0].imshow(A)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')
axes[0, 1].imshow(a)
axes[0, 1].set_title('Original + Generated Style')
axes[0, 1].axis('off')
axes[1, 0].imshow(B)
axes[1, 0].set_title('Generated')
axes[1, 0].axis('off')
axes[1, 1].imshow(b)
axes[1, 1].set_title('Generated + Original Style')
axes[1, 1].axis('off')
plt.tight_layout()
plt.savefig('../figs/04_transfer_high_level_features.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: ../figs/04_transfer_high_level_features.png")

# Mid-level features (4-15)
print("  - Transferring mid-level features (layers 4-15)...")
idxs_to_swap = slice(4,15)
my_latents = torch.Tensor(np.load('pred_dlatents_myImage.npy', allow_pickle=True))

rand_latents = torch.randn(1,512).cuda()
A, B = latent2image(my_latents.cuda()), latent2image(trunc(mapper(rand_latents)))
generated_image_latents = trunc(mapper(rand_latents))

x = my_latents.clone()
x[:,idxs_to_swap] = generated_image_latents[:,idxs_to_swap]
a = latent2image(x.float().cuda())

x = generated_image_latents.clone()
x[:,idxs_to_swap] = my_latents[:,idxs_to_swap]
b = latent2image(x.float().cuda())

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Transfer Granular Features (Layers 4-15)', fontsize=16)
axes[0, 0].imshow(A)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')
axes[0, 1].imshow(a)
axes[0, 1].set_title('Original + Generated Details')
axes[0, 1].axis('off')
axes[1, 0].imshow(B)
axes[1, 0].set_title('Generated')
axes[1, 0].axis('off')
axes[1, 1].imshow(b)
axes[1, 1].set_title('Generated + Original Details')
axes[1, 1].axis('off')
plt.tight_layout()
plt.savefig('../figs/05_transfer_granular_features.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: ../figs/05_transfer_granular_features.png")

# Fine features (16-18)
print("  - Transferring fine features (layers 16-18)...")
idxs_to_swap = slice(16,18)
my_latents = torch.Tensor(np.load('pred_dlatents_myImage.npy', allow_pickle=True))

rand_latents = torch.randn(1,512).cuda()
A, B = latent2image(my_latents.cuda()), latent2image(trunc(mapper(rand_latents)))
generated_image_latents = trunc(mapper(rand_latents))

x = my_latents.clone()
x[:,idxs_to_swap] = generated_image_latents[:,idxs_to_swap]
a = latent2image(x.float().cuda())

x = generated_image_latents.clone()
x[:,idxs_to_swap] = my_latents[:,idxs_to_swap]
b = latent2image(x.float().cuda())

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Transfer Fine Features (Layers 16-18)', fontsize=16)
axes[0, 0].imshow(A)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')
axes[0, 1].imshow(a)
axes[0, 1].set_title('Original + Generated Texture')
axes[0, 1].axis('off')
axes[1, 0].imshow(B)
axes[1, 0].set_title('Generated')
axes[1, 0].axis('off')
axes[1, 1].imshow(b)
axes[1, 1].set_title('Generated + Original Texture')
axes[1, 1].axis('off')
plt.tight_layout()
plt.savefig('../figs/06_transfer_fine_features.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: ../figs/06_transfer_fine_features.png")

# Attribute manipulation - Smile
print("\n5. Manipulating facial attributes (smile)...")

# Clean up old results directory if it exists
if os.path.exists('results_new_smile'):
    print("   Removing old results directory...")
    shutil.rmtree('results_new_smile')

subprocess.check_call([
    "python", "InterFaceGAN/edit.py",
    "-m", "stylegan_ffhq",
    "-o", "results_new_smile",
    "-b", "InterFaceGAN/boundaries/stylegan_ffhq_smile_w_boundary.npy",
    "-i", "pred_dlatents_myImage.npy",
    "-s", "WP",
    "--steps", "20"
])

generated_faces = sorted(glob.glob('results_new_smile/*.jpg'))
if generated_faces:
    # Create grid of smile variations
    n_images = len(generated_faces)
    cols = 5
    rows = (n_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 3*rows))
    fig.suptitle('Smile Attribute Manipulation (from frown to smile)', fontsize=16)
    
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, img_path in enumerate(generated_faces):
        row = idx // cols
        col = idx % cols
        img = plt.imread(img_path)
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        axes[row, col].set_title(f'Step {idx}')
    
    # Hide empty subplots
    for idx in range(n_images, rows * cols):
        row = idx // cols
        col = idx % cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.savefig('../figs/07_smile_manipulation.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("Saved: ../figs/07_smile_manipulation.png")

print("\n" + "="*60)
print("COMPLETE! All images saved to ./figs/")
print("="*60)
print("\nGenerated files:")
for fig_file in sorted(glob.glob('../figs/*.png')):
    print(f"  - {fig_file}")
