#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Crop and Resize Images for Training Dataset

This script processes images in training_data/:
1. Smart cropping (keeps subject centered with background)
2. Resizes to 512x512 pixels
3. Creates backup of originals

Multiple cropping strategies available:
- Center crop (default, keeps center square)
- Smart crop (focuses on detected subjects)
- Cover crop (fills 512x512, may crop edges)
- Contain crop (fits entire image, may add padding)
"""

import os
from pathlib import Path
from PIL import Image, ImageOps
import shutil

# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    input_dir = "./training_data"
    backup_dir = "./training_data_backup"
    target_size = (512, 512)
    
    # Cropping strategy: 'center', 'smart', 'cover', 'contain'
    crop_strategy = 'center'
    
    # Background padding color (for 'contain' mode)
    padding_color = (255, 255, 255)  # White
    
    # Quality for JPEG
    jpeg_quality = 95

config = Config()

# ============================================================================
# CROPPING STRATEGIES
# ============================================================================

def center_crop(image, target_size):
    """
    Center crop to square, keeping the middle part of the image.
    Best for: General purpose, keeps subject centered
    """
    width, height = image.size
    target_width, target_height = target_size
    
    # Calculate crop box to get center square
    if width > height:
        # Landscape - crop width
        left = (width - height) // 2
        top = 0
        right = left + height
        bottom = height
    else:
        # Portrait - crop height
        left = 0
        top = (height - width) // 2
        right = width
        bottom = top + width
    
    # Crop to square
    image = image.crop((left, top, right, bottom))
    
    # Resize to target size
    image = image.resize(target_size, Image.LANCZOS)
    
    return image


def smart_crop(image, target_size):
    """
    Smart crop using image entropy to focus on interesting regions.
    Best for: Automatically focusing on subjects
    """
    try:
        # Try to use smart crop (focuses on high-detail areas)
        width, height = image.size
        
        # First, make it square using entropy-based crop
        if width > height:
            # Landscape - crop width intelligently
            # Calculate entropy for different crop positions
            best_entropy = 0
            best_left = 0
            
            # Try different horizontal positions
            for left in range(0, width - height, max(1, (width - height) // 10)):
                crop_box = (left, 0, left + height, height)
                crop_img = image.crop(crop_box)
                
                # Calculate entropy (measure of information/detail)
                try:
                    entropy = crop_img.entropy()
                    if entropy > best_entropy:
                        best_entropy = entropy
                        best_left = left
                except:
                    pass
            
            image = image.crop((best_left, 0, best_left + height, height))
        
        elif height > width:
            # Portrait - crop height intelligently
            best_entropy = 0
            best_top = 0
            
            for top in range(0, height - width, max(1, (height - width) // 10)):
                crop_box = (0, top, width, top + width)
                crop_img = image.crop(crop_box)
                
                try:
                    entropy = crop_img.entropy()
                    if entropy > best_entropy:
                        best_entropy = entropy
                        best_top = top
                except:
                    pass
            
            image = image.crop((0, best_top, width, best_top + width))
        
        # Resize to target size
        image = image.resize(target_size, Image.LANCZOS)
        return image
        
    except Exception as e:
        print(f"    Smart crop failed, falling back to center crop: {e}")
        return center_crop(image, target_size)


def cover_crop(image, target_size):
    """
    Cover crop - fills entire target size, may crop edges.
    Best for: Ensuring no padding, subject fills frame
    """
    # Calculate resize dimensions to cover the target
    width, height = image.size
    target_width, target_height = target_size
    
    # Calculate scale to cover
    scale = max(target_width / width, target_height / height)
    
    new_width = int(width * scale)
    new_height = int(height * scale)
    
    # Resize
    image = image.resize((new_width, new_height), Image.LANCZOS)
    
    # Center crop to exact target size
    left = (new_width - target_width) // 2
    top = (new_height - target_height) // 2
    right = left + target_width
    bottom = top + target_height
    
    image = image.crop((left, top, right, bottom))
    
    return image


def contain_crop(image, target_size, padding_color=(255, 255, 255)):
    """
    Contain crop - fits entire image with padding if needed.
    Best for: Keeping entire subject visible with background
    """
    # Calculate resize dimensions to fit inside target
    width, height = image.size
    target_width, target_height = target_size
    
    # Calculate scale to fit
    scale = min(target_width / width, target_height / height)
    
    new_width = int(width * scale)
    new_height = int(height * scale)
    
    # Resize
    image = image.resize((new_width, new_height), Image.LANCZOS)
    
    # Create new image with padding
    result = Image.new('RGB', target_size, padding_color)
    
    # Paste resized image in center
    left = (target_width - new_width) // 2
    top = (target_height - new_height) // 2
    result.paste(image, (left, top))
    
    return result


def simple_resize(image, target_size):
    """
    Simple resize - may distort image.
    Best for: Quick processing, don't care about aspect ratio
    """
    return image.resize(target_size, Image.LANCZOS)

# ============================================================================
# MAIN PROCESSING
# ============================================================================

def process_image(input_path, output_path, strategy='center'):
    """Process a single image with the specified strategy."""
    
    try:
        # Open image
        image = Image.open(input_path)
        
        # Convert to RGB if needed (handles RGBA, grayscale, etc.)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Get original size
        orig_width, orig_height = image.size
        
        # Apply cropping strategy
        if strategy == 'center':
            processed = center_crop(image, config.target_size)
        elif strategy == 'smart':
            processed = smart_crop(image, config.target_size)
        elif strategy == 'cover':
            processed = cover_crop(image, config.target_size)
        elif strategy == 'contain':
            processed = contain_crop(image, config.target_size, config.padding_color)
        elif strategy == 'simple':
            processed = simple_resize(image, config.target_size)
        else:
            print(f"    Unknown strategy '{strategy}', using center crop")
            processed = center_crop(image, config.target_size)
        
        # Save processed image
        if output_path.suffix.lower() in ['.jpg', '.jpeg']:
            processed.save(output_path, 'JPEG', quality=config.jpeg_quality)
        else:
            processed.save(output_path)
        
        return True, orig_width, orig_height
        
    except Exception as e:
        print(f"    Error: {e}")
        return False, 0, 0


def process_directory():
    """Process all images in the directory."""
    
    input_dir = Path(config.input_dir)
    backup_dir = Path(config.backup_dir)
    
    # Check if input directory exists
    if not input_dir.exists():
        print(f"❌ Directory not found: {input_dir}")
        print("   Create it and add images first!")
        return
    
    # Find all images
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp']
    images = []
    for ext in image_extensions:
        images.extend(input_dir.glob(f"*{ext}"))
        images.extend(input_dir.glob(f"*{ext.upper()}"))
    
    if not images:
        print(f"❌ No images found in {input_dir}")
        return
    
    print(f"\n{'='*80}")
    print(f"Image Processing Tool")
    print(f"{'='*80}")
    print(f"Input directory: {input_dir}")
    print(f"Found {len(images)} images")
    print(f"Target size: {config.target_size[0]}x{config.target_size[1]}")
    print(f"Crop strategy: {config.crop_strategy}")
    
    # Create backup
    print(f"\n{'='*80}")
    print("Creating backup...")
    print(f"{'='*80}")
    
    if backup_dir.exists():
        print(f"⚠️  Backup directory already exists: {backup_dir}")
        response = input("Overwrite existing backup? (yes/no) [no]: ").strip().lower()
        if response != 'yes':
            print("Cancelled. Remove or rename existing backup first.")
            return
        shutil.rmtree(backup_dir)
    
    backup_dir.mkdir(exist_ok=True, parents=True)
    
    # Copy originals to backup
    for img_path in images:
        shutil.copy2(img_path, backup_dir / img_path.name)
    
    print(f"✓ Backed up {len(images)} images to {backup_dir}")
    
    # Process images
    print(f"\n{'='*80}")
    print("Processing images...")
    print(f"{'='*80}\n")
    
    processed_count = 0
    failed_count = 0
    
    for img_path in sorted(images):
        print(f"Processing: {img_path.name}")
        
        success, orig_width, orig_height = process_image(
            img_path, 
            img_path,  # Overwrite original (backup exists)
            config.crop_strategy
        )
        
        if success:
            processed_count += 1
            new_img = Image.open(img_path)
            new_width, new_height = new_img.size
            print(f"  ✓ {orig_width}x{orig_height} → {new_width}x{new_height}")
        else:
            failed_count += 1
    
    # Summary
    print(f"\n{'='*80}")
    print("Summary")
    print(f"{'='*80}")
    print(f"✓ Processed: {processed_count}")
    if failed_count > 0:
        print(f"✗ Failed: {failed_count}")
    print(f"📁 Originals backed up to: {backup_dir}")
    print(f"📁 Processed images in: {input_dir}")
    print(f"{'='*80}\n")
    
    # Show sample
    print("Sample processed image:")
    if images:
        sample = Image.open(images[0])
        print(f"  {images[0].name}: {sample.size[0]}x{sample.size[1]} pixels")


# ============================================================================
# INTERACTIVE MODE
# ============================================================================

def interactive_mode():
    """Interactive mode to choose settings."""
    
    print("="*80)
    print("Image Crop & Resize Tool")
    print("="*80)
    print("\nThis tool will process all images in './training_data/'")
    print("Original images will be backed up to './training_data_backup/'\n")
    
    # Choose strategy
    print("Choose cropping strategy:")
    print("1. Center crop (default) - Keeps center of image")
    print("2. Smart crop - Focuses on high-detail regions")
    print("3. Cover crop - Fills frame, may crop edges")
    print("4. Contain crop - Keeps entire image, adds padding if needed")
    print("5. Simple resize - Stretches to fit (may distort)")
    
    choice = input("\nEnter choice (1-5) [default: 1]: ").strip()
    
    strategy_map = {
        '1': 'center',
        '2': 'smart',
        '3': 'cover',
        '4': 'contain',
        '5': 'simple',
        '': 'center'
    }
    
    config.crop_strategy = strategy_map.get(choice, 'center')
    
    # Target size
    size_input = input(f"\nTarget size [default: 512]: ").strip()
    if size_input:
        try:
            size = int(size_input)
            config.target_size = (size, size)
        except:
            print("Invalid size, using 512")
    
    # Confirm
    print(f"\nSettings:")
    print(f"  Strategy: {config.crop_strategy}")
    print(f"  Target size: {config.target_size[0]}x{config.target_size[1]}")
    
    confirm = input("\nProceed? (yes/no) [yes]: ").strip().lower()
    
    if confirm == 'no':
        print("Cancelled.")
        return
    
    # Process
    process_directory()


# ============================================================================
# PREVIEW MODE
# ============================================================================

def preview_strategies():
    """Show preview of different cropping strategies."""
    
    input_dir = Path(config.input_dir)
    
    # Find first image
    image_extensions = ['.jpg', '.jpeg', '.png']
    images = []
    for ext in image_extensions:
        images.extend(input_dir.glob(f"*{ext}"))
    
    if not images:
        print("No images found for preview!")
        return
    
    sample_path = images[0]
    print(f"\nGenerating preview using: {sample_path.name}")
    
    try:
        import matplotlib.pyplot as plt
        
        original = Image.open(sample_path)
        if original.mode != 'RGB':
            original = original.convert('RGB')
        
        # Generate all strategies
        strategies = ['center', 'smart', 'cover', 'contain']
        results = {}
        
        for strategy in strategies:
            if strategy == 'center':
                results[strategy] = center_crop(original.copy(), (512, 512))
            elif strategy == 'smart':
                results[strategy] = smart_crop(original.copy(), (512, 512))
            elif strategy == 'cover':
                results[strategy] = cover_crop(original.copy(), (512, 512))
            elif strategy == 'contain':
                results[strategy] = contain_crop(original.copy(), (512, 512))
        
        # Display
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        # Original
        axes[0].imshow(original)
        axes[0].set_title(f'Original\n{original.size[0]}x{original.size[1]}')
        axes[0].axis('off')
        
        # Strategies
        for idx, (strategy, img) in enumerate(results.items(), 1):
            axes[idx].imshow(img)
            axes[idx].set_title(f'{strategy.capitalize()} Crop\n512x512')
            axes[idx].axis('off')
        
        # Hide last subplot
        axes[5].axis('off')
        
        plt.tight_layout()
        plt.savefig('crop_preview.png', dpi=150, bbox_inches='tight')
        print("\n✓ Preview saved to: crop_preview.png")
        plt.show()
        
    except ImportError:
        print("matplotlib not installed. Install with: pip install matplotlib")
    except Exception as e:
        print(f"Error generating preview: {e}")


# ============================================================================
# MAIN
# ============================================================================

def main():
    import sys
    
    if len(sys.argv) > 1:
        # Command line mode
        if sys.argv[1] == 'preview':
            preview_strategies()
        elif sys.argv[1] == 'auto':
            # Auto mode with default settings
            process_directory()
        else:
            print("Usage:")
            print("  python crop_resize_images.py          # Interactive mode")
            print("  python crop_resize_images.py auto     # Auto mode (center crop)")
            print("  python crop_resize_images.py preview  # Preview strategies")
    else:
        # Interactive mode
        interactive_mode()


if __name__ == "__main__":
    main()
