import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from sklearn.neighbors import NearestNeighbors
from scipy.stats import multivariate_normal
import warnings
import os
warnings.filterwarnings('ignore')

class GaussianMixtureModel:
    """Custom GMM implementation with EM algorithm"""
    
    def __init__(self, n_components, max_iter=100, tol=1e-6, reg_covar=1e-6):
        self.n_components = n_components
        self.max_iter = max_iter
        self.tol = tol
        self.reg_covar = reg_covar
        
    def _initialize_parameters(self, X, kmeans_labels):
        """Initialize GMM parameters using K-means results"""
        n_samples, n_features = X.shape
        
        # Initialize weights
        self.weights_ = np.bincount(kmeans_labels) / n_samples
        
        # Initialize means using K-means centroids
        self.means_ = np.zeros((self.n_components, n_features))
        for k in range(self.n_components):
            mask = kmeans_labels == k
            if np.sum(mask) > 0:
                self.means_[k] = X[mask].mean(axis=0)
            else:
                self.means_[k] = X[np.random.randint(0, n_samples)]
        
        # Initialize covariances
        self.covariances_ = np.zeros((self.n_components, n_features, n_features))
        for k in range(self.n_components):
            mask = kmeans_labels == k
            if np.sum(mask) > 1:
                self.covariances_[k] = np.cov(X[mask].T) + self.reg_covar * np.eye(n_features)
            else:
                self.covariances_[k] = np.eye(n_features)
    
    def _e_step(self, X):
        """Expectation step"""
        n_samples = X.shape[0]
        responsibilities = np.zeros((n_samples, self.n_components))
        
        for k in range(self.n_components):
            try:
                responsibilities[:, k] = self.weights_[k] * multivariate_normal.pdf(
                    X, self.means_[k], self.covariances_[k])
            except np.linalg.LinAlgError:
                # Handle singular covariance matrix
                self.covariances_[k] += self.reg_covar * np.eye(self.covariances_[k].shape[0])
                responsibilities[:, k] = self.weights_[k] * multivariate_normal.pdf(
                    X, self.means_[k], self.covariances_[k])
        
        # Normalize responsibilities
        responsibilities += 1e-10  # Avoid division by zero
        responsibilities /= responsibilities.sum(axis=1, keepdims=True)
        
        return responsibilities
    
    def _m_step(self, X, responsibilities):
        """Maximization step"""
        n_samples, n_features = X.shape
        
        # Update weights
        Nk = responsibilities.sum(axis=0)
        self.weights_ = Nk / n_samples
        
        # Update means
        for k in range(self.n_components):
            if Nk[k] > 0:
                self.means_[k] = (responsibilities[:, k:k+1] * X).sum(axis=0) / Nk[k]
        
        # Update covariances
        for k in range(self.n_components):
            if Nk[k] > 0:
                diff = X - self.means_[k]
                self.covariances_[k] = np.dot(responsibilities[:, k] * diff.T, diff) / Nk[k]
                self.covariances_[k] += self.reg_covar * np.eye(n_features)
    
    def _compute_log_likelihood(self, X):
        """Compute log-likelihood"""
        log_likelihood = 0
        for i in range(X.shape[0]):
            likelihood = 0
            for k in range(self.n_components):
                try:
                    likelihood += self.weights_[k] * multivariate_normal.pdf(
                        X[i], self.means_[k], self.covariances_[k])
                except np.linalg.LinAlgError:
                    likelihood += 1e-10
            log_likelihood += np.log(likelihood + 1e-10)
        return log_likelihood
    
    def fit(self, X, kmeans_labels):
        """Fit GMM using EM algorithm"""
        self._initialize_parameters(X, kmeans_labels)
        
        prev_log_likelihood = -np.inf
        
        for iteration in range(self.max_iter):
            # E-step
            responsibilities = self._e_step(X)
            
            # M-step
            self._m_step(X, responsibilities)
            
            # Check convergence
            log_likelihood = self._compute_log_likelihood(X)
            
            if abs(log_likelihood - prev_log_likelihood) < self.tol:
                print(f"EM converged after {iteration + 1} iterations")
                break
            
            prev_log_likelihood = log_likelihood
        
        return self
    
    def predict(self, X):
        """Predict cluster labels"""
        responsibilities = self._e_step(X)
        return np.argmax(responsibilities, axis=1)
    
    def predict_proba(self, X):
        """Predict cluster probabilities"""
        return self._e_step(X)
    
    def sample(self, n_samples=1):
        """Generate samples from the GMM"""
        samples = []
        
        for _ in range(n_samples):
            # Choose component based on weights
            component = np.random.choice(self.n_components, p=self.weights_)
            
            # Sample from chosen component
            sample = np.random.multivariate_normal(
                self.means_[component], self.covariances_[component])
            samples.append(sample)
        
        return np.array(samples)

