#!/usr/bin/env python
# coding: utf-8

"""
Face Generation using Conditional GAN (cGAN) - FIXED VERSION
Demonstrates controlled generation with gender conditioning
FIXES: Memory leaks, performance issues, label consistency, checkpointing
"""

import os
import torch
from torch import nn as nn
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torch_snippets import *
from torch_snippets.torch_loader import Report
from torchvision.utils import make_grid, save_image
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torchvision import transforms
import torchvision.utils as vutils
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create output directory for figures
os.makedirs('./figs', exist_ok=True)
os.makedirs('./checkpoints', exist_ok=True)

print("=" * 60)
print("CONDITIONAL GAN: Gender-Conditioned Face Generation (FIXED)")
print("=" * 60)

# Load image paths
print("\n1. Loading image paths...")
female_images = Glob('./females/*.jpg')
male_images = Glob('./males/*.jpg')
print(f"   Female images: {len(female_images)}")
print(f"   Male images: {len(male_images)}")

# Face detection
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

# Crop faces
print("\n2. Preprocessing images...")
if not os.path.exists('cropped_faces_female'):
    os.mkdir('cropped_faces_female')
if not os.path.exists('cropped_faces_male'):
    os.mkdir('cropped_faces_male')

if len(Glob('cropped_faces_female/*.jpg')) == 0:
    print("   Cropping female faces...")
    female_count = 0
    for i in range(len(female_images)):
        try:
            img = read(female_images[i], 1)
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray, 1.3, 5)
            for (x, y, w, h) in faces:
                img2 = img[y:(y+h), x:(x+w), :]
                cv2.imwrite('cropped_faces_female/'+str(i)+'.jpg', cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
                female_count += 1
                break
        except Exception as e:
            print(f"   Warning: Could not process female image {i}: {e}")
    print(f"   ✓ Cropped {female_count} female faces")

if len(Glob('cropped_faces_male/*.jpg')) == 0:
    print("   Cropping male faces...")
    male_count = 0
    for i in range(len(male_images)):
        try:
            img = read(male_images[i], 1)
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray, 1.3, 5)
            for (x, y, w, h) in faces:
                img2 = img[y:(y+h), x:(x+w), :]
                cv2.imwrite('cropped_faces_male/'+str(i)+'.jpg', cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
                male_count += 1
                break
        except Exception as e:
            print(f"   Warning: Could not process male image {i}: {e}")
    print(f"   ✓ Cropped {male_count} male faces")
else:
    print("   ✓ Using existing cropped faces")

# Data transformation
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Dataset with gender labels
class Faces(Dataset):
    def __init__(self, folders):
        super().__init__()
        self.folderfemale = folders[0]
        self.foldermale = folders[1]
        self.images = sorted(Glob(self.folderfemale)) + sorted(Glob(self.foldermale))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, ix):
        image_path = self.images[ix]
        image = Image.open(image_path)
        image = transform(image)
        # Female = 1, Male = 0
        gender = np.where('female' in str(image_path), 1, 0)
        return image, torch.tensor(gender).long()

# Load dataset - REDUCED num_workers to prevent issues
print("\n3. Loading dataset...")
ds = Faces(folders=['cropped_faces_female', 'cropped_faces_male'])
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
print(f"   Dataset size: {len(ds)} images")

# Weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Conditional Discriminator
class Discriminator(nn.Module):
    def __init__(self, emb_size=32):
        super(Discriminator, self).__init__()
        self.emb_size = 32
        self.label_embeddings = nn.Embedding(2, self.emb_size)
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*8, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten()
        )
        self.model2 = nn.Sequential(
            nn.Linear(288, 100),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )
        self.apply(weights_init)
    
    def forward(self, input, labels):
        x = self.model(input)
        y = self.label_embeddings(labels)
        input = torch.cat([x, y], 1)
        final_output = self.model2(input)
        return final_output

# Conditional Generator
class Generator(nn.Module):
    def __init__(self, emb_size=32):
        super(Generator, self).__init__()
        self.emb_size = emb_size
        self.label_embeddings = nn.Embedding(2, self.emb_size)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100+self.emb_size, 64*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)
    
    def forward(self, input_noise, labels):
        label_embeddings = self.label_embeddings(labels).view(len(labels), self.emb_size, 1, 1)
        input = torch.cat([input_noise, label_embeddings], 1)
        return self.model(input)

print("\n4. Building conditional models...")
discriminator = Discriminator().to(device)
generator = Generator().to(device)

print("\n   Conditional Discriminator architecture:")
summary(discriminator, torch.zeros(32, 3, 64, 64).to(device), torch.zeros(32).long().to(device))
print("\n   Conditional Generator architecture:")
summary(generator, torch.zeros(32, 100, 1, 1).to(device), torch.zeros(32).long().to(device))

# Helper functions
def noise(size):
    n = torch.randn(size, 100, 1, 1, device=device)
    return n

def discriminator_train_step(real_data, real_labels, fake_data, fake_labels):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data, real_labels)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1, device=device))
    error_real.backward()
    prediction_fake = discriminator(fake_data, fake_labels)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1, device=device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

def generator_train_step(fake_data, fake_labels):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data, fake_labels)
    error = loss(prediction, torch.ones(len(fake_data), 1, device=device))
    error.backward()
    g_optimizer.step()
    return error

