#!/usr/bin/env python
# coding: utf-8

"""
Face Generation using DCGAN (Deep Convolutional GAN)
Demonstrates convolutional architecture for realistic face generation
"""

import os
import torch
from torch import nn as nn
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torch_snippets import *
from torch_snippets.torch_loader import Report
from torchvision import transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import cv2
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create output directory for figures
os.makedirs('./figs', exist_ok=True)

print("=" * 60)
print("DCGAN: Face Generation with Deep Convolutional GAN")
print("=" * 60)

# Face detection setup
print("\n1. Setting up face detection...")
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

# Crop faces from images
print("\n2. Preprocessing images...")
if not os.path.exists('cropped_faces'):
    os.mkdir('cropped_faces')
    print("   Cropping faces from images...")
    images = Glob('./females/*.jpg') + Glob('./males/*.jpg')
    cropped_count = 0
    for i in range(len(images)):
        try:
            img = read(images[i], 1)
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray, 1.3, 5)
            for (x, y, w, h) in faces:
                img2 = img[y:(y+h), x:(x+w), :]
                cv2.imwrite('cropped_faces/'+str(i)+'.jpg', cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
                cropped_count += 1
                break  # Only take first face
        except Exception as e:
            print(f"   Warning: Could not process image {i}: {e}")
    print(f"   ✓ Cropped {cropped_count} faces")
else:
    print("   ✓ Using existing cropped faces")

# Data transformation
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Dataset class
class Faces(Dataset):
    def __init__(self, folder):
        super().__init__()
        self.folder = folder
        self.images = sorted(Glob(folder))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, ix):
        image_path = self.images[ix]
        image = Image.open(image_path)
        image = transform(image)
        return image

# Load dataset
print("\n3. Loading dataset...")
ds = Faces(folder='cropped_faces/')
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)
print(f"   Dataset size: {len(ds)} images")

# Weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Discriminator architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        self.apply(weights_init)
    
    def forward(self, input):
        return self.model(input)

# Generator architecture
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 64*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)
    
    def forward(self, input):
        return self.model(input)

print("\n4. Building models...")
discriminator = Discriminator().to(device)
generator = Generator().to(device)

print("\n   Discriminator architecture:")
summary(discriminator, torch.zeros(1, 3, 64, 64).to(device))
print("\n   Generator architecture:")
summary(generator, torch.zeros(1, 100, 1, 1).to(device))

# Training functions
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real.squeeze(), torch.ones(len(real_data)).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake.squeeze(), torch.zeros(len(fake_data)).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction.squeeze(), torch.ones(len(fake_data)).to(device))
    error.backward()
    g_optimizer.step()
    return error

# Setup training
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
print(f"\n5. Training for 25 epochs...")
print("   (This may take a while...)\n")

log = Report(25)
for epoch in range(25):
    N = len(dataloader)
    for i, images in enumerate(dataloader):
        real_data = images.to(device)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
    
    # Save intermediate results every 5 epochs
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            generator.eval()
            noise = torch.randn(64, 100, 1, 1, device=device)
            sample_images = generator(noise).detach().cpu()
            grid = vutils.make_grid(sample_images, nrow=8, normalize=True)
            vutils.save_image(grid, f'./figs/generated_faces_epoch_{epoch+1}.png')
            generator.train()

# Save training loss plot
log.plot_epochs(['d_loss', 'g_loss'])
plt.savefig('./figs/training_losses.png', dpi=150, bbox_inches='tight')
plt.close()

# Generate final samples
print("\n6. Generating final samples...")
from torchvision.utils import save_image

generator.eval()
noise = torch.randn(64, 100, 1, 1, device=device)
sample_images = generator(noise).detach().cpu()
grid = vutils.make_grid(sample_images, nrow=8, normalize=True)

# Save the image grid to a file
save_image(grid, './figs/generated_faces_final.png')

# Display the image
show(grid.cpu().detach().permute(1, 2, 0), sz=10, title='Generated Faces')

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print("\nGenerated files in ./figs/:")
print("  - generated_faces_final.png (final samples)")
print("  - training_losses.png (loss curves)")
print("  - generated_faces_epoch_*.png (intermediate samples)")
print("=" * 60)
