#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Generate all images needed for the diffusion models lecture slides.

This script creates the images referenced in the slides by running
and capturing outputs from the four main scripts.
"""

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Create output directory
OUTPUT_DIR = Path("./slide_images")
OUTPUT_DIR.mkdir(exist_ok=True)

print("Generating images for lecture slides...")
print("=" * 80)

# ============================================================================
# IMAGE 1: Forward Diffusion Progression
# ============================================================================

def generate_forward_diffusion_progression():
    """
    Generate image showing forward diffusion process.
    Run: understand_the_theory_behind_diffusion_model.py
    """
    print("\n1. Generating forward_diffusion_progression.png")
    print("   Source: understand_the_theory_behind_diffusion_model.py")
    print("   Shows: Clean image → progressively noisier → pure noise")
    
    # This requires running your script and saving the output
    # Add this to the end of understand_the_theory_behind_diffusion_model.py:
    code = """
# At the end of understand_the_theory_behind_diffusion_model.py, add:
fig.savefig('slide_images/forward_diffusion_progression.png', 
            dpi=300, bbox_inches='tight')
print("Saved: forward_diffusion_progression.png")
"""
    print("   Add this code to your script:")
    print(code)


# ============================================================================
# IMAGE 2: Alpha Bar Schedule
# ============================================================================

def generate_alpha_bar_schedule():
    """Generate plot of alpha_bar vs timestep."""
    print("\n2. Generating alpha_bar_schedule.png")
    
    # Parameters matching your script
    T = 1000
    beta_start = 0.0001
    beta_end = 0.02
    
    # Linear schedule
    betas = np.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alpha_bar = np.cumprod(alphas)
    
    # Create figure
    plt.figure(figsize=(10, 6))
    plt.plot(range(T), alpha_bar, linewidth=2.5, color='blue')
    plt.xlabel('Timestep $t$', fontsize=14)
    plt.ylabel(r'$\bar{\alpha}_t$', fontsize=14)
    plt.title(r'Noise Schedule: $\bar{\alpha}_t$ decreases over time', fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.xlim(0, T)
    plt.ylim(0, 1.05)
    
    # Add annotations
    plt.axhline(y=1, color='green', linestyle='--', alpha=0.5, linewidth=1.5)
    plt.text(50, 1.02, 'No noise', fontsize=12, color='green')
    
    plt.axhline(y=0, color='red', linestyle='--', alpha=0.5, linewidth=1.5)
    plt.text(50, 0.03, 'Pure noise', fontsize=12, color='red')
    
    # Mark key points
    plt.plot([0], [1.0], 'go', markersize=10, label=r'$t=0$: clean image')
    plt.plot([T-1], [alpha_bar[-1]], 'ro', markersize=10, 
             label=f'$t=T$: noise ($\\bar{{\\alpha}}_T = {alpha_bar[-1]:.3f}$)')
    
    plt.legend(fontsize=12, loc='upper right')
    plt.tight_layout()
    
    # Save
    output_path = OUTPUT_DIR / "alpha_bar_schedule.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"   ✓ Saved: {output_path}")
    plt.close()


# ============================================================================
# IMAGE 3: Text-to-Image Results
# ============================================================================

def generate_text_to_image_results():
    """
    Generate image showing text-to-image results.
    Run: understand_how_stable_diffusion_works.py
    """
    print("\n3. Generating text_to_image_results.png")
    print("   Source: understand_how_stable_diffusion_works.py")
    print("   Shows: Text-to-image and image-to-image results")
    
    code = """
# At the end of understand_how_stable_diffusion_works.py, modify:
plt.savefig('slide_images/text_to_image_results.png', 
            dpi=300, bbox_inches='tight')
print("Saved: text_to_image_results.png")
"""
    print("   Add this code to your script:")
    print(code)


# ============================================================================
# IMAGE 4: Training Loss Curve
# ============================================================================

def generate_training_loss_curve():
    """
    Generate training loss curve.
    Requires: Modifying finetune_2080_optimized.py
    """
    print("\n4. Generating training_loss_curve.png")
    print("   Source: finetune_2080_optimized.py")
    print("   Shows: Training loss over epochs")
    
    code = """
# Add to finetune_2080_optimized.py:

