import pandas as pd
import numpy as np
import random
from typing import Tuple, Dict, List
import os


def split_dataset_first(df: pd.DataFrame, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, random_seed: int = 42, target_col: str = 'label') -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    First split the original dataset without any sampling
    Use stratified split to ensure each class has at least one sample in training, validation, and test sets
    
    Args:
        df: Original data DataFrame
        train_ratio: Training set ratio
        val_ratio: Validation set ratio  
        test_ratio: Test set ratio
        random_seed: Random seed
        target_col: Target column name (for stratified split)
        
    Returns:
        df_train, df_val, df_test: Split DataFrames
    """
    # Set random seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    # Check if ratios are reasonable
    total_ratio = train_ratio + val_ratio + test_ratio
    if abs(total_ratio - 1.0) > 1e-6:
        raise ValueError(f"Sum of ratios must be 1.0, current is {total_ratio}")
    
    print(f"Original dataset size: {len(df)} samples")
    
    # Check if target column exists
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' does not exist in DataFrame")
    
    # Get all unique labels
    unique_labels = df[target_col].unique()
    print(f"Found {len(unique_labels)} unique labels: {sorted(unique_labels)}")
    
    # Group by label
    label_groups = {}
    for label in unique_labels:
        label_groups[label] = df[df[target_col] == label].copy()
        print(f"  Label {label}: {len(label_groups[label])} samples")
    
    # Stratified split: ensure each class has at least one sample in all three datasets
    train_data = []
    val_data = []
    test_data = []
    
    for label, group_df in label_groups.items():
        n_samples = len(group_df)
        
        # If sample count is less than 3, cannot guarantee samples in all datasets
        if n_samples < 3:
            print(f"Warning: Label {label} has only {n_samples} samples, cannot guarantee samples in all datasets")
            # At least ensure training set has samples, try to allocate validation and test sets
            if n_samples == 1:
                train_data.append(group_df)
            elif n_samples == 2:
                # One to training set, one to validation set
                group_shuffled = group_df.sample(frac=1, random_state=random_seed).reset_index(drop=True)
                train_data.append(group_shuffled.iloc[:1])
                val_data.append(group_shuffled.iloc[1:])
            continue
        
        # Shuffle current class data
        group_shuffled = group_df.sample(frac=1, random_state=random_seed).reset_index(drop=True)
        
        # Calculate sample count for each dataset (at least ensure each dataset has 1 sample)
        n_train = max(1, int(n_samples * train_ratio))
        n_val = max(1, int(n_samples * val_ratio))
        n_test = max(1, n_samples - n_train - n_val)  # Remaining to test set
        
        # Adjust to ensure total is correct and each dataset has at least one sample
        if n_train + n_val + n_test > n_samples:
            # If calculated total exceeds actual sample count, scale down proportionally
            total_allocated = n_train + n_val + n_test
            n_train = max(1, int(n_train * n_samples / total_allocated))
            n_val = max(1, int(n_val * n_samples / total_allocated))
            n_test = n_samples - n_train - n_val
            if n_test < 1:
                n_test = 1
                n_train = max(1, n_train - 1)
        
        # Ensure each dataset has at least one sample
        if n_train < 1:
            n_train = 1
        if n_val < 1:
            n_val = 1
        if n_test < 1:
            n_test = 1
        
        # Split data
        train_data.append(group_shuffled.iloc[:n_train])
        val_data.append(group_shuffled.iloc[n_train:n_train + n_val])
        test_data.append(group_shuffled.iloc[n_train + n_val:n_train + n_val + n_test])
    
    # Merge data from all classes
    df_train = pd.concat(train_data, ignore_index=True).sample(frac=1, random_state=random_seed).reset_index(drop=True)
    df_val = pd.concat(val_data, ignore_index=True).sample(frac=1, random_state=random_seed).reset_index(drop=True)
    df_test = pd.concat(test_data, ignore_index=True).sample(frac=1, random_state=random_seed).reset_index(drop=True)
    
    n_total = len(df_train) + len(df_val) + len(df_test)
    print(f"\nDataset split completed (stratified split, ensuring each class has samples in all datasets):")
    print(f"  Training set: {len(df_train)} samples ({len(df_train)/n_total*100:.1f}%)")
    print(f"  Validation set: {len(df_val)} samples ({len(df_val)/n_total*100:.1f}%)")
    print(f"  Test set: {len(df_test)} samples ({len(df_test)/n_total*100:.1f}%)")
    
    # Verify each class has samples in all datasets
    print(f"\nVerifying class distribution:")
    all_labels_in_train = set(df_train[target_col].unique())
    all_labels_in_val = set(df_val[target_col].unique())
    all_labels_in_test = set(df_test[target_col].unique())
    all_unique_labels_set = set(unique_labels)
    
    missing_in_train = all_unique_labels_set - all_labels_in_train
    missing_in_val = all_unique_labels_set - all_labels_in_val
    missing_in_test = all_unique_labels_set - all_labels_in_test
    
    if missing_in_train:
        print(f"⚠️  Warning: The following labels are missing in training set: {sorted(missing_in_train)}")
    if missing_in_val:
        print(f"⚠️  Warning: The following labels are missing in validation set: {sorted(missing_in_val)}")
    if missing_in_test:
        print(f"⚠️  Warning: The following labels are missing in test set: {sorted(missing_in_test)}")
    
    if not missing_in_train and not missing_in_val and not missing_in_test:
        print(f"✓ All {len(unique_labels)} labels have samples in training, validation, and test sets")
    else:
        print(f"⚠️  Note: Some labels are missing in certain datasets, possibly due to insufficient samples (<3 samples) for some classes")
    
    return df_train, df_val, df_test


def print_label_distribution(df: pd.DataFrame, title: str = "", target_col: str = 'label'):
    """
    Print label distribution statistics
    
    Args:
        df: DataFrame
        title: Title
        target_col: Target column name
    """
    if title:
        print(f"\n{title}")
        print("=" * 50)
    
    if target_col not in df.columns:
        print(f"Warning: Target column '{target_col}' does not exist")
        return
    
    labels = df[target_col].values
    unique_labels, counts = np.unique(labels, return_counts=True)
    
    print(f"Label distribution:")
    for label, count in zip(unique_labels, counts):
        percentage = count / len(labels) * 100
        print(f"  Label {label}: {count} samples ({percentage:.2f}%)")
    
    print(f"Total: {len(labels)} samples")


def apply_sampling_strategy(df_train: pd.DataFrame, target_col: str = 'label') -> pd.DataFrame:
    """
    Apply sampling strategy to training set
    
    New sampling strategy:
    - C1 data: undersample to 60% of original
    - C2 data: no sampling (keep as is)
    - C3 and D2: oversample to 5x original
    - Other data: expand to 20x original
    
    Args:
        df_train: Training set DataFrame
        target_col: Target column name
        
    Returns:
        Sampled training set DataFrame
    """
    print(f"\nStarting to apply sampling strategy...")
    print(f"Original training set size: {len(df_train)} samples")
    
    # Print original distribution
    print_label_distribution(df_train, "Original training set label distribution")
    
    # 1. Group by label
    label_groups = {}
    for label in df_train[target_col].unique():
        label_groups[label] = df_train[df_train[target_col] == label]
    
    print(f"\nStep 1: Grouping by label")
    for label, group in sorted(label_groups.items()):
        print(f"  Label {label}: {len(group)} samples")
    
    # 2. Apply sampling strategy
    sampled_data = []

    for label, group in sorted(label_groups.items()):
        # Determine label type and apply corresponding sampling strategy
        if label == 'C1':
            # C1 data: undersample to 60% of original
            factor = 0.6
            sample_size = max(int(len(group) * factor), 1)
            sampled_group = group.sample(n=sample_size, random_state=42)
            print(f"Label {label}: Undersampling {len(group)} -> {len(sampled_group)} ({factor * 100:.0f}%)")
        elif label == 'C2':
            # C2 data: no change
            sampled_group = group.copy()
            print(f"Label {label}: No sampling, keeping {len(group)} samples")
        elif label == 'C3' or label == 'D2':
            # C3 and D2: oversample to 5x original
            factor = 5
            sample_size = max(int(len(group) * factor), len(group))
            sampled_group = group.sample(n=sample_size, replace=True, random_state=42)
            print(f"Label {label}: Oversampling {len(group)} -> {len(sampled_group)} ({factor}x)")
        else:
            # Other data: expand to 20x original
            factor = 20
            sample_size = max(int(len(group) * factor), len(group))
            sampled_group = group.sample(n=sample_size, replace=True, random_state=42)
            print(f"Label {label}: Oversampling {len(group)} -> {len(sampled_group)} ({factor}x)")

        sampled_data.append(sampled_group)
    
    # 3. Merge all sampled data
    df_train_sampled = pd.concat(sampled_data, ignore_index=True)
    
    # 4. Shuffle data
    df_train_sampled = df_train_sampled.sample(frac=1, random_state=42).reset_index(drop=True)
    
    print(f"\nSampling completed!")
    print(f"Training set size after sampling: {len(df_train_sampled)} samples")
    print_label_distribution(df_train_sampled, "Training set label distribution after sampling")
    
    return df_train_sampled


def process_dataset_with_sampling(csv_file: str, output_dir: str, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, random_seed: int = 42):
    """
    Process dataset: split first, then sample training set
    
    Args:
        csv_file: Original CSV file path
        output_dir: Output directory
        train_ratio: Training set ratio
        val_ratio: Validation set ratio
        test_ratio: Test set ratio
        random_seed: Random seed
    """
    print("=" * 60)
    print("Starting dataset processing: split first, then sample training set")
    print("=" * 60)
    
    # 1. Read original data
    print(f"Reading original data: {csv_file}")
    # Try to read CSV file, automatically add column names if no header
    try:
        df = pd.read_csv(csv_file)
        # Check if header exists (judge by whether column names contain 'protein' or 'label')
        lower_cols = [str(c).lower() for c in df.columns]
        if not (any('protein' in c for c in lower_cols) and any('label' in c for c in lower_cols)):
            # No header, add column names
            df = pd.read_csv(csv_file, header=None, names=['protein_id', 'label'])
            print("Detected CSV file without header, automatically added column names: protein_id, label")
    except Exception as e:
        # If reading fails, try no-header method
        print(f"Trying to read without header: {e}")
        df = pd.read_csv(csv_file, header=None, names=['protein_id', 'label'])
    
    # Ensure column names are standardized
    if 'protein_id' not in df.columns or 'label' not in df.columns:
        # Try to rename columns
        if len(df.columns) == 2:
            df.columns = ['protein_id', 'label']
        else:
            raise ValueError(f"CSV file must contain two columns (protein_id and label), current columns: {list(df.columns)}")
    
    print(f"Original data size: {len(df)} samples")
    
    # 2. First perform data split (use stratified split, ensure each class has samples in all datasets)
    print(f"\nStep 1: Splitting original data (stratified split)")
    df_train, df_val, df_test = split_dataset_first(df, train_ratio, val_ratio, test_ratio, random_seed, target_col='label')
    
    # 3. Apply sampling strategy to training set
    print(f"\nStep 2: Applying sampling strategy to training set")
    df_train_sampled = apply_sampling_strategy(df_train)
    
    # 4. Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # 5. Save processed data
    train_output_path = os.path.join(output_dir, 'protein_classification_train.csv')
    val_output_path = os.path.join(output_dir, 'protein_classification_val.csv')
    test_output_path = os.path.join(output_dir, 'protein_classification_test.csv')
    
    df_train_sampled.to_csv(train_output_path, index=False)
    df_val.to_csv(val_output_path, index=False)
    df_test.to_csv(test_output_path, index=False)
    
    print(f"\nData saved successfully:")
    print(f"  Training set: {train_output_path} ({len(df_train_sampled)} samples)")
    print(f"  Validation set: {val_output_path} ({len(df_val)} samples)")
    print(f"  Test set: {test_output_path} ({len(df_test)} samples)")
    
    # 6. Print final statistics
    print(f"\nFinal dataset statistics:")
    print(f"  Original data: {len(df)} samples")
    print(f"  Training set: {len(df_train_sampled)} samples")
    print(f"  Validation set: {len(df_val)} samples")
    print(f"  Test set: {len(df_test)} samples")
    print(f"  Total: {len(df_train_sampled) + len(df_val) + len(df_test)} samples")
    
    return df_train_sampled, df_val, df_test


if __name__ == "__main__":
    # Data sampling script
    # Use specified data path
    csv_file = "one_label.csv"  # Use data file in current directory
    output_dir = "data_output"
    
    # Process dataset
    # Sampling strategy:
    # - C1 data: undersample to 60% of original
    # - C2 data: no sampling (keep as is)
    # - C3 and D2: oversample to 5x original
    # - Other data: expand to 20x original
    # - Validation and test sets: no sampling (keep as is)
    print("=" * 70)
    print("Data Sampling Script")
    print("=" * 70)
    print(f"Input file: {csv_file}")
    print(f"Output directory: {output_dir}")
    print("=" * 70)
    
    df_train, df_val, df_test = process_dataset_with_sampling(
        csv_file=csv_file,
        output_dir=output_dir,
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1,
        random_seed=42
    )
    
    print("\n" + "=" * 70)
    print("Sampling completed!")
    print("=" * 70)





