#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Fix and Validate captions.json

This script:
1. Checks for common JSON errors
2. Fixes trailing commas
3. Validates the JSON structure
4. Creates a backup before fixing
"""

import json
import re
from pathlib import Path
import shutil

def fix_json_file(filepath):
    """
    Fix common JSON errors in captions file.
    """
    
    filepath = Path(filepath)
    
    if not filepath.exists():
        print(f"❌ File not found: {filepath}")
        return False
    
    # Create backup
    backup_path = filepath.parent / f"{filepath.stem}_backup.json"
    shutil.copy2(filepath, backup_path)
    print(f"✓ Backup created: {backup_path}")
    
    # Read the file
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    
    print(f"\nOriginal content length: {len(content)} characters")
    
    # Common fixes
    print("\nApplying fixes...")
    
    # 1. Remove trailing commas before closing braces/brackets
    original_content = content
    content = re.sub(r',(\s*[}\]])', r'\1', content)
    if content != original_content:
        print("  ✓ Removed trailing commas")
    
    # 2. Fix single quotes to double quotes (if any)
    original_content = content
    content = re.sub(r"'([^']*)':", r'"\1":', content)
    if content != original_content:
        print("  ✓ Fixed single quotes to double quotes")
    
    # 3. Remove any BOM (Byte Order Mark)
    if content.startswith('\ufeff'):
        content = content[1:]
        print("  ✓ Removed BOM")
    
    # 4. Try to parse and pretty-print
    try:
        data = json.loads(content)
        
        # Re-format nicely
        formatted = json.dumps(data, indent=2, ensure_ascii=False)
        
        # Write fixed file
        with open(filepath, 'w', encoding='utf-8') as f:
            f.write(formatted)
        
        print(f"\n✓ Successfully fixed and validated JSON!")
        print(f"✓ Found {len(data)} entries")
        
        # Show sample
        print("\nSample entries:")
        for i, (key, value) in enumerate(list(data.items())[:3]):
            print(f"  {key}: \"{value}\"")
        
        if len(data) > 3:
            print(f"  ... and {len(data) - 3} more")
        
        return True
        
    except json.JSONDecodeError as e:
        print(f"\n❌ Still have JSON error after fixes:")
        print(f"   {e}")
        print(f"\n   The error is at line {e.lineno}, column {e.colno}")
        print(f"   Character position: {e.pos}")
        
        # Show problematic area
        lines = content.split('\n')
        if e.lineno <= len(lines):
            print(f"\n   Problematic line {e.lineno}:")
            print(f"   {lines[e.lineno - 1]}")
            if e.lineno > 1:
                print(f"   Previous line {e.lineno - 1}:")
                print(f"   {lines[e.lineno - 2]}")
        
        print(f"\n   Common issues:")
        print(f"   - Trailing comma after last item")
        print(f"   - Missing comma between items")
        print(f"   - Unclosed quotes")
        print(f"   - Single quotes instead of double quotes")
        
        return False


def validate_captions_with_images(captions_file, image_dir):
    """
    Validate that captions match existing images.
    """
    
    captions_file = Path(captions_file)
    image_dir = Path(image_dir)
    
    # Load captions
    with open(captions_file, 'r') as f:
        captions = json.load(f)
    
    # Find images
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
    images = []
    for ext in image_extensions:
        images.extend(image_dir.glob(f"*{ext}"))
        images.extend(image_dir.glob(f"*{ext.upper()}"))
    
    image_names = {img.name for img in images}
    caption_names = set(captions.keys())
    
    # Check matches
    print(f"\n{'='*80}")
    print("Validation Report")
    print(f"{'='*80}")
    print(f"Images found: {len(image_names)}")
    print(f"Captions found: {len(caption_names)}")
    
    # Images without captions
    missing_captions = image_names - caption_names
    if missing_captions:
        print(f"\n⚠️  Images without captions ({len(missing_captions)}):")
        for name in sorted(missing_captions)[:5]:
            print(f"   - {name}")
        if len(missing_captions) > 5:
            print(f"   ... and {len(missing_captions) - 5} more")
    
    # Captions without images
    missing_images = caption_names - image_names
    if missing_images:
        print(f"\n⚠️  Captions without images ({len(missing_images)}):")
        for name in sorted(missing_images)[:5]:
            print(f"   - {name}")
        if len(missing_images) > 5:
            print(f"   ... and {len(missing_images) - 5} more")
    
    # Matching pairs
    matching = image_names & caption_names
    print(f"\n✓ Matching pairs: {len(matching)}")
    
    if len(matching) < 10:
        print(f"\n⚠️  Warning: Only {len(matching)} images with captions.")
        print(f"   You need at least 10-20 for training.")
    
    print(f"{'='*80}")
    
    return len(matching) > 0


def create_template_captions(image_dir, output_file):
    """
    Create a template captions.json from images in directory.
    """
    
    image_dir = Path(image_dir)
    output_file = Path(output_file)
    
    # Find all images
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
    images = []
    for ext in image_extensions:
        images.extend(image_dir.glob(f"*{ext}"))
        images.extend(image_dir.glob(f"*{ext.upper()}"))
    
    if not images:
        print(f"❌ No images found in {image_dir}")
        return False
    
    print(f"Found {len(images)} images")
    
    # Create captions
    captions = {}
    for img in sorted(images):
        # Try to generate caption from filename
        name = img.stem.lower()
        
        if 'cat' in name:
            caption = "a photo of a cat"
        elif 'dog' in name:
            caption = "a photo of a dog"
        elif 'kitten' in name:
            caption = "a photo of a kitten"
        elif 'puppy' in name:
            caption = "a photo of a puppy"
        else:
            caption = f"a photo of {name.replace('_', ' ')}"
        
        captions[img.name] = caption
    
    # Save
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(captions, f, indent=2, ensure_ascii=False)
    
    print(f"✓ Created template: {output_file}")
    print(f"✓ Generated {len(captions)} captions")
    print("\nPlease review and edit the captions for better results!")
    
    return True


def main():
    import sys
    
    print("="*80)
    print("Fix and Validate captions.json")
    print("="*80)
    
    captions_file = "./training_data/captions.json"
    image_dir = "./training_data"
    
    # Check if file exists
    if not Path(captions_file).exists():
        print(f"\n❌ File not found: {captions_file}")
        print("\nWould you like to create a template?")
        response = input("Create template from images? (yes/no) [yes]: ").strip().lower()
        
        if response != 'no':
            create_template_captions(image_dir, captions_file)
        return
    
    # Try to fix the file
    print(f"\nAttempting to fix: {captions_file}")
    print("="*80)
    
    success = fix_json_file(captions_file)
    
    if success:
        # Validate against images
        print("\n" + "="*80)
        validate_captions_with_images(captions_file, image_dir)
        
        print("\n✅ All done! You can now run training:")
        print("   python finetune_2080_optimized.py")
    else:
        print("\n" + "="*80)
        print("Manual Fix Required")
        print("="*80)
        print("\nPlease manually check your captions.json file.")
        print("\nCommon JSON format:")
        print('''
{
  "image1.jpg": "description here",
  "image2.jpg": "another description",
  "image3.jpg": "last description"
}

Important:
- Use double quotes " not single quotes '
- No comma after the last item
- Each line needs a comma except the last one
''')
        
        print("\nOr delete captions.json and run this script again to create a template.")


if __name__ == "__main__":
    main()