# After main() function, add:
loss_history = []  # Initialize at start of training

# Inside training loop, after computing loss:
loss_history.append(loss.item())

# After training completes:
plt.figure(figsize=(10, 6))
plt.plot(loss_history, linewidth=2)
plt.xlabel('Training Step', fontsize=14)
plt.ylabel('Loss (MSE)', fontsize=14)
plt.title('Training Loss: UNet Learning to Predict Noise', fontsize=16)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('slide_images/training_loss_curve.png', dpi=300, bbox_inches='tight')
print("Saved: training_loss_curve.png")
"""
    print("   Add this code to your script:")
    print(code)


# ============================================================================
# IMAGE 5: Fine-tuned Results
# ============================================================================

def generate_finetuned_results():
    """
    Generate fine-tuned model results.
    Run: inference_finetuned.py
    """
    print("\n5. Generating finetuned_results.png")
    print("   Source: inference_finetuned.py")
    print("   Shows: Images generated by fine-tuned model")
    
    code = """
# At the end of inference_finetuned.py, modify the save:
grid_path = Path('slide_images') / "finetuned_results.png"
plt.savefig(grid_path, dpi=300, bbox_inches='tight')
print(f"Saved: {grid_path}")
"""
    print("   Add this code to your script:")
    print(code)


# ============================================================================
# IMAGE 6: Base vs Fine-tuned Comparison
# ============================================================================

def generate_comparison_image():
    """
    Generate side-by-side comparison.
    Create manually or with script.
    """
    print("\n6. Generating base_vs_finetuned_comparison.png")
    print("   Manual creation required:")
    print("   - Generate images with base model (script 2)")
    print("   - Generate images with fine-tuned model (script 4)")
    print("   - Use the following code to create comparison:")
    
    code = """
import matplotlib.pyplot as plt
from PIL import Image

# Load images
base_img = Image.open('base_model_output.png')
finetuned_img = Image.open('finetuned_model_output.png')

# Create comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(base_img)
axes[0].set_title('Base Model\\n(General Knowledge)', fontsize=14)
axes[0].axis('off')

axes[1].imshow(finetuned_img)
axes[1].set_title('Fine-tuned Model\\n(Specialized for Cats/Dogs)', fontsize=14)
axes[1].axis('off')

plt.tight_layout()
plt.savefig('slide_images/base_vs_finetuned_comparison.png', 
            dpi=300, bbox_inches='tight')
print("Saved: base_vs_finetuned_comparison.png")
"""
    print(code)


# ============================================================================
# MAIN
# ============================================================================

def main():
    print("\nImage Generation Guide for Lecture Slides")
    print("=" * 80)
    
    # Generate what we can
    generate_alpha_bar_schedule()
    
    # Print instructions for others
    generate_forward_diffusion_progression()
    generate_text_to_image_results()
    generate_training_loss_curve()
    generate_finetuned_results()
    generate_comparison_image()
    
    print("\n" + "=" * 80)
    print("Summary of Required Images")
    print("=" * 80)
    
    images = [
        ("forward_diffusion_progression.png", "Script 1", "Forward diffusion process visualization"),
        ("alpha_bar_schedule.png", "Generated", "✓ Alpha-bar schedule plot"),
        ("text_to_image_results.png", "Script 2", "Text-to-image generation results"),
        ("training_loss_curve.png", "Script 3", "Training loss over time"),
        ("finetuned_results.png", "Script 4", "Fine-tuned model outputs"),
        ("base_vs_finetuned_comparison.png", "Manual", "Side-by-side comparison"),
    ]
    
    print("\n{:<40} {:<15} {:<40}".format("Image", "Source", "Description"))
    print("-" * 95)
    for img, source, desc in images:
        status = "✓" if (OUTPUT_DIR / img).exists() else "○"
        print(f"{status} {img:<38} {source:<15} {desc:<40}")
    
    print("\n" + "=" * 80)
    print("Next Steps:")
    print("=" * 80)
    print("1. Run the four scripts with the modifications shown above")
    print("2. All images will be saved to ./slide_images/")
    print("3. Copy images to your slides directory")
    print("4. Compile slides: pdflatex diffusion_lecture_slides.tex")
    print("=" * 80)


if __name__ == "__main__":
    main()
