#!/usr/bin/env python
# coding: utf-8

import torch
from torch import nn as nn
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torch_snippets import *
from torch_snippets.torch_loader import Report
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg16
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import cv2
from random_warp import get_training_data
import logging
logging.getLogger().setLevel(logging.WARNING)

#is GPU available?
gpu = torch.cuda.is_available()

#defining device where to to the computation
device = torch.device(0) if gpu else torch.device('cpu')

face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

def crop_face(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)
    if(len(faces)>0):
        for (x,y,w,h) in faces:
            img2 = img[y:(y+h),x:(x+w),:]
        img2 = cv2.resize(img2,(256,256))
        return img2, True
    else:
        return img, False

if not os.path.exists('cropped_faces_personA'):
    os.mkdir('cropped_faces_personA')
    
if not os.path.exists('cropped_faces_personB'):
    os.mkdir('cropped_faces_personB')

def crop_images(folder):
    images = Glob(folder+'/*.jpg', silent=True)
    for i in range(len(images)):
        img = read(images[i],1)
        img2, face_detected = crop_face(img)
        if(face_detected==False):
            continue
        else:
            cv2.imwrite('cropped_faces_'+folder+'/'+str(i)+'.jpg',cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))

crop_images('personA')
crop_images('personB')

class ImageDataset(Dataset):
    def __init__(self, items_A, items_B):
        self.items_A = np.concatenate([read(f,1)[None] for f in items_A])/255.
        self.items_B = np.concatenate([read(f,1)[None] for f in items_B])/255.
        # REMOVED: Color shift normalization that was causing color artifacts
        # self.items_A += self.items_B.mean(axis=(0, 1, 2)) - self.items_A.mean(axis=(0, 1, 2))

    def __len__(self):
        return min(len(self.items_A), len(self.items_B))
    
    def __getitem__(self, ix):
        a = choose(self.items_A, verbose=False)
        b = choose(self.items_B, verbose=False)
        return a, b

    def collate_fn(self, batch):
        imsA, imsB = list(zip(*batch))
        imsA, targetA = get_training_data(imsA, len(imsA))
        imsB, targetB = get_training_data(imsB, len(imsB))
        imsA, imsB, targetA, targetB = [torch.Tensor(i).permute(0,3,1,2).to(device) for i in [imsA, imsB, targetA, targetB]]
        return imsA, imsB, targetA, targetB


# ==================== Loss Functions ====================

class VGGPerceptualLoss(nn.Module):
    """Perceptual loss using VGG16 features to preserve semantic content"""
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = vgg16(weights=True).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        # VGG expects ImageNet normalized input
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device))
    
    def forward(self, pred, target):
        # Normalize for VGG
        pred_norm = (pred - self.mean) / self.std
        target_norm = (target - self.mean) / self.std
        return nn.functional.l1_loss(self.vgg(pred_norm), self.vgg(target_norm))


def color_consistency_loss(pred, target):
    """Penalize mean color differences across the image"""
    pred_mean = pred.mean(dim=[2, 3])
    target_mean = target.mean(dim=[2, 3])
    return nn.functional.mse_loss(pred_mean, target_mean)


# ==================== U-Net Components ====================

