#!/usr/bin/env python
# coding: utf-8

# # Understand the theory behind diffusion models - Generate figures for slides

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from itertools import accumulate
from diffusers import StableDiffusionPipeline

# Create figs directory if it doesn't exist
os.makedirs("./figs", exist_ok=True)

# Generate a sample image
print("Generating dog image...")
text2img_pipe = StableDiffusionPipeline.from_pretrained(
    "stablediffusionapi/deliberate-v2",
    torch_dtype=torch.float16
).to("cuda:0")

generator = torch.Generator("cuda:0").manual_seed(2)
prompt = "high quality, a happy dog running on the grass"
image = text2img_pipe(prompt=prompt, generator=generator).images[0]

# Display the original image
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.title("Original Dog Image (x_0)", fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()

image.save("./figs/dog_step0.png")  # Save original as step 0
print("Saved: ./figs/dog_step0.png")

# ## Demo 1: Forward Diffusion - Iterative Process

def forward_diffusion_iterative(image, num_steps=16, beta=0.1, return_all=False):
    """
    Apply forward diffusion process step by step.
    
    Formula: x_t = sqrt(1 - beta) * x_{t-1} + sqrt(beta) * epsilon
    """
    # Ensure image is in [0, 1] range
    if image.max() > 1.0:
        image = image / 255.0
    
    # Transform to [-1, 1] range (standard in diffusion models)
    image = image.astype(np.float64) * 2 - 1
    
    images = []
    current_image = np.copy(image)
    
    for t in range(num_steps):
        # Sample noise
        epsilon = np.random.randn(*current_image.shape)
        
        # Apply one step of diffusion
        current_image = (
            np.sqrt(1 - beta) * current_image + 
            np.sqrt(beta) * epsilon
        )
        
        # Store for visualization (convert back to [0,1])
        img_display = np.clip((current_image + 1) / 2, 0, 1)
        images.append(img_display)
    
    if return_all:
        return images, current_image
    return images

# Load and process
print("\nGenerating iterative forward diffusion images...")
original_image = plt.imread("./figs/dog_step0.png")
images_iterative = forward_diffusion_iterative(original_image, num_steps=16, beta=0.1)

# Save and display individual steps for the first slide
print("Saving individual step images...")

# Display and save Step 1
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(images_iterative[0])
ax.set_title("Step 1", fontsize=14)
ax.axis('off')
plt.tight_layout()
plt.show()
plt.imsave("./figs/dog_step1.png", images_iterative[0])
print("Saved: ./figs/dog_step1.png")

# Display and save Step 5
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(images_iterative[4])
ax.set_title("Step 5", fontsize=14)
ax.axis('off')
plt.tight_layout()
plt.show()
plt.imsave("./figs/dog_step5.png", images_iterative[4])
print("Saved: ./figs/dog_step5.png")

# Display and save Final step
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(images_iterative[-1])
ax.set_title("Final Step (Step 16)", fontsize=14)
ax.axis('off')
plt.tight_layout()
plt.show()
plt.imsave("./figs/dog_step_final.png", images_iterative[-1])
print("Saved: ./figs/dog_step_final.png")

# Visualize all steps in grid
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
axes = axes.flatten()

for idx, img in enumerate(images_iterative):
    axes[idx].imshow(img)
    axes[idx].set_title(f"Step {idx+1}", fontsize=10)
    axes[idx].axis('off')

plt.suptitle("Forward Diffusion: Iterative Process", fontsize=16)
plt.tight_layout()
plt.savefig("./figs/forward_diffusion_iterative.png", dpi=150, bbox_inches='tight')
plt.show()
plt.close()
print("Saved: ./figs/forward_diffusion_iterative.png")

# ## Demo 2: Forward Diffusion - Direct Sampling (Closed Form)

def forward_diffusion_direct(image, timestep, beta=0.05, num_timesteps=20):
    """
    Jump directly to timestep t using closed-form solution.
    
    Formula: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
    """
    # Prepare image
    if image.max() > 1.0:
        image = image / 255.0
    image = image.astype(np.float64) * 2 - 1  # Transform to [-1, 1]
    
    # Compute alpha_bar_t
    betas = [beta] * num_timesteps
    alphas = [1 - b for b in betas]
    alpha_bars = list(accumulate(alphas, lambda x, y: x * y))
    
    # Direct sampling at timestep t
    alpha_bar_t = alpha_bars[timestep]
    epsilon = np.random.randn(*image.shape)
    
    x_t = (
        np.sqrt(alpha_bar_t) * image + 
        np.sqrt(1 - alpha_bar_t) * epsilon
    )
    
    # Convert back to [0, 1]
    x_t_display = np.clip((x_t + 1) / 2, 0, 1)
    
    return x_t_display, alpha_bar_t

# Compare different timesteps
print("\nGenerating direct sampling comparison...")
original_image = plt.imread("./figs/dog_step0.png")
timesteps = [0, 5, 10, 15, 19]

fig, axes = plt.subplots(1, len(timesteps), figsize=(15, 3))

for idx, t in enumerate(timesteps):
    if t == 0:
        # Show original
        img_display = original_image if original_image.max() <= 1.0 else original_image / 255.0
        alpha_bar = 1.0
    else:
        img_display, alpha_bar = forward_diffusion_direct(original_image, t-1, beta=0.05)
    
    axes[idx].imshow(img_display)
    axes[idx].set_title(f"t={t}\n$\\bar{{\\alpha}}_t$={alpha_bar:.3f}", fontsize=11)
    axes[idx].axis('off')

plt.suptitle("Forward Diffusion: Direct Sampling at Different Timesteps", fontsize=14)
plt.tight_layout()
plt.savefig("./figs/forward_diffusion_direct.png", dpi=150, bbox_inches='tight')
plt.show()
plt.close()
print("Saved: ./figs/forward_diffusion_direct.png")

# ## Demo 3: Visualize the role of alpha_bar

def visualize_alpha_schedule(num_timesteps=20, beta=0.05):
    """Visualize how alpha_bar_t decreases over time."""
    betas = [beta] * num_timesteps
    alphas = [1 - b for b in betas]
    alpha_bars = list(accumulate(alphas, lambda x, y: x * y))
    
    timesteps = list(range(num_timesteps))
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot alpha_bar_t
    ax1.plot(timesteps, alpha_bars, 'b-', linewidth=2, label='$\\bar{\\alpha}_t$')
    ax1.axhline(y=0, color='r', linestyle='--', alpha=0.3)
    ax1.set_xlabel('Timestep t', fontsize=12)
    ax1.set_ylabel('$\\bar{\\alpha}_t$', fontsize=12)
    ax1.set_title('Cumulative Signal Retention', fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.legend(fontsize=11)
    
    # Plot coefficients
    signal_coef = [np.sqrt(ab) for ab in alpha_bars]
    noise_coef = [np.sqrt(1 - ab) for ab in alpha_bars]
    
    ax2.plot(timesteps, signal_coef, 'b-', linewidth=2, label='$\\sqrt{\\bar{\\alpha}_t}$ (signal)')
    ax2.plot(timesteps, noise_coef, 'r-', linewidth=2, label='$\\sqrt{1-\\bar{\\alpha}_t}$ (noise)')
    ax2.set_xlabel('Timestep t', fontsize=12)
    ax2.set_ylabel('Coefficient', fontsize=12)
    ax2.set_title('Signal vs Noise Coefficients', fontsize=14)
    ax2.grid(True, alpha=0.3)
    ax2.legend(fontsize=11)
    
    plt.tight_layout()
    plt.savefig("./figs/alpha_schedule.png", dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()
    
    return alpha_bars

print("\nGenerating alpha schedule visualization...")
alpha_bars = visualize_alpha_schedule()
print("Saved: ./figs/alpha_schedule.png")

# ## Demo 4: Show that at t=T, we get pure Gaussian noise

print("\nGenerating forward endpoint comparison...")
original_image = plt.imread("./figs/dog_step0.png")
final_noisy, alpha_bar_final = forward_diffusion_direct(original_image, 19, beta=0.05)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original
img_orig = original_image if original_image.max() <= 1.0 else original_image / 255.0
axes[0].imshow(img_orig)
axes[0].set_title("Original Image\n$x_0$", fontsize=14)
axes[0].axis('off')

# Noisy
axes[1].imshow(final_noisy)
axes[1].set_title(f"After T=20 steps\n$x_T$ ($\\bar{{\\alpha}}_T$={alpha_bar_final:.4f})", fontsize=14)
axes[1].axis('off')

# Pure Gaussian
pure_noise = np.random.randn(*original_image.shape)
pure_noise = np.clip((pure_noise + 1) / 2, 0, 1)
axes[2].imshow(pure_noise)
axes[2].set_title("Pure Gaussian Noise\n$\\mathcal{N}(0, I)$", fontsize=14)
axes[2].axis('off')

plt.suptitle("Forward Diffusion Endpoint: Signal → Noise", fontsize=16)
plt.tight_layout()
plt.savefig("./figs/forward_endpoint.png", dpi=150, bbox_inches='tight')
plt.show()
plt.close()
print("Saved: ./figs/forward_endpoint.png")

print("\n" + "="*60)
print("Summary Statistics:")
print("="*60)
print(f"At final timestep T=20:")
print(f"  ᾱ_T = {alpha_bar_final:.6f} ≈ 0")
print(f"  √(ᾱ_T) = {np.sqrt(alpha_bar_final):.6f} (signal coefficient)")
print(f"  √(1-ᾱ_T) = {np.sqrt(1-alpha_bar_final):.6f} (noise coefficient)")
print(f"\nConclusion: x_T ≈ ε (pure noise)")
print("="*60)

print("\n✓ All figures generated successfully in ./figs/ directory!")
print("\nGenerated files:")
files = [
    "dog_step0.png",
    "dog_step1.png", 
    "dog_step5.png",
    "dog_step_final.png",
    "alpha_schedule.png",
    "forward_diffusion_iterative.png",
    "forward_diffusion_direct.png",
    "forward_endpoint.png"
]
for f in files:
    print(f"  - ./figs/{f}")
