import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Batch
from torch.utils.data import Dataset
from typing import Optional, List, Sequence, Tuple
import os
import pickle
from protein_utils import extract_ca_coordinates, create_protein_graph, load_esm2_features
import random
from torch_geometric.loader import DataLoader
# Import label merging function
try:
    from protein_dataset import merge_labels
except ImportError:
    # If cannot import, define local version
    def merge_labels(label):
        """Label merging function"""
        if label in ['C7', 'C8', 'C9']:
            return 'C7-C9'
        elif label in ['C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17']:
            return 'C10-C17'
        elif label in ['D6', 'D7', 'D8']:
            return 'D6-D8'
        else:
            return label

def process_protein_dataset_by_ids(protein_ids: List[str], pdb_dir: str, distance_threshold: float = 6.0, output_file: str = None, esm2_dir: str = 'data/ESM2') -> Tuple[List[Data], List[str]]:
    """
    Process PDB files based on protein ID list
    
    Args:
        protein_ids: List of protein IDs
        pdb_dir: PDB file directory
        distance_threshold: Connection threshold
        output_file: Output file path (optional)
        esm2_dir: ESM2 feature file directory
        
    Returns:
        Tuple[List of protein graphs, List of successfully processed protein IDs]
    """
    protein_graphs = []
    successful_ids = []
    
    for protein_id in protein_ids:
        # Try different file extensions
        pdb_file = None
        for ext in ['.pdb', '.pdbqt']:
            potential_file = os.path.join(pdb_dir, protein_id + ext)
            if os.path.exists(potential_file):
                pdb_file = potential_file
                break
        
        if pdb_file is None:
            print(f"Warning: Cannot find PDB file for protein {protein_id}")
            continue
        
        try:
            # Extract CA coordinates
            ca_coords, residue_names = extract_ca_coordinates(pdb_file)
            
            # Create protein graph
            node_attr, edge_index, edge_attr, triangle_index, triangle_dis, codebar_features = create_protein_graph(
                ca_coords, residue_names, distance_threshold, pdb_file
            )
            
            # Load ESM2 features
            esm2_features = load_esm2_features(protein_id, esm2_dir)
            
            # Create Data object
            data = Data(
                x=node_attr,
                edge_index=edge_index,
                edge_attr=edge_attr,
                triangle_index=triangle_index,
                triangle_dis=triangle_dis,
                codebar_features=codebar_features,  # Add codebar features
                esm2_features=esm2_features,       # Add ESM2 features
                num_nodes=len(ca_coords)
            )
            
            protein_graphs.append(data)
            successful_ids.append(protein_id)
            
        except Exception as e:
            print(f"Error: Error processing protein {protein_id}: {e}")
            continue
    
    if output_file:
        with open(output_file, 'wb') as f:
            pickle.dump(protein_graphs, f)
        print(f"Saving {len(protein_graphs)} protein graphs to {output_file}")
    
    return protein_graphs, successful_ids