# Setup training
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
fixed_noise = torch.randn(64, 100, 1, 1, device=device)
fixed_fake_labels = torch.LongTensor([0]*(len(fixed_noise)//2) + [1]*(len(fixed_noise)//2)).to(device)
n_epochs = 25
img_list = []

# Check for existing checkpoint
start_epoch = 0
checkpoint_path = './checkpoints/latest_checkpoint.pth'
if os.path.exists(checkpoint_path):
    print(f"\n📁 Found checkpoint! Loading from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
    d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"   ✓ Resuming from epoch {start_epoch}")

# Training loop
print(f"\n5. Training from epoch {start_epoch} to {n_epochs}...")
print("   (This may take a while...)\n")

log = Report(n_epochs)
for epoch in range(start_epoch, n_epochs):
    N = len(dataloader)
    for bx, (images, labels) in enumerate(dataloader):
        real_data, real_labels = images.to(device), labels.to(device)
        
        # Train Discriminator
        fake_labels = torch.LongTensor(np.random.randint(0, 2, len(real_data))).to(device)
        fake_data = generator(noise(len(real_data)), fake_labels)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, real_labels, fake_data, fake_labels)
        
        # Train Generator
        fake_labels = torch.LongTensor(np.random.randint(0, 2, len(real_data))).to(device)
        fake_data = generator(noise(len(real_data)), fake_labels)
        g_loss = generator_train_step(fake_data, fake_labels)
        
        pos = epoch + (1+bx)/N
        log.record(pos, d_loss=d_loss.detach(), g_loss=g_loss.detach(), end='\r')
        
        # CRITICAL FIX: Clear memory periodically
        if bx % 50 == 0:
            torch.cuda.empty_cache()
            gc.collect()
    
    log.report_avgs(epoch+1)
    
    # Generate and save samples - MOVED TO CPU IMMEDIATELY
    with torch.no_grad():
        fake = generator(fixed_noise, fixed_fake_labels).cpu()  # Move to CPU immediately
        grid = vutils.make_grid(fake, padding=2, normalize=True)
        
        # Save the image grid to a file
        save_image(grid, f'./figs/generated_faces_epoch_{epoch+1:02d}.png')
        
        imgs = grid.permute(1, 2, 0)
        img_list.append(imgs)
        
        # Only display every 5 epochs to save time
        if (epoch + 1) % 5 == 0:
            plt.figure(figsize=(10, 10))
            plt.imshow(imgs)
            plt.axis('off')
            plt.title(f'Epoch {epoch+1}')
            plt.savefig(f'./figs/display_epoch_{epoch+1:02d}.png', bbox_inches='tight')
            plt.close()
    
    # CRITICAL FIX: Save checkpoint regularly
    if (epoch + 1) % 5 == 0 or epoch == n_epochs - 1:
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
        }
        torch.save(checkpoint, checkpoint_path)
        torch.save(checkpoint, f'./checkpoints/checkpoint_epoch_{epoch+1}.pth')
        print(f"   💾 Checkpoint saved at epoch {epoch+1}")
    
    # Clear cache after each epoch
    torch.cuda.empty_cache()
    gc.collect()

# Save training loss plot
log.plot_epochs(['d_loss', 'g_loss'])
plt.savefig('./figs/training_losses.png', dpi=150, bbox_inches='tight')
plt.close()

# Generate final conditioned samples
print("\n6. Generating gender-specific samples...")

# FIXED: Generate female faces (label=1 in dataset)
female_noise = torch.randn(32, 100, 1, 1, device=device)
female_labels = torch.ones(32, dtype=torch.long, device=device)  # 1 = female (FIXED)
female_faces = generator(female_noise, female_labels).detach().cpu()
female_grid = vutils.make_grid(female_faces, nrow=8, normalize=True)
save_image(female_grid, './figs/generated_female_faces.png')

# FIXED: Generate male faces (label=0 in dataset)
male_noise = torch.randn(32, 100, 1, 1, device=device)
male_labels = torch.zeros(32, dtype=torch.long, device=device)  # 0 = male (FIXED)
male_faces = generator(male_noise, male_labels).detach().cpu()
male_grid = vutils.make_grid(male_faces, nrow=8, normalize=True)
save_image(male_grid, './figs/generated_male_faces.png')

# Display final results
fig, axes = plt.subplots(1, 2, figsize=(14, 7))
fig.suptitle('Conditional GAN: Gender-Controlled Face Generation', fontsize=16, fontweight='bold')

axes[0].imshow(female_grid.permute(1, 2, 0))
axes[0].set_title('Generated Female Faces', fontsize=14)
axes[0].axis('off')

axes[1].imshow(male_grid.permute(1, 2, 0))
axes[1].set_title('Generated Male Faces', fontsize=14)
axes[1].axis('off')

plt.tight_layout()
plt.savefig('./figs/gender_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Clear final memory
torch.cuda.empty_cache()
gc.collect()

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print("\nGenerated files in ./figs/:")
print("  - generated_female_faces.png (female-only samples)")
print("  - generated_male_faces.png (male-only samples)")
print("  - gender_comparison.png (side-by-side comparison)")
print("  - training_losses.png (loss curves)")
print("  - generated_faces_epoch_*.png (samples per epoch)")
print("\nCheckpoints saved in ./checkpoints/")
print("=" * 60)
