#!/usr/bin/env python
# coding: utf-8

"""
Handwritten Digit Generation using Vanilla GAN
Demonstrates basic GAN concepts on MNIST dataset
"""

import os
import torch
import torch.optim as optim
from torch_snippets.torch_loader import Report
from torch import nn as nn
from torch_snippets import *
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision.utils import make_grid, save_image
from torchvision.datasets import MNIST
from torchvision import transforms

# Create output directory for figures
os.makedirs('./figs', exist_ok=True)

print("=" * 60)
print("VANILLA GAN: Handwritten Digit Generation")
print("=" * 60)

# Data preparation
print("\n1. Loading MNIST dataset...")
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

data_loader = torch.utils.data.DataLoader(
    MNIST('~/data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True, drop_last=True
)
print(f"   Dataset size: {len(data_loader.dataset)} images")

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( 
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.model(x)

# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, x): return self.model(x)

print("\n2. Building models...")
from torchsummary import summary
discriminator = Discriminator().to(device)
generator = Generator().to(device)
print("\n   Discriminator architecture:")
summary(discriminator, torch.zeros(1,784).to(device))
print("\n   Generator architecture:")
summary(generator, torch.zeros(1,100).to(device))

# Helper functions
def noise(size):
    """Generate random noise for generator input"""
    n = torch.randn(size, 100)
    return n.to(device)

def discriminator_train_step(real_data, fake_data):
    """Train discriminator on real and fake data"""
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

def generator_train_step(fake_data):
    """Train generator to fool discriminator"""
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction, torch.ones(len(fake_data), 1).to(device))
    error.backward()
    g_optimizer.step()
    return error

# Setup training
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()
num_epochs = 200
log = Report(num_epochs)

# Training loop
print(f"\n3. Training for {num_epochs} epochs...")
print("   (This may take a while...)\n")

for epoch in range(num_epochs):
    N = len(data_loader)
    for i, (images, _) in enumerate(data_loader):
        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
    
    # Save sample images every 20 epochs
    if (epoch + 1) % 20 == 0:
        with torch.no_grad():
            z = torch.randn(64, 100).to(device)
            sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
            grid = make_grid(sample_images, nrow=8, normalize=True)
            save_image(grid, f'./figs/generated_digits_epoch_{epoch+1}.png')

# Plot training losses
log.plot_epochs(['d_loss', 'g_loss'])
import matplotlib.pyplot as plt
plt.savefig('./figs/training_losses.png', dpi=150, bbox_inches='tight')
plt.close()

# Generate final samples
print("\n4. Generating final samples...")
z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)

# Save the image grid to a file
save_image(grid, './figs/generated_digits_final.png')

# Optional: still display the image
show(grid.cpu().detach().permute(1, 2, 0), sz=5)

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print("\nGenerated files in ./figs/:")
print("  - generated_digits_final.png (final samples)")
print("  - training_losses.png (loss curves)")
print("  - generated_digits_epoch_*.png (intermediate samples)")
print("=" * 60)