class ProteinDataset(Dataset):
    """Protein dataset class"""
    
    def __init__(
        self,
        df: pd.DataFrame,
        protein_graphs: Sequence[Data],
        target: str = None,
        transform=None,
        mean_train=None,
        std_train=None,
        normalize_labels=True,
        label_mapping=None,
    ):
        """
        Initialize protein dataset
        
        Args:
            df: DataFrame containing protein information
            protein_graphs: List of protein graphs
            target: Target column name
            transform: Data transformation
            mean_train: Training mean (for standardization)
            std_train: Training standard deviation (for standardization)
            normalize_labels: Whether to normalize labels
        """
        self.df = df
        self.protein_graphs = protein_graphs
        self.target = target
        self.transform = transform
        self.mean_train = mean_train
        self.std_train = std_train
        self.normalize_labels = normalize_labels
        self.label_mapping = label_mapping
        
        # Process labels
        if target is not None:
            # Get label values
            label_values = self.df[target].values
            
            # Check label type: if string (object type), need to process first
            # Check if string type: dtype is object or first element is string
            is_string_type = (label_values.dtype == 'object' or 
                            (len(label_values) > 0 and isinstance(label_values[0], str)))
            
            if is_string_type:
                # Label is string type (classification task)
                if label_mapping is not None:
                    # Use label mapping to convert strings to numbers
                    mapped_labels = []
                    for label in label_values:
                        if label in label_mapping:
                            mapped_labels.append(label_mapping[label])
                        else:
                            print(f"Warning: Unknown label {label}, using 0 as default")
                            mapped_labels.append(0)
                    labels = torch.tensor(mapped_labels, dtype=torch.long)
                else:
                    # If no mapping provided, try to convert strings to numbers
                    # This usually shouldn't happen, but for robustness, we try conversion
                    try:
                        # Try to convert to numeric
                        numeric_labels = pd.to_numeric(label_values, errors='coerce')
                        if numeric_labels.isna().any():
                            raise ValueError("Cannot convert string labels to numeric without label_mapping")
                        labels = torch.tensor(numeric_labels.values, dtype=torch.get_default_dtype())
                    except (ValueError, TypeError) as e:
                        raise ValueError(f"Cannot convert string labels to tensor without label_mapping: {e}")
            else:
                # Label is numeric type (regression task or already encoded classification task)
                if not torch.is_tensor(label_values):
                    labels = torch.tensor(label_values, dtype=torch.get_default_dtype())
                else:
                    labels = label_values
                
                # Apply label mapping (if provided, for already encoded classification tasks)
                if label_mapping is not None:
                    # Map original labels to consecutive indices 0, 1, 2, ...
                    mapped_labels = []
                    for label in labels:
                        label_val = label.item() if torch.is_tensor(label) else label
                        if label_val in label_mapping:
                            mapped_labels.append(label_mapping[label_val])
                        else:
                            print(f"Warning: Unknown label {label_val}, using 0 as default")
                            mapped_labels.append(0)
                    labels = torch.tensor(mapped_labels, dtype=torch.long)
                elif normalize_labels and mean_train is not None and std_train is not None:
                    # Standardization for regression tasks
                    labels = (labels - mean_train) / std_train
            
            self.labels = labels
        else:
            self.labels = None
    
    def __len__(self):
        return len(self.protein_graphs)
    
    def __getitem__(self, idx):
        """Get single protein sample"""
        protein_graph = self.protein_graphs[idx]
        
        if self.labels is not None:
            label = self.labels[idx]
            return protein_graph, label
        else:
            return protein_graph
    
    @staticmethod
    def collate(samples):
        """Batch processing function"""
        if isinstance(samples[0], tuple):
            graphs, labels = map(list, zip(*samples))
            
            # Use PyTorch Geometric batch processing
            batched_graph = Batch.from_data_list(graphs)
            batched_labels = torch.stack(labels)
            
            return batched_graph, batched_labels
        else:
            graphs = samples
            batched_graph = Batch.from_data_list(graphs)
            return batched_graph

def print_label_distribution(df_train, df_val, df_test, target_col='label'):
    """Print label distribution"""
    print(f"\nLabel distribution statistics:")
    print(f"Training: {len(df_train)} samples")
    print(f"Validation: {len(df_val)} samples")
    print(f"Test: {len(df_test)} samples")
    
    # Training label distribution
    train_labels = df_train[target_col].value_counts().sort_index()
    print(f"\nTraining set label distribution:")
    for label, count in train_labels.items():
        print(f"  Label {label}: {count} samples")
    
    # Validation label distribution
    val_labels = df_val[target_col].value_counts().sort_index()
    print(f"\nValidation set label distribution:")
    for label, count in val_labels.items():
        print(f"  Label {label}: {count} samples")
    
    # Test label distribution
    test_labels = df_test[target_col].value_counts().sort_index()
    print(f"\nTest set label distribution:")
    for label, count in test_labels.items():
        print(f"  Label {label}: {count} samples")

