#!/usr/bin/env python
# coding: utf-8

from torch_snippets import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch_snippets.torch_loader import Report
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Create output directory
if not os.path.exists('figs'):
    os.mkdir('figs')
    print("Created 'figs' directory")

# Load MNIST dataset
train_dataset = datasets.MNIST(root='MNIST/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='MNIST/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        # Encoder layers
        self.d1 = nn.Linear(x_dim, h_dim1)
        self.d2 = nn.Linear(h_dim1, h_dim2)
        self.d31 = nn.Linear(h_dim2, z_dim)  # mean
        self.d32 = nn.Linear(h_dim2, z_dim)  # log variance
        
        # Decoder layers
        self.d4 = nn.Linear(z_dim, h_dim2)
        self.d5 = nn.Linear(h_dim2, h_dim1)
        self.d6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.d1(x))
        h = F.relu(self.d2(h))
        return self.d31(h), self.d32(h)  # mean, log_var
    
    def sampling(self, mean, log_var):
        """Reparameterization trick: z = μ + σ * ε"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)
    
    def decoder(self, z):
        h = F.relu(self.d4(z))
        h = F.relu(self.d5(h))
        return F.sigmoid(self.d6(h))
    
    def forward(self, x):
        mean, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mean, log_var)
        return self.decoder(z), mean, log_var


def loss_function(recon_x, x, mean, log_var):
    """VAE loss = Reconstruction loss + KL divergence"""
    RECON = F.mse_loss(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return RECON + KLD, RECON, KLD


def train_batch(data, model, optimizer, loss_function):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, mean, log_var = model(data)
    loss, mse, kld = loss_function(recon_batch, data, mean, log_var)
    loss.backward()
    optimizer.step()
    return loss, mse, kld, log_var.mean(), mean.mean()


@torch.no_grad()
def validate_batch(data, model, loss_function):
    model.eval()
    data = data.to(device)
    recon, mean, log_var = model(data)
    loss, mse, kld = loss_function(recon, data, mean, log_var)
    return loss, mse, kld, log_var.mean(), mean.mean()


def save_generated_images(vae, epoch, z_dim=50, n_samples=64):
    """Generate and save images from random latent vectors"""
    vae.eval()
    with torch.no_grad():
        # Sample from N(0, I)
        z = torch.randn(n_samples, z_dim).to(device)
        
        # Decode
        sample = vae.decoder(z).to(device)
        
        # Create grid
        images = make_grid(sample.view(n_samples, 1, 28, 28), nrow=8).permute(1, 2, 0)
        
        # Save figure
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(images.cpu().numpy(), cmap='gray')
        ax.axis('off')
        ax.set_title(f'VAE Generated Samples - Epoch {epoch}', fontsize=14)
        
        plt.tight_layout()
        filepath = f'./figs/vae_generated_epoch_{epoch:03d}.png'
        plt.savefig(filepath, bbox_inches='tight', dpi=150)
        plt.close(fig)
        print(f"Saved: {filepath}")


def save_reconstruction_comparison(vae, test_loader, epoch):
    """Save comparison of original vs reconstructed images"""
    vae.eval()
    with torch.no_grad():
        # Get a batch of test data
        data, _ = next(iter(test_loader))
        data = data.to(device)
        
        # Get reconstruction
        recon, _, _ = vae(data)
        
        # Take first 8 samples
        n_samples = 8
        original = data[:n_samples]
        reconstructed = recon[:n_samples].view(-1, 1, 28, 28)
        
        # Create comparison grid
        comparison = torch.cat([original, reconstructed])
        grid = make_grid(comparison, nrow=n_samples).permute(1, 2, 0)
        
        # Save figure
        fig, ax = plt.subplots(figsize=(12, 3))
        ax.imshow(grid.cpu().numpy(), cmap='gray')
        ax.axis('off')
        ax.set_title(f'Top: Original, Bottom: Reconstructed - Epoch {epoch}', fontsize=12)
        
        plt.tight_layout()
        filepath = f'./figs/vae_reconstruction_epoch_{epoch:03d}.png'
        plt.savefig(filepath, bbox_inches='tight', dpi=150)
        plt.close(fig)
        print(f"Saved: {filepath}")


# Initialize model
vae = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=50).to(device)
optimizer = optim.AdamW(vae.parameters(), lr=1e-3)

print("\nVAE Architecture:")
print(f"- Latent dimension: 50")
print(f"- Encoder: 784 → 512 → 256 → 50")
print(f"- Decoder: 50 → 256 → 512 → 784")

# Training loop
n_epochs = 10
log = Report(n_epochs)

print(f"\nTraining for {n_epochs} epochs...")

for epoch in range(n_epochs):
    # Training
    N = len(train_loader)
    for batch_idx, (data, _) in enumerate(train_loader):
        loss, recon, kld, log_var, mean = train_batch(data, vae, optimizer, loss_function)
        pos = epoch + (1 + batch_idx) / N
        log.record(pos, train_loss=loss, train_kld=kld, train_recon=recon, 
                  train_log_var=log_var, train_mean=mean, end='\r')

    # Validation
    N = len(test_loader)
    for batch_idx, (data, _) in enumerate(test_loader):
        loss, recon, kld, log_var, mean = validate_batch(data, vae, loss_function)
        pos = epoch + (1 + batch_idx) / N
        log.record(pos, val_loss=loss, val_kld=kld, val_recon=recon, 
                  val_log_var=log_var, val_mean=mean, end='\r')

    log.report_avgs(epoch + 1)
    
    # Save generated images every epoch
    print(f"\nEpoch {epoch + 1}: Saving generated images...")
    save_generated_images(vae, epoch + 1, z_dim=50, n_samples=64)
    save_reconstruction_comparison(vae, test_loader, epoch + 1)

# Save training plots
print("\nSaving training plots...")

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Total loss
log.plot_epochs(['train_loss', 'val_loss'], ax=axes[0, 0])
axes[0, 0].set_title('Total Loss (Reconstruction + KL)')
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: KL Divergence
log.plot_epochs(['train_kld', 'val_kld'], ax=axes[0, 1])
axes[0, 1].set_title('KL Divergence')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Reconstruction Loss
log.plot_epochs(['train_recon', 'val_recon'], ax=axes[1, 0])
axes[1, 0].set_title('Reconstruction Loss')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Mean and Log Variance
log.plot_epochs(['train_mean', 'train_log_var'], ax=axes[1, 1])
axes[1, 1].set_title('Latent Statistics (Mean & Log Variance)')
axes[1, 1].set_ylabel('Value')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./figs/vae_training_curves.png', bbox_inches='tight', dpi=150)
plt.close(fig)
print("Saved: ./figs/vae_training_curves.png")

# Save final model
torch.save(vae.state_dict(), './figs/vae_model_final.pth')
print("Saved: ./figs/vae_model_final.pth")

print("\n=== Training Complete ===")
print(f"All results saved to: ./figs/")
print(f"- Generated images: vae_generated_epoch_XXX.png")
print(f"- Reconstructions: vae_reconstruction_epoch_XXX.png")
print(f"- Training curves: vae_training_curves.png")
print(f"- Model weights: vae_model_final.pth")
