#!/usr/bin/env python
# coding: utf-8

# # Understand the Reverse Diffusion Process - Generate figures with custom images

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from diffusers import DDPMPipeline, DDPMScheduler
from pathlib import Path

# Create figs directory if it doesn't exist
os.makedirs("./figs", exist_ok=True)

print("="*60)
print("REVERSE DIFFUSION PROCESS - HIGH RESOLUTION FIGURES")
print("="*60)

# ============================================================================
# CONFIGURATION
# ============================================================================

# Path to your training images
TRAINING_DATA_PATH = "./training_data"

# Choose a higher resolution pretrained model
# Options:
# - "google/ddpm-celebahq-256" (256x256 faces)
# - "google/ddpm-church-256" (256x256 churches)  
# - "google/ddpm-bedroom-256" (256x256 bedrooms)
MODEL_ID = "google/ddpm-celebahq-256"  # Change this based on your preference

IMAGE_SIZE = 256  # Will be determined by model

# ============================================================================
# LOAD PRETRAINED MODEL (Higher Resolution)
# ============================================================================

print(f"\nLoading pretrained model: {MODEL_ID}")
print("This model generates higher resolution images (256×256)...")

ddpm = DDPMPipeline.from_pretrained(MODEL_ID)
ddpm = ddpm.to("cuda")

unet = ddpm.unet
scheduler = ddpm.scheduler

print(f"✓ Model loaded successfully")
print(f"  Image size: {IMAGE_SIZE}×{IMAGE_SIZE}")
print(f"  Training timesteps: {scheduler.config.num_train_timesteps}")

# ============================================================================
# LOAD YOUR CUSTOM IMAGES
# ============================================================================

def load_custom_images(data_path, target_size=256, max_images=6):
    """Load images from training_data folder."""
    print(f"\nLoading images from {data_path}...")
    
    image_paths = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
        image_paths.extend(Path(data_path).glob(ext))
    
    if not image_paths:
        raise FileNotFoundError(f"No images found in {data_path}")
    
    images = []
    for img_path in sorted(image_paths)[:max_images]:
        img = Image.open(img_path).convert('RGB')
        # Resize to target size
        img = img.resize((target_size, target_size), Image.LANCZOS)
        images.append(img)
        print(f"  Loaded: {img_path.name}")
    
    print(f"✓ Loaded {len(images)} images")
    return images

def pil_to_tensor(pil_image):
    """Convert PIL image to tensor in [-1, 1] range."""
    img_array = np.array(pil_image).astype(np.float32) / 255.0
    img_array = img_array * 2.0 - 1.0  # Scale to [-1, 1]
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
    return img_tensor

def tensor_to_pil(tensor):
    """Convert tensor to PIL Image for visualization."""
    # Tensor is in range [-1, 1], convert to [0, 1]
    tensor = (tensor / 2 + 0.5).clamp(0, 1)
    # Convert to numpy and transpose
    image = tensor.cpu().permute(1, 2, 0).numpy()
    # Convert to uint8
    image = (image * 255).round().astype("uint8")
    return image

# Load images in training_data
try:
    custom_images = load_custom_images(TRAINING_DATA_PATH, target_size=IMAGE_SIZE, max_images=3)
except FileNotFoundError as e:
    print(f"\n⚠ Warning: {e}")
    print("Creating demo with pretrained model only...")
    custom_images = []

# ============================================================================
# FIGURE 1: FORWARD DIFFUSION ON IMAGES IN TRAINING_DATA
# ============================================================================