def load_protein_data_with_sampling(config):
    """
    Load protein data using sampling strategy (runtime dynamic sampling)
    
    Important Description:
    - Training: Runtime dynamic sampling to balance class distribution
    - Validation: Keep original distribution (not sampled), for true evaluation of model generalization performance
    - Test: Keep original distribution (not sampled), for final evaluation of model performance
    
    Sampling strategy:
    - C1 data: undersample to 60% of original
    - C2 data: no sampling (keep as is)
    - C3 and D2: oversample to 5x original
    - Other data: expand to 20x original
    
    Args:
        config: Configuration dictionary
        
    Returns:
        Training, validation, test data loaders
    """
    dataset_name = config['training']['dataset']
    target = config['training'].get('target')
    batch_size = config['training']['batch_size']
    distance_threshold = config.get('protein', {}).get('distance_threshold', 6.0)
    
    # Check if ESM2-only mode
    esm2_only = config.get('data', {}).get('esm2_only', False)
    
    # Get original data file path
    csv_file = config.get('data', {}).get('csv_file', 'one_label.csv')
    train_ratio = config.get('data', {}).get('train_ratio', 0.8)
    val_ratio = config.get('data', {}).get('val_ratio', 0.1)
    test_ratio = config.get('data', {}).get('test_ratio', 0.1)
    random_seed = config.get('data', {}).get('random_seed', 42)
    
    # PDB file path - use unified PDB directory (only used in non-ESM2-only mode)
    pdb_dir = config['data'].get('pdb_dir', 'data/PDB')
    esm2_dir = config['data'].get('esm2_dir', 'data/ESM2')
    
    print("=" * 70)
    print("Loading protein data with dynamic sampling strategy...")
    print("=" * 70)
    
    # Check if original CSV file exists
    if not os.path.exists(csv_file):
        raise FileNotFoundError(f"Original data CSV file does not exist: {csv_file}")
    
    # Read original data
    print(f"\nStep 1: Reading original data file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        # Check if header exists
        lower_cols = [str(c).lower() for c in df.columns]
        if not (any('protein' in c for c in lower_cols) and any('label' in c for c in lower_cols)):
            # No header, add column names
            df = pd.read_csv(csv_file, header=None, names=['protein_id', 'label'])
            print("Detected CSV file without header, automatically added column names: protein_id, label")
    except Exception as e:
        # If reading fails, try no-header method
        print(f"Trying to read without header: {e}")
        df = pd.read_csv(csv_file, header=None, names=['protein_id', 'label'])
    
    # Ensure column names are standardized
    if 'protein_id' not in df.columns or 'label' not in df.columns:
        # Try to rename columns
        if len(df.columns) == 2:
            df.columns = ['protein_id', 'label']
        else:
            raise ValueError(f"CSV file must contain two columns (protein_id and label), current columns: {list(df.columns)}")
    
    print(f"Original data size: {len(df)} samples")
    
    # Import sampling functions
    try:
        from data_sampling import split_dataset_first, apply_sampling_strategy
    except ImportError:
        raise ImportError("Cannot import sampling functions, please ensure data_sampling.py file exists")
    
    # Step 2: First perform data split (use stratified split, ensure each class has samples in all datasets)
    print(f"\nStep 2: Splitting original data (stratified split)")
    target_col = target if target else 'label'
    df_train, df_val, df_test = split_dataset_first(
        df, train_ratio, val_ratio, test_ratio, random_seed, target_col=target_col
    )
    
    # Step 3: Apply sampling strategy to training set
    print(f"\nStep 3: Applying sampling strategy to training set (runtime dynamic sampling)")
    df_train = apply_sampling_strategy(df_train, target_col=target if target else 'label')
    
    print(f"\nStep 4: Data preparation completed")
    print(f"  Training: {len(df_train)} samples (sampled)")
    print(f"  Validation: {len(df_val)} samples (not sampled)")
    print(f"  Test: {len(df_test)} samples (not sampled)")
    
    # Print label distribution statistics
    if target is not None:
        print_label_distribution(df_train, df_val, df_test, target_col=target)
    
    # Choose different processing methods based on esm2_only mode
    if esm2_only:
        # ESM2-only mode: only need ESM2 features, no PDB files needed
        print("\n" + "=" * 70)
        print("Building ESM2-only datasets (no PDB/graph features needed)...")
        print("=" * 70)
        
        # Build Data objects containing only ESM2
        def build_esm2_graphs(df: pd.DataFrame) -> List[Data]:
            graphs = []
            for pid in df['protein_id'].tolist():
                pid_lower = str(pid).lower()
                esm = load_esm2_features(pid_lower, esm2_dir)
                # Minimal graph object: model only uses esm2_features
                g = Data(
                    x=torch.zeros(1, 1, dtype=torch.float32),  # Placeholder
                    edge_index=torch.zeros(2, 0, dtype=torch.long),
                    num_nodes=1,
                    esm2_features=esm
                )
                graphs.append(g)
            return graphs
        
        print("Building ESM2 features for training set...")
        train_graphs = build_esm2_graphs(df_train)
        print(f"✓ Training set: {len(train_graphs)} samples")
        
        print("Building ESM2 features for validation set...")
        val_graphs = build_esm2_graphs(df_val)
        print(f"✓ Validation set: {len(val_graphs)} samples")
        
        print("Building ESM2 features for test set...")
        test_graphs = build_esm2_graphs(df_test)
        print(f"✓ Test set: {len(test_graphs)} samples")
    else:
        # Full mode: need to process PDB files and graph structures
        print("\n" + "=" * 70)
        print("Processing protein graphs...")
        print("=" * 70)
        
        # Check if processed graph files already exist (using dynamic sampling identifier)
        # Note: Due to dynamic sampling, graph file cache may be inaccurate, recommend using force_reprocess
        train_graphs_file = f'data_output/protein_graphs_train_dynamic_sampled.pkl'
        val_graphs_file = f'data_output/protein_graphs_val_dynamic_sampled.pkl'
        test_graphs_file = f'data_output/protein_graphs_test_dynamic_sampled.pkl'
        
        # Check if force reprocessing
        force_reprocess = config.get('training', {}).get('force_reprocess', False)
        if force_reprocess:
            print("Force reprocessing: deleting old graph files...")
            for graph_file in [train_graphs_file, val_graphs_file, test_graphs_file]:
                if os.path.exists(graph_file):
                    try:
                        os.remove(graph_file)
                        print(f"Deleted old file: {graph_file}")
                    except PermissionError as e:
                        print(f"Unable to delete file {graph_file}: {e}")
        
        # Process training set
        if os.path.exists(train_graphs_file) and not force_reprocess:
            print("Loading pre-processed training protein graphs...")
            with open(train_graphs_file, 'rb') as f:
                train_graphs = pickle.load(f)
        else:
            print("Processing training PDB files...")
            train_graphs, successful_train_ids = process_protein_dataset_by_ids(
                df_train['protein_id'].tolist(), pdb_dir, distance_threshold, train_graphs_file, esm2_dir
            )
            # Filter DataFrame, only keep successfully processed proteins
            df_train = df_train[df_train['protein_id'].isin(successful_train_ids)].reset_index(drop=True)
            print(f"Training set filtered: {len(df_train)} samples after processing")
        
        # Process validation set
        if os.path.exists(val_graphs_file) and not force_reprocess:
            print("Loading pre-processed validation protein graphs...")
            with open(val_graphs_file, 'rb') as f:
                val_graphs = pickle.load(f)
        else:
            print("Processing validation PDB files...")
            val_graphs, successful_val_ids = process_protein_dataset_by_ids(
                df_val['protein_id'].tolist(), pdb_dir, distance_threshold, val_graphs_file, esm2_dir
            )
            # Filter DataFrame, only keep successfully processed proteins
            df_val = df_val[df_val['protein_id'].isin(successful_val_ids)].reset_index(drop=True)
            print(f"Validation set filtered: {len(df_val)} samples after processing")
        
        # Process test set
        if os.path.exists(test_graphs_file) and not force_reprocess:
            print("Loading pre-processed test protein graphs...")
            with open(test_graphs_file, 'rb') as f:
                test_graphs = pickle.load(f)
        else:
            print("Processing test PDB files...")
            test_graphs, successful_test_ids = process_protein_dataset_by_ids(
                df_test['protein_id'].tolist(), pdb_dir, distance_threshold, test_graphs_file, esm2_dir
            )
            # Filter DataFrame, only keep successfully processed proteins
            df_test = df_test[df_test['protein_id'].isin(successful_test_ids)].reset_index(drop=True)
            print(f"Test set filtered: {len(df_test)} samples after processing")
    
    # Process label mapping (for classification tasks)
    label_mapping = None
    mean_train = None
    std_train = None
    
    if target is not None:
        if config['model']['criterion'] == 'cross_entropy':
            # Classification task: first merge labels, then create label mapping
            print("\nStep: Applying label merging strategy")
            print("Merging rules: C7-C9 -> C7-C9, C10-C17 -> C10-C17, D6-D8 -> D6-D8")
            
            # Merge labels for Training, Validation, Test
            df_train[target] = df_train[target].apply(merge_labels)
            df_val[target] = df_val[target].apply(merge_labels)
            df_test[target] = df_test[target].apply(merge_labels)
            
            # Get all unique labels after merging
            all_labels = df_train[target].values
            unique_labels = sorted(set(all_labels))
            print(f"Merged unique labels ({len(unique_labels)}): {unique_labels}")
            
            # Create label mapping: map merged labels to 0, 1, 2, ...
            label_mapping = {}
            mapped_idx = 0
            
            # Process all existing merged labels, map in order to 0, 1, 2, ...
            for label in unique_labels:
                label_mapping[label] = mapped_idx
                mapped_idx += 1
            
            print(f"Label mapping: {label_mapping}")
            print(f"Number of classes after mapping: {len(set(label_mapping.values()))}")
            
            # Statistics of each class count (after merging)
            print("\nMerged class statistics:")
            for label in sorted(label_mapping.keys()):
                count = (all_labels == label).sum()
                print(f"  Label {label} -> class {label_mapping[label]}: {count} samples")
            
        else:
            # Regression task: compute standardization parameters
            train_vals = df_train[target].values
            val_vals = df_val[target].values
            test_vals = df_test[target].values
            
            mean_train = np.mean(train_vals)
            std_train = np.std(train_vals)
    
    # Create datasets
    train_dataset = ProteinDataset(
        df_train, train_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    
    val_dataset = ProteinDataset(
        df_val, val_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    
    test_dataset = ProteinDataset(
        df_test, test_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    
    # Create data loaders
    from torch_geometric.loader import DataLoader
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    return train_loader, val_loader, test_loader, mean_train, std_train