class AttentionGate(nn.Module):
    """Attention gate to control information flow through skip connections"""
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(num_groups=min(8, F_int), num_channels=F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(num_groups=min(8, F_int), num_channels=F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(num_groups=1, num_channels=1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        # g: gating signal from coarser scale
        # x: skip connection from encoder
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class ConvBlock(nn.Module):
    """Convolutional block with Group Normalization for better color preservation"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding, bias=False),
            nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class DownBlock(nn.Module):
    """Downsampling block for encoder"""
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        skip = self.conv(x)
        x = self.pool(skip)
        return x, skip


class UpBlock(nn.Module):
    """Upsampling block with attention gate for decoder"""
    def __init__(self, in_channels, out_channels, use_attention=True):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.use_attention = use_attention
        
        if use_attention:
            self.attention = AttentionGate(F_g=out_channels, F_l=out_channels, F_int=out_channels//2)
        
        self.conv = ConvBlock(out_channels * 2, out_channels)
    
    def forward(self, x, skip):
        x = self.up(x)
        
        if self.use_attention:
            skip = self.attention(g=x, x=skip)
        
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x


class UNetEncoder(nn.Module):
    """Shared U-Net style encoder"""
    def __init__(self):
        super(UNetEncoder, self).__init__()
        
        # Encoder path
        self.down1 = DownBlock(3, 64)      # 64x64 -> 32x32
        self.down2 = DownBlock(64, 128)    # 32x32 -> 16x16
        self.down3 = DownBlock(128, 256)   # 16x16 -> 8x8
        self.down4 = DownBlock(256, 512)   # 8x8 -> 4x4
        
        # Bottleneck
        self.bottleneck = ConvBlock(512, 1024)
    
    def forward(self, x):
        # Encoder with skip connections
        x, skip1 = self.down1(x)  # 64 channels
        x, skip2 = self.down2(x)  # 128 channels
        x, skip3 = self.down3(x)  # 256 channels
        x, skip4 = self.down4(x)  # 512 channels
        
        x = self.bottleneck(x)     # 1024 channels
        
        return x, [skip1, skip2, skip3, skip4]


class UNetDecoder(nn.Module):
    """Person-specific U-Net decoder with attention gates"""
    def __init__(self, use_attention=True):
        super(UNetDecoder, self).__init__()
        
        # Decoder path with attention
        self.up1 = UpBlock(1024, 512, use_attention=use_attention)  # 4x4 -> 8x8
        self.up2 = UpBlock(512, 256, use_attention=use_attention)   # 8x8 -> 16x16
        self.up3 = UpBlock(256, 128, use_attention=use_attention)   # 16x16 -> 32x32
        self.up4 = UpBlock(128, 64, use_attention=use_attention)    # 32x32 -> 64x64
        
        # Final output layer
        self.final = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x, skips):
        # Unpack skip connections (encoder returns [skip1, skip2, skip3, skip4])
        skip1, skip2, skip3, skip4 = skips
        
        # Use in reverse order: start with deepest skip
        x = self.up1(x, skip4)  # 512 channels
        x = self.up2(x, skip3)  # 256 channels
        x = self.up3(x, skip2)  # 128 channels
        x = self.up4(x, skip1)  # 64 channels
        
        x = self.final(x)
        return x


class UNetAutoencoder(nn.Module):
    """U-Net based autoencoder for deepfakes with attention gates"""
    def __init__(self, use_attention=True):
        super(UNetAutoencoder, self).__init__()
        
        self.encoder = UNetEncoder()
        self.decoder_A = UNetDecoder(use_attention=use_attention)
        self.decoder_B = UNetDecoder(use_attention=use_attention)
    
    def forward(self, x, select='A'):
        # Encode
        bottleneck, skips = self.encoder(x)
        
        # Decode with appropriate decoder
        if select == 'A':
            out = self.decoder_A(bottleneck, skips)
        else:
            out = self.decoder_B(bottleneck, skips)
        
        return out


# ==================== Training Functions ====================

def train_batch(model, data, criterion_l1, criterion_perceptual, optimizers, 
                perceptual_weight=0.1, color_weight=0.5):
    optA, optB = optimizers
    optA.zero_grad()
    optB.zero_grad()
    
    imgA, imgB, targetA, targetB = data
    _imgA, _imgB = model(imgA, 'A'), model(imgB, 'B')

    # Combine multiple loss components
    lossA_l1 = criterion_l1(_imgA, targetA)
    lossA_perceptual = criterion_perceptual(_imgA, targetA)
    lossA_color = color_consistency_loss(_imgA, targetA)
    lossA = lossA_l1 + perceptual_weight * lossA_perceptual + color_weight * lossA_color
    
    lossB_l1 = criterion_l1(_imgB, targetB)
    lossB_perceptual = criterion_perceptual(_imgB, targetB)
    lossB_color = color_consistency_loss(_imgB, targetB)
    lossB = lossB_l1 + perceptual_weight * lossB_perceptual + color_weight * lossB_color

    lossA.backward()
    lossB.backward()

    optA.step()
    optB.step()

    return lossA.item(), lossB.item()


# ==================== Main Training ====================

print("Initializing model with color-preserving improvements...")
print("- Removed problematic color shift normalization")
print("- Replaced InstanceNorm with GroupNorm for better color preservation")
print("- Added perceptual loss for semantic content")
print("- Added color consistency loss to prevent color artifacts\n")

# Initialize model
model = UNetAutoencoder(use_attention=True).to(device)

# Dataset and dataloader
dataset = ImageDataset(Glob('cropped_faces_personA',silent=True), 
                       Glob('cropped_faces_personB',silent=True))
dataloader = DataLoader(dataset, 32, collate_fn=dataset.collate_fn)

# Optimizers for each decoder path + shared encoder
optimizers = optim.Adam([{'params': model.encoder.parameters()},
                          {'params': model.decoder_A.parameters()}],
                        lr=5e-5, betas=(0.5, 0.999)), \
             optim.Adam([{'params': model.encoder.parameters()},
                          {'params': model.decoder_B.parameters()}], 
                        lr=5e-5, betas=(0.5, 0.999))

# Loss functions
criterion_l1 = nn.L1Loss()
criterion_perceptual = VGGPerceptualLoss()

# Training loop
n_epochs = 10000  # Increase to 1000 for full training
log = Report(n_epochs)

if not os.path.exists('checkpoint'):
    os.mkdir('checkpoint')

if not os.path.exists('figs'):
    os.mkdir('figs')

for ex in range(n_epochs):
    N = len(dataloader)
    for bx, data in enumerate(dataloader):
        lossA, lossB = train_batch(model, data, criterion_l1, criterion_perceptual, 
                                   optimizers, perceptual_weight=0.1, color_weight=0.5)
        log.record(ex+(1+bx)/N, lossA=lossA, lossB=lossB, end='\r')

    log.report_avgs(ex+1)
    
    # Save checkpoint
    if (ex+1) % 100 == 0:
        state = {
            'state': model.state_dict(),
            'epoch': ex
        }
        torch.save(state, './checkpoint/unet_autoencoder.pth')

    # Visualize results
    if (ex+1) % 100 == 0:
        model.eval()
        with torch.no_grad():
            bs = 5
            a, b, A, B = data
            
            # A to B transformation
            print('Saving A to B transformation...')
            _a = model(a[:bs], 'A')
            _b = model(a[:bs], 'B')
            
            # Create figure manually
            fig, axes = plt.subplots(3, bs, figsize=(bs*2, 6))
            for i in range(bs):
                # Row 1: Original target A
                axes[0, i].imshow(A[i].cpu().permute(1, 2, 0).clamp(0, 1))
                axes[0, i].axis('off')
                if i == 0:
                    axes[0, i].set_title('Target A', fontsize=10)
                
                # Row 2: Reconstructed A (A->A)
                axes[1, i].imshow(_a[i].cpu().permute(1, 2, 0).clamp(0, 1))
                axes[1, i].axis('off')
                if i == 0:
                    axes[1, i].set_title('A→A (reconstruct)', fontsize=10)
                
                # Row 3: Face swap A->B
                axes[2, i].imshow(_b[i].cpu().permute(1, 2, 0).clamp(0, 1))
                axes[2, i].axis('off')
                if i == 0:
                    axes[2, i].set_title('A→B (swap)', fontsize=10)
            
            plt.suptitle(f'Epoch {ex+1}: Person A to B', fontsize=12)
            plt.tight_layout()
            plt.savefig(f'./figs/epoch_{ex+1}_A_to_B.png', bbox_inches='tight', dpi=150)
            plt.close(fig)

            # B to A transformation
            print('Saving B to A transformation...')
            _a = model(b[:bs], 'A')
            _b = model(b[:bs], 'B')
            
            # Create figure manually
            fig, axes = plt.subplots(3, bs, figsize=(bs*2, 6))
            for i in range(bs):
                # Row 1: Original target B
                axes[0, i].imshow(B[i].cpu().permute(1, 2, 0).clamp(0, 1))
                axes[0, i].axis('off')
                if i == 0:
                    axes[0, i].set_title('Target B', fontsize=10)
                
                # Row 2: Face swap B->A
                axes[1, i].imshow(_a[i].cpu().permute(1, 2, 0).clamp(0, 1))
                axes[1, i].axis('off')
                if i == 0:
                    axes[1, i].set_title('B→A (swap)', fontsize=10)
                
                # Row 3: Reconstructed B (B->B)
                axes[2, i].imshow(_b[i].cpu().permute(1, 2, 0).clamp(0, 1))
                axes[2, i].axis('off')
                if i == 0:
                    axes[2, i].set_title('B→B (reconstruct)', fontsize=10)
            
            plt.suptitle(f'Epoch {ex+1}: Person B to A', fontsize=12)
            plt.tight_layout()
            plt.savefig(f'./figs/epoch_{ex+1}_B_to_A.png', bbox_inches='tight', dpi=150)
            plt.close(fig)
            
        model.train()

log.plot_epochs()
plt.savefig('./figs/training_loss.png', bbox_inches='tight', dpi=150)

print("\n=== Training Complete ===")
print("Color artifact fixes applied:")
print("1. ✓ Removed color shift normalization")
print("2. ✓ GroupNorm instead of InstanceNorm (preserves color relationships)")
print("3. ✓ Perceptual loss (weight=0.1) for semantic content")
print("4. ✓ Color consistency loss (weight=0.5) prevents channel imbalance")
print("5. ✓ Attention gates control feature flow")
print("\nResults saved in ./figs/")