if custom_images:
    print("\n" + "="*60)
    print("GENERATING FIGURE 1: Forward Diffusion on Images In Training_Data")
    print("="*60)
    
    # Take first image
    original_img = custom_images[0]
    img_tensor = pil_to_tensor(original_img).to("cuda")
    
    # Apply forward diffusion at different timesteps
    print("\nApplying forward diffusion...")
    forward_timesteps = [0, 200, 400, 600, 800, 999]
    forward_images = []
    
    for t_val in forward_timesteps:
        if t_val == 0:
            forward_images.append(original_img)
        else:
            # Get noise schedule
            alpha_t = scheduler.alphas_cumprod[t_val]
            
            # Add noise
            torch.manual_seed(42)
            noise = torch.randn_like(img_tensor)
            noisy_img = torch.sqrt(alpha_t) * img_tensor + torch.sqrt(1 - alpha_t) * noise
            
            forward_images.append(Image.fromarray(tensor_to_pil(noisy_img[0])))
        
        print(f"  t={t_val}: α̅_t={scheduler.alphas_cumprod[t_val]:.4f}")
    
    # Visualize
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.flatten()
    
    for idx, (img, t_val) in enumerate(zip(forward_images, forward_timesteps)):
        axes[idx].imshow(img)
        alpha_bar = 1.0 if t_val == 0 else scheduler.alphas_cumprod[t_val].item()
        axes[idx].set_title(f"t={t_val}\n$\\bar{{\\alpha}}_t$={alpha_bar:.3f}", fontsize=10)
        axes[idx].axis('off')
    
    plt.suptitle("Forward Diffusion on a dog image (High Resolution)", 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig("./figs/forward_custom_image.png", dpi=150, bbox_inches='tight')
    print("\n✓ Saved: ./figs/forward_custom_image.png")
    plt.show()
    plt.close()

# ============================================================================
# FIGURE 2: REVERSE DIFFUSION PROGRESSION (HIGH RESOLUTION)
# ============================================================================

print("\n" + "="*60)
print("GENERATING FIGURE 2: Reverse Diffusion Progression (256×256)")
print("="*60)

torch.manual_seed(42)

# Start from pure noise
batch_size = 1
image_shape = (batch_size, 3, IMAGE_SIZE, IMAGE_SIZE)
x_T = torch.randn(image_shape).to("cuda")

# Use 1000 inference steps for high quality
num_inference_steps = 1000
scheduler.set_timesteps(num_inference_steps)

print(f"\nRunning reverse diffusion with {num_inference_steps} steps...")
print("Generating high-resolution image (256×256)...")

# Store intermediate results
timesteps_to_save = [0, 200, 400, 600, 800, 999]
saved_images = []
saved_timestep_values = []

x_t = x_T

for i, t in enumerate(scheduler.timesteps):
    t_tensor = torch.tensor([t]).to("cuda")
    
    with torch.no_grad():
        noise_pred = unet(x_t, t_tensor).sample
    
    x_t = scheduler.step(noise_pred, t, x_t).prev_sample
    
    if i in timesteps_to_save:
        saved_images.append(x_t[0].cpu())
        saved_timestep_values.append(t.item())
        print(f"  Step {i}/{num_inference_steps}, t={t.item():.0f}")

print("\n✓ Reverse diffusion complete!")

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for idx, (img_tensor, t_val) in enumerate(zip(saved_images, saved_timestep_values)):
    img_np = tensor_to_pil(img_tensor)
    axes[idx].imshow(img_np)
    axes[idx].set_title(f"Step {timesteps_to_save[idx]}/{num_inference_steps}\nt={t_val:.0f}", 
                       fontsize=10)
    axes[idx].axis('off')

plt.suptitle("Reverse Diffusion: Progressive Denoising (256×256, T=1000)", 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("./figs/reverse_diffusion_progression.png", dpi=150, bbox_inches='tight')
print("\n✓ Saved: ./figs/reverse_diffusion_progression.png")
plt.show()
plt.close()

# ============================================================================
# FIGURE 3: FORWARD VS REVERSE (HIGH RESOLUTION)
# ============================================================================

print("\n" + "="*60)
print("GENERATING FIGURE 3: Forward vs Reverse Comparison (256×256)")
print("="*60)

# Generate a clean image first
print("\nGenerating clean image with 1000 steps...")
torch.manual_seed(123)
generated_image = ddpm(batch_size=1, num_inference_steps=1000).images[0]
print("✓ Clean image generated")

# Convert to tensor
img_tensor = pil_to_tensor(generated_image).to("cuda")

# Forward process
print("\nApplying forward diffusion...")
num_forward_steps = 50
forward_timesteps = torch.linspace(0, 999, num_forward_steps).long()
forward_images = [img_tensor[0].cpu()]

for t in forward_timesteps[1:]:
    alpha_t = scheduler.alphas_cumprod[t]
    torch.manual_seed(42)
    noise = torch.randn_like(img_tensor)
    noisy_img = torch.sqrt(alpha_t) * img_tensor + torch.sqrt(1 - alpha_t) * noise
    forward_images.append(noisy_img[0].cpu())

print(f"✓ Created {len(forward_images)} forward steps")

# Reverse process
print("\nApplying reverse diffusion (1000 steps)...")
x_T_forward = noisy_img
scheduler.set_timesteps(1000)

reverse_images = [x_T_forward[0].cpu()]
x_t = x_T_forward

save_reverse_at = [0, 200, 400, 600, 800]
for i, t in enumerate(scheduler.timesteps):
    t_tensor = torch.tensor([t]).to("cuda")
    
    with torch.no_grad():
        noise_pred = unet(x_t, t_tensor).sample
    
    x_t = scheduler.step(noise_pred, t, x_t).prev_sample
    
    if i in save_reverse_at:
        reverse_images.append(x_t[0].cpu())

reverse_images.append(x_t[0].cpu())
print(f"✓ Created {len(reverse_images)} reverse checkpoints")

# Visualize
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Forward process
forward_indices = [0, 12, 24, 36, 49]
for idx, img_idx in enumerate(forward_indices):
    img_np = tensor_to_pil(forward_images[img_idx])
    axes[0, idx].imshow(img_np)
    t_val = forward_timesteps[img_idx].item()
    axes[0, idx].set_title(f"Forward\nt={t_val:.0f}", fontsize=9)
    axes[0, idx].axis('off')

# Reverse process
reverse_display_indices = [0, 1, 2, 3, 6]
reverse_step_labels = [0, 200, 400, 600, 1000]
for idx, (img_idx, step_label) in enumerate(zip(reverse_display_indices, reverse_step_labels)):
    img_np = tensor_to_pil(reverse_images[img_idx])
    axes[1, idx].imshow(img_np)
    axes[1, idx].set_title(f"Reverse\nstep {step_label}", fontsize=9)
    axes[1, idx].axis('off')

plt.suptitle("Forward vs Reverse Process (256×256, T=1000)", 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("./figs/forward_vs_reverse.png", dpi=150, bbox_inches='tight')
print("\n✓ Saved: ./figs/forward_vs_reverse.png")
plt.show()
plt.close()

# ============================================================================
# FIGURE 4: MULTIPLE HIGH-QUALITY SAMPLES
# ============================================================================

print("\n" + "="*60)
print("GENERATING FIGURE 4: Multiple Diverse Samples (256×256)")
print("="*60)

print("\nGenerating 6 samples with T=1000...")
num_samples = 6
samples = []

for i in range(num_samples):
    print(f"  Generating sample {i+1}/{num_samples}...")
    torch.manual_seed(i * 100)
    
    x_T = torch.randn(image_shape).to("cuda")
    scheduler.set_timesteps(1000)
    
    x_t = x_T
    for t in scheduler.timesteps:
        t_tensor = torch.tensor([t]).to("cuda")
        
        with torch.no_grad():
            noise_pred = unet(x_t, t_tensor).sample
        
        x_t = scheduler.step(noise_pred, t, x_t).prev_sample
    
    samples.append(x_t[0].cpu())

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for idx, img_tensor in enumerate(samples):
    img_np = tensor_to_pil(img_tensor)
    axes[idx].imshow(img_np)
    axes[idx].set_title(f"Sample {idx+1}", fontsize=11)
    axes[idx].axis('off')

plt.suptitle("Multiple Samples from Reverse Diffusion (256×256, T=1000)", 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("./figs/multiple_samples.png", dpi=150, bbox_inches='tight')
print("\n✓ Saved: ./figs/multiple_samples.png")
plt.show()
plt.close()

# ============================================================================
# FIGURE 5: EFFECT OF NUMBER OF STEPS
# ============================================================================

print("\n" + "="*60)
print("GENERATING FIGURE 5: Effect of Number of Steps")
print("="*60)

torch.manual_seed(42)
x_T = torch.randn(image_shape).to("cuda")

step_counts = [10, 50, 250, 1000]
results = []

for steps in step_counts:
    print(f"\nGenerating with {steps} steps...")
    scheduler.set_timesteps(steps)
    
    x_t = x_T.clone()
    for t in scheduler.timesteps:
        t_tensor = torch.tensor([t]).to("cuda")
        
        with torch.no_grad():
            noise_pred = unet(x_t, t_tensor).sample
        
        x_t = scheduler.step(noise_pred, t, x_t).prev_sample
    
    results.append(x_t[0].cpu())
    print(f"  ✓ Completed {steps} steps")

# Visualize
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for idx, (img_tensor, steps) in enumerate(zip(results, step_counts)):
    img_np = tensor_to_pil(img_tensor)
    axes[idx].imshow(img_np)
    axes[idx].set_title(f"T={steps} steps", fontsize=12)
    axes[idx].axis('off')

plt.suptitle("Effect of Number of Inference Steps (256×256)", 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("./figs/steps_comparison.png", dpi=150, bbox_inches='tight')
print("\n✓ Saved: ./figs/steps_comparison.png")
plt.show()
plt.close()

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*60)
print("SUMMARY - ALL HIGH-RESOLUTION FIGURES GENERATED")
print("="*60)
print("\nGenerated figures (256×256 resolution):")
if custom_images:
    print("  ✓ ./figs/forward_custom_image.png (using images in training_data!)")
print("  ✓ ./figs/reverse_diffusion_progression.png")
print("  ✓ ./figs/forward_vs_reverse.png")
print("  ✓ ./figs/multiple_samples.png")
print("  ✓ ./figs/steps_comparison.png")
print("\nAll figures use T=1000 steps for optimal quality!")
print(f"Model used: {MODEL_ID}")
print("="*60)

