#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Quick fix for dtype mismatch error in fine-tuning script.

This patches your existing finetune_model.py to fix the float32/float16 issue.
"""

import re
from pathlib import Path

def fix_training_step(filepath):
    """Fix the training_step function to handle dtype conversion."""
    
    filepath = Path(filepath)
    
    if not filepath.exists():
        print(f"❌ File not found: {filepath}")
        return False
    
    # Read the file
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Create backup
    backup_path = filepath.parent / f"{filepath.stem}_backup.py"
    with open(backup_path, 'w', encoding='utf-8') as f:
        f.write(content)
    print(f"✓ Backup created: {backup_path}")
    
    # Find and fix the training_step function
    # Look for the pattern where we send pixel_values to device
    old_pattern = r'pixel_values = batch\["pixel_values"\]\.to\(config\.device\)\s+captions = batch\["captions"\]'
    
    new_code = '''pixel_values = batch["pixel_values"].to(config.device)
    captions = batch["captions"]
    
    # Convert to float16 to match VAE dtype
    pixel_values = pixel_values.to(dtype=torch.float16)'''
    
    if re.search(old_pattern, content):
        content = re.sub(old_pattern, new_code, content)
        print("✓ Fixed dtype conversion in training_step")
        
        # Write fixed file
        with open(filepath, 'w', encoding='utf-8') as f:
            f.write(content)
        
        print(f"✓ File patched successfully: {filepath}")
        return True
    else:
        print("⚠ Could not find the pattern to fix")
        print("  Manual fix required - see instructions below")
        return False


def print_manual_fix():
    """Print manual fix instructions."""
    print("\n" + "="*80)
    print("MANUAL FIX INSTRUCTIONS")
    print("="*80)
    print("\nFind this code in your training_step function:")
    print("-"*80)
    print('''
    # Get batch data
    pixel_values = batch["pixel_values"].to(config.device)
    captions = batch["captions"]
    
    # Encode images to latent space
    with torch.no_grad():
        latents = vae.encode(pixel_values).latent_dist.sample()
''')
    print("-"*80)
    print("\nReplace with:")
    print("-"*80)
    print('''
    # Get batch data
    pixel_values = batch["pixel_values"].to(config.device)
    captions = batch["captions"]
    
    # Convert to float16 to match VAE dtype
    pixel_values = pixel_values.to(dtype=torch.float16)
    
    # Encode images to latent space
    with torch.no_grad():
        latents = vae.encode(pixel_values).latent_dist.sample()
''')
    print("-"*80)
    print("\nThe key addition is:")
    print("  pixel_values = pixel_values.to(dtype=torch.float16)")
    print("="*80)


def main():
    print("="*80)
    print("Dtype Mismatch Fix Utility")
    print("="*80)
    print("\nThis fixes: RuntimeError: Input type (float) and bias type (c10::Half)")
    print("should be the same\n")
    
    # Try to find the file
    possible_names = [
        "finetune_model.py",
        "finetune_2080_optimized.py",
        "finetune_stable_diffusion.py"
    ]
    
    found = False
    for name in possible_names:
        if Path(name).exists():
            print(f"Found: {name}")
            response = input(f"Patch this file? (yes/no) [yes]: ").strip().lower()
            
            if response != 'no':
                success = fix_training_step(name)
                if success:
                    print("\n✅ File patched successfully!")
                    print("\nYou can now run training again:")
                    print(f"  python {name}")
                    found = True
                    break
                else:
                    print_manual_fix()
                    found = True
                    break
    
    if not found:
        print("❌ Could not find training script")
        print("\nSearched for:")
        for name in possible_names:
            print(f"  - {name}")
        print("\nPlease specify the file to patch:")
        filepath = input("Enter filename: ").strip()
        
        if filepath and Path(filepath).exists():
            fix_training_step(filepath)
        else:
            print_manual_fix()


if __name__ == "__main__":
    main()