class MNISTClusteringSystem:
    """Complete MNIST clustering and generation system"""
    
    def __init__(self, n_gaussians=15, pca_dim=50, output_dir='./figs'):
        self.n_gaussians = n_gaussians
        self.pca_dim = pca_dim
        self.output_dir = output_dir
        self.pca = None
        self.kmeans = None
        self.gmm = None
        self.tsne = None
        self.X_original = None
        self.X_pca = None
        self.X_tsne = None
        self.y_true = None
        self.kmeans_labels = None
        self.gmm_labels = None
        
        # Create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
        
    def load_data(self, n_samples=5000):
        """Load and preprocess MNIST data"""
        print("Loading MNIST data...")
        mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
        
        # Sample subset for faster computation
        indices = np.random.choice(len(mnist.data), n_samples, replace=False)
        self.X_original = mnist.data[indices].astype(np.float32) / 255.0
        self.y_true = mnist.target[indices].astype(int)
        
        print(f"Loaded {n_samples} MNIST samples")
        return self
    
    def apply_pca(self):
        """Apply PCA for dimensionality reduction"""
        print(f"Applying PCA (reducing to {self.pca_dim} dimensions)...")
        self.pca = PCA(n_components=self.pca_dim, random_state=42)
        self.X_pca = self.pca.fit_transform(self.X_original)
        
        explained_variance = np.sum(self.pca.explained_variance_ratio_)
        print(f"PCA explained variance: {explained_variance:.3f}")
        return self
    
    def fit_kmeans(self):
        """Fit K-means clustering"""
        print(f"Fitting K-means with {self.n_gaussians} clusters...")
        self.kmeans = KMeans(n_clusters=self.n_gaussians, random_state=42, n_init=10)
        self.kmeans_labels = self.kmeans.fit_predict(self.X_pca)
        return self
    
    def fit_gmm(self):
        """Fit Gaussian Mixture Model using EM algorithm"""
        print(f"Fitting GMM with {self.n_gaussians} components using EM...")
        self.gmm = GaussianMixtureModel(n_components=self.n_gaussians)
        self.gmm.fit(self.X_pca, self.kmeans_labels)
        self.gmm_labels = self.gmm.predict(self.X_pca)
        return self
    
    def apply_tsne(self, perplexity=30):
        """Apply t-SNE for visualization"""
        print("Applying t-SNE for visualization...")
        self.tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
        self.X_tsne = self.tsne.fit_transform(self.X_pca)
        return self
    
    def evaluate_clustering(self):
        """Evaluate clustering performance"""
        print("\n=== Clustering Evaluation ===")
        
        # K-means evaluation
        kmeans_ari = adjusted_rand_score(self.y_true, self.kmeans_labels)
        kmeans_nmi = normalized_mutual_info_score(self.y_true, self.kmeans_labels)
        kmeans_silhouette = silhouette_score(self.X_pca, self.kmeans_labels)
                
        print(f"K-means - ARI: {kmeans_ari:.3f}, NMI: {kmeans_nmi:.3f}, Silhouette: {kmeans_silhouette:.3f}")
        
        return {
            'kmeans': {'ari': kmeans_ari, 'nmi': kmeans_nmi, 'silhouette': kmeans_silhouette},
        }
    
    def visualize_clusters(self, save_fig=True):
        """Visualize clustering results using t-SNE"""
        fig, axes = plt.subplots(1, 2, figsize=(18, 5))
        
        # True labels
        scatter1 = axes[0].scatter(self.X_tsne[:, 0], self.X_tsne[:, 1], 
                                 c=self.y_true, cmap='tab10', alpha=0.7, s=20)
        axes[0].set_title('True Labels')
        axes[0].set_xlabel('t-SNE 1')
        axes[0].set_ylabel('t-SNE 2')
        plt.colorbar(scatter1, ax=axes[0])
        
        # K-means clusters
        scatter2 = axes[1].scatter(self.X_tsne[:, 0], self.X_tsne[:, 1], 
                                 c=self.kmeans_labels, cmap='tab10', alpha=0.7, s=20)
        axes[1].set_title('K-means Clusters')
        axes[1].set_xlabel('t-SNE 1')
        axes[1].set_ylabel('t-SNE 2')
        plt.colorbar(scatter2, ax=axes[1])
        
        plt.tight_layout()
        
        if save_fig:
            filepath = os.path.join(self.output_dir, 'clustering_visualization.png')
            plt.savefig(filepath, dpi=150, bbox_inches='tight')
            print(f"Saved clustering visualization to {filepath}")
        
        plt.show()
    
    def generate_images(self, method='weighted_neighbors', n_samples=10):
        """Generate new images using specified method"""
        print(f"\nGenerating {n_samples} images using '{method}' method...")
        
        if method == 'weighted_neighbors':
            return self._generate_weighted_neighbors(n_samples)
        elif method == 'gaussian_mixture':
            return self._generate_gaussian_mixture(n_samples)
        elif method == 'interpolation':
            return self._generate_interpolation(n_samples)
        else:
            raise ValueError("Method must be 'weighted_neighbors', 'gaussian_mixture', or 'interpolation'")
    
    def _generate_weighted_neighbors(self, n_samples):
        """Generate images using weighted average of nearest neighbors"""
        generated_images = []
        
        # Fit nearest neighbors model
        nn_model = NearestNeighbors(n_neighbors=5, metric='euclidean')
        nn_model.fit(self.X_pca)
        
        for _ in range(n_samples):
            # Pick a random sample as seed
            seed_idx = np.random.randint(0, len(self.X_pca))
            seed_point = self.X_pca[seed_idx:seed_idx+1]
            
            # Find nearest neighbors
            distances, indices = nn_model.kneighbors(seed_point)
            
            # Create weights inversely proportional to distance
            weights = 1.0 / (distances[0] + 1e-10)
            weights = weights / np.sum(weights)
            
            # Generate weighted average in PCA space
            pca_sample = np.average(self.X_pca[indices[0]], axis=0, weights=weights)
            
            # Transform back to image space
            image_sample = self.pca.inverse_transform(pca_sample.reshape(1, -1))
            generated_images.append(image_sample[0])
        
        return np.array(generated_images)
    
    def _generate_gaussian_mixture(self, n_samples):
        """Generate images by sampling from GMM components"""
        generated_images = []
        
        # Sample from GMM in PCA space
        pca_samples = self.gmm.sample(n_samples)
        
        # Transform back to image space
        for pca_sample in pca_samples:
            image_sample = self.pca.inverse_transform(pca_sample.reshape(1, -1))
            generated_images.append(image_sample[0])
        
        return np.array(generated_images)
    
    def _generate_interpolation(self, n_samples):
        """Generate images by interpolating between nearest neighbors with distance-based weights"""
        generated_images = []
        
        # Fit nearest neighbors model with 2 neighbors
        nn_model = NearestNeighbors(n_neighbors=2, metric='euclidean')
        nn_model.fit(self.X_pca)
        
        for _ in range(n_samples):
            # Pick a random sample as seed
            seed_idx = np.random.randint(0, len(self.X_pca))
            seed_point = self.X_pca[seed_idx:seed_idx+1]
            
            # Find 2 nearest neighbors
            distances, indices = nn_model.kneighbors(seed_point)
            
            # Create weights inversely proportional to distance
            weights = 1.0 / (distances[0] + 1e-10)
            weights = weights / np.sum(weights)
            
            # Interpolate between the two neighbors using distance-based weights
            pca_sample = weights[0] * self.X_pca[indices[0][0]] + weights[1] * self.X_pca[indices[0][1]]
            
            # Transform back to image space
            image_sample = self.pca.inverse_transform(pca_sample.reshape(1, -1))
            generated_images.append(image_sample[0])
        
        return np.array(generated_images)
    
    def display_generated_images(self, generated_images, method_name, save_fig=True):
        """Display generated images"""
        n_images = len(generated_images)
        cols = min(5, n_images)
        rows = (n_images + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
        if rows == 1:
            axes = axes.reshape(1, -1)
        
        for i in range(n_images):
            row, col = i // cols, i % cols
            
            # Clip values to [0, 1] and reshape to 28x28
            image = np.clip(generated_images[i], 0, 1).reshape(28, 28)
            
            axes[row, col].imshow(image, cmap='gray')
            axes[row, col].set_title(f'Generated {i+1}')
            axes[row, col].axis('off')
        
        # Hide empty subplots
        for i in range(n_images, rows * cols):
            row, col = i // cols, i % cols
            axes[row, col].axis('off')
        
        plt.suptitle(f'Generated Images - {method_name}', fontsize=14)
        plt.tight_layout()
        
        if save_fig:
            # Create filename from method name
            filename = method_name.lower().replace(' ', '_') + '_generated.png'
            filepath = os.path.join(self.output_dir, filename)
            plt.savefig(filepath, dpi=150, bbox_inches='tight')
            print(f"Saved {method_name} generated images to {filepath}")
        
        plt.show()
    
    def run_complete_pipeline(self, n_samples=5000):
        """Run the complete pipeline"""
        print("=== MNIST Clustering and Generation Pipeline ===\n")
        
        # Load and preprocess data
        self.load_data(n_samples)
        
        # Apply PCA
        self.apply_pca()
        
        # Fit K-means
        self.fit_kmeans()
        
        # Fit GMM with EM
        self.fit_gmm()
        
        # Apply t-SNE for visualization
        self.apply_tsne()
        
        # Evaluate clustering
        evaluation_results = self.evaluate_clustering()
        
        # Visualize clusters
        self.visualize_clusters(save_fig=True)
        
        # Generate and display images using all three methods
        methods = ['weighted_neighbors', 'gaussian_mixture', 'interpolation']
        
        for method in methods:
            generated_images = self.generate_images(method=method, n_samples=8)
            method_display_name = method.replace('_', ' ').title()
            self.display_generated_images(generated_images, method_display_name, save_fig=True)
        
        print(f"\nAll figures saved to: {self.output_dir}/")
        
        return evaluation_results

# Example usage
def main():
    # Create and run the system
    system = MNISTClusteringSystem(n_gaussians=15, pca_dim=50, output_dir='./figs')
    results = system.run_complete_pipeline(n_samples=1000)
    
    print("\n=== Pipeline Complete ===")
    print("The system has:")
    print("1. Loaded MNIST data")
    print("2. Applied PCA for dimensionality reduction")
    print("3. Used K-means to initialize GMM")
    print("4. Fit GMM using custom EM algorithm")
    print("5. Applied t-SNE for visualization")
    print("6. Evaluated clustering performance")
    print("7. Generated new images using three different methods")
    print("8. Saved all visualizations to ./figs directory")

if __name__ == "__main__":
    main()
