#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Create a small training dataset with cats and dogs for fine-tuning demo.

This script downloads 20 images (10 cats, 10 dogs) from a public dataset
and creates the required directory structure with captions.
"""

import os
import json
from pathlib import Path
from PIL import Image
import random

# Try multiple dataset sources in order of preference
def create_dataset_from_huggingface():
    """Download from Hugging Face datasets."""
    try:
        from datasets import load_dataset
        print("Downloading images from Hugging Face datasets...")
        
        # Load Oxford-IIIT Pet dataset (high quality cat/dog images)
        dataset = load_dataset("pcuenq/oxford-pets", split="train", trust_remote_code=True)
        
        output_dir = Path("./training_data")
        output_dir.mkdir(exist_ok=True, parents=True)
        
        captions = {}
        
        # Get 10 cats and 10 dogs
        cats_count = 0
        dogs_count = 0
        
        # Shuffle dataset
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        
        for idx in indices:
            if cats_count >= 10 and dogs_count >= 10:
                break
            
            sample = dataset[idx]
            image = sample['image']
            label = sample['label']  # 0 = cat, 1 = dog (approximately)
            
            # Convert to RGB if needed
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Resize to reasonable size
            image = image.resize((512, 512), Image.LANCZOS)
            
            # Try to identify if it's a cat or dog based on label
            # Oxford pets dataset has specific breed labels
            # Even numbered labels are typically cats, odd are dogs (roughly)
            is_cat = label % 2 == 0
            
            if is_cat and cats_count < 10:
                filename = f"cat_{cats_count+1:02d}.jpg"
                caption = f"a photo of a cat"
                cats_count += 1
            elif not is_cat and dogs_count < 10:
                filename = f"dog_{dogs_count+1:02d}.jpg"
                caption = f"a photo of a dog"
                dogs_count += 1
            else:
                continue
            
            # Save image
            image_path = output_dir / filename
            image.save(image_path, "JPEG", quality=95)
            captions[filename] = caption
            
            print(f"  Saved: {filename}")
        
        # Save captions
        with open(output_dir / "captions.json", 'w') as f:
            json.dump(captions, f, indent=2)
        
        print(f"\n✓ Successfully created dataset with {len(captions)} images")
        print(f"  Location: {output_dir}")
        return True
        
    except Exception as e:
        print(f"✗ Hugging Face method failed: {e}")
        return False


def create_dataset_from_cifar():
    """Download from CIFAR-10 dataset."""
    try:
        import torchvision
        import torchvision.transforms as transforms
        import torch
        
        print("Downloading images from CIFAR-10 dataset...")
        
        # Download CIFAR-10
        dataset = torchvision.datasets.CIFAR10(
            root='./cifar_data',
            train=True,
            download=True
        )
        
        output_dir = Path("./training_data")
        output_dir.mkdir(exist_ok=True, parents=True)
        
        # CIFAR-10 classes: 3=cat, 5=dog
        cat_class = 3
        dog_class = 5
        
        captions = {}
        cats_count = 0
        dogs_count = 0
        
        # Shuffle indices
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        
        for idx in indices:
            if cats_count >= 10 and dogs_count >= 10:
                break
            
            image, label = dataset[idx]
            
            # Check if it's a cat or dog
            if label == cat_class and cats_count < 10:
                filename = f"cat_{cats_count+1:02d}.jpg"
                caption = f"a photo of a cat"
                cats_count += 1
            elif label == dog_class and dogs_count < 10:
                filename = f"dog_{dogs_count+1:02d}.jpg"
                caption = f"a photo of a dog"
                dogs_count += 1
            else:
                continue
            
            # Resize CIFAR image (32x32 -> 512x512)
            image = image.resize((512, 512), Image.LANCZOS)
            
            # Save image
            image_path = output_dir / filename
            image.save(image_path, "JPEG", quality=95)
            captions[filename] = caption
            
            print(f"  Saved: {filename}")
        
        # Save captions
        with open(output_dir / "captions.json", 'w') as f:
            json.dump(captions, f, indent=2)
        
        print(f"\n✓ Successfully created dataset with {len(captions)} images")
        print(f"  Location: {output_dir}")
        return True
        
    except Exception as e:
        print(f"✗ CIFAR-10 method failed: {e}")
        return False


def create_synthetic_dataset():
    """Create synthetic colored images as placeholders."""
    print("Creating synthetic placeholder images...")
    print("(You should replace these with real images for actual training)")
    
    output_dir = Path("./training_data")
    output_dir.mkdir(exist_ok=True, parents=True)
    
    captions = {}
    
    # Create colored placeholder images
    for i in range(10):
        # Cat images (orange-ish)
        cat_img = Image.new('RGB', (512, 512), color=(255, 150, 50))
        filename = f"cat_{i+1:02d}.jpg"
        cat_img.save(output_dir / filename, "JPEG")
        captions[filename] = f"a photo of a cat"
        print(f"  Created placeholder: {filename}")
    
    for i in range(10):
        # Dog images (brown-ish)
        dog_img = Image.new('RGB', (512, 512), color=(150, 100, 50))
        filename = f"dog_{i+1:02d}.jpg"
        dog_img.save(output_dir / filename, "JPEG")
        captions[filename] = f"a photo of a dog"
        print(f"  Created placeholder: {filename}")
    
    # Save captions
    with open(output_dir / "captions.json", 'w') as f:
        json.dump(captions, f, indent=2)
    
    print(f"\n⚠ Created 20 placeholder images")
    print(f"  Location: {output_dir}")
    print(f"\n  IMPORTANT: Replace these with real cat/dog images!")
    print(f"  You can manually add images and edit captions.json")
    return True


def download_from_url():
    """Download sample images from public URLs."""
    try:
        import urllib.request
        from io import BytesIO
        
        print("Downloading sample images from public URLs...")
        
        # Sample public domain image URLs (Creative Commons)
        # Note: These URLs might change or become unavailable
        sample_urls = [
            # Add public domain cat/dog image URLs here
            # For example from Unsplash, Pexels, or Wikimedia Commons
        ]
        
        if not sample_urls:
            print("No URLs configured. Skipping this method.")
            return False
        
        output_dir = Path("./training_data")
        output_dir.mkdir(exist_ok=True, parents=True)
        
        captions = {}
        
        for idx, url in enumerate(sample_urls[:20]):
            try:
                filename = f"image_{idx+1:02d}.jpg"
                
                # Download image
                with urllib.request.urlopen(url) as response:
                    img_data = response.read()
                
                image = Image.open(BytesIO(img_data)).convert('RGB')
                image = image.resize((512, 512), Image.LANCZOS)
                
                # Save image
                image.save(output_dir / filename, "JPEG", quality=95)
                
                # Generic caption
                captions[filename] = "a photo of an animal"
                print(f"  Downloaded: {filename}")
                
            except Exception as e:
                print(f"  Failed to download {url}: {e}")
                continue
        
        if captions:
            with open(output_dir / "captions.json", 'w') as f:
                json.dump(captions, f, indent=2)
            
            print(f"\n✓ Downloaded {len(captions)} images")
            return True
        
        return False
        
    except Exception as e:
        print(f"✗ URL download method failed: {e}")
        return False


def main():
    print("="*80)
    print("Creating Training Dataset: Cats & Dogs")
    print("="*80)
    print("\nThis script will create a small dataset for fine-tuning demo.")
    print("Trying multiple methods to obtain images...\n")
    
    # Try methods in order of preference
    methods = [
        ("Hugging Face Datasets", create_dataset_from_huggingface),
        ("CIFAR-10", create_dataset_from_cifar),
        ("Synthetic Placeholders", create_synthetic_dataset),
    ]
    
    for method_name, method_func in methods:
        print(f"\n{'='*80}")
        print(f"Trying: {method_name}")
        print(f"{'='*80}")
        
        if method_func():
            print(f"\n{'='*80}")
            print("SUCCESS!")
            print(f"{'='*80}")
            
            # Verify the dataset
            verify_dataset()
            
            print("\nNext steps:")
            print("1. Review the images in ./training_data/")
            print("2. Edit captions.json to improve descriptions")
            print("3. Run: python finetune_2080_optimized.py")
            return
        
        print(f"\n{method_name} didn't work, trying next method...")
    
    print("\n" + "="*80)
    print("FALLBACK: Manual Setup Instructions")
    print("="*80)
    print("\nPlease manually create the dataset:")
    print("\n1. Create directory: ./training_data/")
    print("2. Add 20 images (any format: jpg, png)")
    print("3. Create captions.json with format:")
    print('   {')
    print('     "image1.jpg": "a photo of a cat",')
    print('     "image2.jpg": "a photo of a dog",')
    print('     ...')
    print('   }')
    print("\nOr search online for 'cats and dogs dataset' and download manually.")


def verify_dataset():
    """Verify the created dataset."""
    dataset_dir = Path("./training_data")
    
    if not dataset_dir.exists():
        print("\n✗ Dataset directory not found!")
        return False
    
    # Check for images
    image_files = list(dataset_dir.glob("*.jpg")) + \
                  list(dataset_dir.glob("*.png")) + \
                  list(dataset_dir.glob("*.jpeg"))
    
    # Check for captions
    captions_file = dataset_dir / "captions.json"
    
    print(f"\nDataset Verification:")
    print(f"  Directory: {dataset_dir}")
    print(f"  Images found: {len(image_files)}")
    print(f"  Captions file: {'✓' if captions_file.exists() else '✗'}")
    
    if captions_file.exists():
        with open(captions_file, 'r') as f:
            captions = json.load(f)
        print(f"  Captions count: {len(captions)}")
        
        # Show sample captions
        print(f"\n  Sample captions:")
        for filename, caption in list(captions.items())[:3]:
            print(f"    {filename}: \"{caption}\"")
    
    if len(image_files) >= 10:
        print(f"\n✓ Dataset looks good! You have {len(image_files)} images.")
        return True
    else:
        print(f"\n⚠ Warning: Only {len(image_files)} images found. Consider adding more.")
        return False


if __name__ == "__main__":
    main()
