import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Batch
from torch.utils.data import Dataset
from typing import Optional, List, Sequence, Tuple
import os
import pickle
from protein_utils import extract_ca_coordinates, create_protein_graph, process_protein_dataset, process_protein_dataset_from_dir
from protein_utils import load_path_homology_features
import random
try:
    from sklearn.model_selection import train_test_split
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
    print("Warning: sklearn unavailable, will use random split instead of stratified split")

def merge_labels(label):
    """
    Merge labels function:
    - C7-C9 merged to one label (represented as C7-C9)
    - C10-C17 merged to one label (represented as C10-C17)
    - D6-D8 merged to one label (represented as D6-D8)
    Other labels remain unchanged
    
    Args:
        label: Original label (string)
    
    Returns:
        Merged label (string)
    """
    # Merge C7-C9
    if label in ['C7', 'C8', 'C9']:
        return 'C7-C9'
    # Merge C10-C17
    elif label in ['C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17']:
        return 'C10-C17'
    # Merge D6-D8
    elif label in ['D6', 'D7', 'D8']:
        return 'D6-D8'
    # Other labels remain unchanged
    else:
        return label

def print_label_distribution(df_train, df_val, df_test, target_col='label', is_sampled=False):
    """
    Print the number of data samples for each label in Training, Validation, and Test sets
    Enhanced version: displays sampling information and detailed statistics
    
    Args:
        df_train: Training DataFrame
        df_val: Validation DataFrame
        df_test: Test DataFrame
        target_col: Target column name, default is 'label'
        is_sampled: Whether sampled data is used
    """
    print("\n" + "="*70)
    print("Dataset label distribution statistics" + (" (after sampling)" if is_sampled else " (original data)"))
    print("="*70)
    
    # Check if target column exists
    if target_col not in df_train.columns:
        print(f"Warning: Target column \'{target_col}\' does not exist in training set")
        print(f"Available columns: {list(df_train.columns)}")
        return
    
    # Get labels from all datasets
    train_labels = df_train[target_col].values
    val_labels = df_val[target_col].values
    test_labels = df_test[target_col].values
    
    # Get all unique labels
    all_unique_labels = sorted(set(list(train_labels) + list(val_labels) + list(test_labels)))
    
    print(f"Target column: {target_col}")
    print(f"Found unique labels: {all_unique_labels}")
    if is_sampled:
        print("📊 Training set uses sampled data, validation and test sets maintain original distribution")
    print()
    
    # Create statistics table
    print(f"{'Label':<8} {'Training':<12} {'Validation':<12} {'Test':<12} {'Total':<12}")
    print("-" * 70)
    
    total_train = len(df_train)
    total_val = len(df_val)
    total_test = len(df_test)
    total_all = total_train + total_val + total_test
    
    # Calculate statistics for each class
    label_stats = []
    for label in all_unique_labels:
        train_count = (train_labels == label).sum()
        val_count = (val_labels == label).sum()
        test_count = (test_labels == label).sum()
        total_count = train_count + val_count + test_count
        
        train_ratio = train_count / total_train * 100 if total_train > 0 else 0
        val_ratio = val_count / total_val * 100 if total_val > 0 else 0
        test_ratio = test_count / total_test * 100 if total_test > 0 else 0
        
        label_stats.append({
            'label': label,
            'train_count': train_count,
            'val_count': val_count,
            'test_count': test_count,
            'total_count': total_count,
            'train_ratio': train_ratio,
            'val_ratio': val_ratio,
            'test_ratio': test_ratio
        })
        
        # Display statistics with percentages
        train_str = f"{train_count} ({train_ratio:.1f}%)"
        val_str = f"{val_count} ({val_ratio:.1f}%)"
        test_str = f"{test_count} ({test_ratio:.1f}%)"
        
        print(f"{label:<8} {train_str:<12} {val_str:<12} {test_str:<12} {total_count:<12}")
    
    print("-" * 70)
    print(f"{'Total':<8} {total_train:<12} {total_val:<12} {total_test:<12} {total_all:<12}")
    
    # Check if data distribution is balanced
    print("\n" + "="*70)
    print("Data distribution balance analysis (Training)")
    print("="*70)
    
    # Calculate proportions of each class in Training set
    train_label_counts = [(label, (train_labels == label).sum()) for label in all_unique_labels]
    train_label_counts.sort(key=lambda x: x[1], reverse=True)
    
    if len(train_label_counts) > 1:
        max_count = train_label_counts[0][1]
        min_count = train_label_counts[-1][1]
        max_label = train_label_counts[0][0]
        min_label = train_label_counts[-1][0]
        imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
        
        print(f"Most frequent class: {max_label} ({max_count} samples, {max_count/total_train*100:.2f}%)")
        print(f"Least frequent class: {min_label} ({min_count} samples, {min_count/total_train*100:.2f}%)")
        print(f"Class imbalance ratio: {imbalance_ratio:.2f}:1")
        
        # Calculate standard deviation to evaluate distribution uniformity
        counts = [count for _, count in train_label_counts]
        mean_count = np.mean(counts)
        std_count = np.std(counts)
        cv = std_count / mean_count if mean_count > 0 else 0  # Coefficient of variation
        
        print(f"Average per classSamples: {mean_count:.1f}")
        print(f"Standard deviation: {std_count:.1f}")
        print(f"Coefficient of variation (CV): {cv:.3f} (smaller is more balanced)")
        
        if imbalance_ratio > 10:
            print("⚠️  Warning: Data severely imbalanced, recommend using class weights or resampling techniques")
        elif imbalance_ratio > 5:
            print("⚠️  Note: Data moderately imbalanced, recommend using class weights")
        elif imbalance_ratio > 2:
            print("ℹ️  Tip: Data slightly imbalanced, class weights may help improve performance")
        else:
            print("✅ Data distribution relatively balanced")
        
        if is_sampled:
            print("\n📊 Sampling effect evaluation:")
            # Calculate ideal balanced distribution (equal samples per class)
            ideal_count = total_train / len(all_unique_labels)
            print(f"Ideal per classSamples: {ideal_count:.1f}")
            
            # Calculate deviation of each class from ideal value
            deviations = []
            for label, count in train_label_counts:
                deviation = abs(count - ideal_count) / ideal_count * 100
                deviations.append(deviation)
                if deviation > 20:  # Deviation exceeds 20%
                    print(f"  ⚠️  Class {label}: deviation {deviation:.1f}% (current: {count}, ideal: {ideal_count:.1f})")
            
            avg_deviation = np.mean(deviations)
            print(f"Average deviation: {avg_deviation:.2f}%")
            if avg_deviation < 10:
                print("✅ Sampling effect good, data distribution close to balanced")
            elif avg_deviation < 20:
                print("ℹ️  Sampling effect average, data distribution basically balanced")
            else:
                print("⚠️  Sampling effect poor, data distribution still imbalanced")
    
    print("="*70)
    print()

class ProteinDataset(Dataset):
    """Protein dataset class"""
    
    def __init__(
        self,
        df: pd.DataFrame,
        protein_graphs: Sequence[Data],
        target: str = None,
        transform=None,
        mean_train=None,
        std_train=None,
        normalize_labels=True,
        label_mapping=None,
    ):
        """
        Initialize protein dataset
        
        Args:
            df: DataFrame containing protein information
            protein_graphs: List of protein graphs
            target: Target column name
            transform: Data transformation
            mean_train: Training mean (for normalization)
            std_train: Training standard deviation (for normalization)
            normalize_labels: Whether to normalize labels
        """
        self.df = df
        self.protein_graphs = protein_graphs
        self.target = target
        self.transform = transform
        self.mean_train = mean_train
        self.std_train = std_train
        self.normalize_labels = normalize_labels
        self.label_mapping = label_mapping
        
        # Process labels
        if target is not None:
            if not torch.is_tensor(self.df[target]):
                # If label mapping exists (classification task), apply mapping first then convert to tensor
                if label_mapping is not None:
                    # Map original labels to consecutive indices 0, 1, 2, ...
                    raw_labels = self.df[target].values
                    mapped_labels = []
                    for label in raw_labels:
                        if label in label_mapping:
                            mapped_labels.append(label_mapping[label])
                        else:
                            print(f"Warning: Unknown label {label}, using 0 as default")
                            mapped_labels.append(0)
                    labels = torch.tensor(mapped_labels, dtype=torch.long)
                else:
                    # Regression task: try to convert to numeric type
                    try:
                        # Try direct conversion to numeric type
                        labels = torch.tensor(self.df[target].values, dtype=torch.get_default_dtype())
                    except (TypeError, ValueError):
                        # If failed, try converting to float first
                        labels = torch.tensor(pd.to_numeric(self.df[target], errors='coerce').values, 
                                            dtype=torch.get_default_dtype())
                        # Check for NaN values
                        if torch.isnan(labels).any():
                            raise ValueError(f"Cannot convert label column '{target}' to numeric type, and NaN values exist. Please check data or provide label_mapping.")
            else:
                labels = self.df[target]
            
            # Apply normalization (only for regression tasks, and after label mapping is processed)
            if label_mapping is None and normalize_labels and mean_train is not None and std_train is not None:
                # Normalization for regression task
                labels = (labels - mean_train) / std_train
            
            self.labels = labels
        else:
            self.labels = None
    
    def __len__(self):
        return len(self.protein_graphs)
    
    def __getitem__(self, idx):
        """Get single protein sample"""
        protein_graph = self.protein_graphs[idx]
        
        if self.labels is not None:
            label = self.labels[idx]
            return protein_graph, label
        else:
            return protein_graph
    
    @staticmethod
    def collate(samples):
        """Batch processing function"""
        if isinstance(samples[0], tuple):
            graphs, labels = map(list, zip(*samples))
            
            # First check and unify node feature dimensions
            node_feature_dims = [graph.x.size(1) if hasattr(graph, 'x') and graph.x is not None else 0 for graph in graphs]
            if len(set(node_feature_dims)) > 1:
                # Node feature dimensions are inconsistent, need to unify
                max_dim = max(node_feature_dims)
                min_dim = min(node_feature_dims)
                print(f"Warning: Node feature dimensions are inconsistent: min={min_dim}, max={max_dim}")
                print(f"  Individual dimensions: {node_feature_dims}")
                print(f"  Unifying to dimension {max_dim} by padding smaller features")
                
                # Unify node feature dimensions for all graphs
                for i, graph in enumerate(graphs):
                    if hasattr(graph, 'x') and graph.x is not None:
                        current_dim = graph.x.size(1)
                        if current_dim < max_dim:
                            # Need padding
                            num_nodes = graph.x.size(0)
                            padding = torch.zeros(num_nodes, max_dim - current_dim, 
                                                dtype=graph.x.dtype, device=graph.x.device)
                            graph.x = torch.cat([graph.x, padding], dim=1)
                            print(f"  Graph {i}: Padded node features from {current_dim} to {max_dim} dims")
                        elif current_dim > max_dim:
                            # Need truncation (this should not happen, but for safety)
                            graph.x = graph.x[:, :max_dim]
                            print(f"  Graph {i}: Truncated node features from {current_dim} to {max_dim} dims")
            
            # Check codebar features
            has_codebar = all(hasattr(graph, 'codebar_features') for graph in graphs)
            if has_codebar:
                print(f"Debug: Found codebar features in {len(graphs)} graphs")
                for i, graph in enumerate(graphs):
                    if hasattr(graph, 'codebar_features'):
                        print(f"  Graph {i}: codebar_features shape {graph.codebar_features.shape}")
            
            # Check path_homology features
            # path_homology is graph-level (protein-level) feature, not graph structure feature
            # Each protein has a 360-dimensional path_homology feature vector
            has_path_homology = any(hasattr(graph, 'path_homology_features') for graph in graphs)
            if has_path_homology:
                graphs_with_ph = sum(1 for g in graphs if hasattr(g, 'path_homology_features') and g.path_homology_features is not None)
                print(f"Debug: Found path_homology features (graph-level) in {graphs_with_ph}/{len(graphs)} graphs, will process all graphs")
                for i, graph in enumerate(graphs):
                    if hasattr(graph, 'path_homology_features') and graph.path_homology_features is not None:
                        print(f"  Graph {i}: path_homology_features (graph-level) shape {graph.path_homology_features.shape}")
                    else:
                        print(f"  Graph {i}: path_homology_features missing, will use zero vector")
            
            # Check ESM2 features
            # ESM2 is graph-level (protein-level) feature, not graph structure feature
            # Each protein has a 1280-dimensional ESM2 feature vector
            has_esm2 = any(hasattr(graph, 'esm2_features') for graph in graphs)
            if has_esm2:
                graphs_with_esm2 = sum(1 for g in graphs if hasattr(g, 'esm2_features') and g.esm2_features is not None)
                print(f"Debug: Found esm2_features (graph-level) in {graphs_with_esm2}/{len(graphs)} graphs, will process all graphs")
                for i, graph in enumerate(graphs):
                    if hasattr(graph, 'esm2_features') and graph.esm2_features is not None:
                        print(f"  Graph {i}: esm2_features (graph-level) shape {graph.esm2_features.shape}")
                    else:
                        print(f"  Graph {i}: esm2_features missing, will use zero vector")
            
            # Manually process codebar features for batching, ensure correct dimensions
            if has_codebar:
                print(f"Debug: Manually processing codebar features for {len(graphs)} graphs")
                
                # Collect codebar features and node counts for each graph
                codebar_features_list = []
                node_counts = []
                
                for i, graph in enumerate(graphs):
                    if hasattr(graph, 'codebar_features') and graph.codebar_features is not None:
                        # Ensure codebar features are 2D tensor [num_residues, 35]
                        if graph.codebar_features.dim() == 1:
                            # If 1D, reshape to 2D
                            if graph.codebar_features.size(0) % 35 == 0:
                                num_residues = graph.codebar_features.size(0) // 35
                                graph.codebar_features = graph.codebar_features.view(num_residues, 35)
                            else:
                                print(f"Warning: Graph {i} codebar features cannot be reshaped to [num_residues, 35]")
                                # Create zero features
                                graph.codebar_features = torch.zeros(graph.num_nodes, 35, dtype=torch.float32)
                        
                        # Verify feature dimensions
                        if graph.codebar_features.size(1) != 35:
                            print(f"Warning: Graph {i} codebar features have wrong second dimension: {graph.codebar_features.size(1)}")
                            # Fix second dimension
                            if graph.codebar_features.size(1) > 35:
                                graph.codebar_features = graph.codebar_features[:, :35]
                            else:
                                padding = torch.zeros(graph.codebar_features.size(0), 35 - graph.codebar_features.size(1), 
                                                   dtype=torch.float32, device=graph.codebar_features.device)
                                graph.codebar_features = torch.cat([graph.codebar_features, padding], dim=1)
                        
                        # Verify if first dimension matches node count
                        if graph.codebar_features.size(0) != graph.num_nodes:
                            print(f"Warning: Graph {i} codebar features first dimension ({graph.codebar_features.size(0)}) "
                                  f"doesn't match num_nodes ({graph.num_nodes})")
                            # Fix first dimension
                            if graph.codebar_features.size(0) > graph.num_nodes:
                                graph.codebar_features = graph.codebar_features[:graph.num_nodes]
                            else:
                                padding = torch.zeros(graph.num_nodes - graph.codebar_features.size(0), 35, 
                                                   dtype=torch.float32, device=graph.codebar_features.device)
                                graph.codebar_features = torch.cat([graph.codebar_features, padding], dim=0)
                        
                        print(f"  Graph {i}: Final codebar_features shape {graph.codebar_features.shape}")
                        codebar_features_list.append(graph.codebar_features)
                        node_counts.append(graph.num_nodes)
                    else:
                        # If no codebar features, create zero features
                        print(f"  Graph {i}: Creating zero codebar features for {graph.num_nodes} nodes")
                        zero_features = torch.zeros(graph.num_nodes, 35, dtype=torch.float32)
                        codebar_features_list.append(zero_features)
                        node_counts.append(graph.num_nodes)
                
                # Calculate total number of residues
                total_residues = sum(node_counts)
                print(f"Debug: Total residues across all graphs: {total_residues}")
                print(f"Debug: Individual graph residue counts: {node_counts}")
                
                # Manually concatenate all codebar features
                manual_codebar_features = torch.cat(codebar_features_list, dim=0)
                print(f"Debug: Manually concatenated codebar features shape: {manual_codebar_features.shape}")
                print(f"Debug: Expected shape: [{total_residues}, 35]")
                
                # Verify manual concatenation result
                if manual_codebar_features.size(0) != total_residues:
                    print(f"Error: Manual concatenation failed! Expected {total_residues}, got {manual_codebar_features.size(0)}")
                if manual_codebar_features.size(1) != 35:
                    print(f"Error: Manual concatenation failed! Expected 35, got {manual_codebar_features.size(1)}")
                
                # Temporarily remove codebar features to prevent PyTorch Geometric automatic batching
                for graph in graphs:
                    if hasattr(graph, 'codebar_features'):
                        graph._codebar_features_backup = graph.codebar_features
                        delattr(graph, 'codebar_features')
            
            # Ensure Batch.from_data_list correctly creates batch indices
            batched_graph = Batch.from_data_list(graphs)
            
            # Verify if batch indices are correctly created
            if not hasattr(batched_graph, 'batch') or batched_graph.batch is None:
                # Manually create batch indices
                batch_indices = []
                for i, graph in enumerate(graphs):
                    batch_indices.extend([i] * graph.num_nodes)
                batched_graph.batch = torch.tensor(batch_indices, dtype=torch.long)
            
            # Manually process path_homology features for batching
            # path_homology is graph-level (protein-level) feature, each protein has a 360-dimensional vector
            # Collect path_homology features from all graphs (before Batch.from_data_list)
            ph_features_list = []
            for i, graph in enumerate(graphs):
                if hasattr(graph, 'path_homology_features') and graph.path_homology_features is not None:
                    # Ensure path_homology features are 1D vector, must be 360-dimensional
                    ph_feat = graph.path_homology_features
                    if ph_feat.dim() == 1:
                        if ph_feat.size(0) == 360:
                            ph_features_list.append(ph_feat)
                        else:
                            # Incorrect dimension, this is an error case
                            print(f"Error: Graph {i} path_homology_features has wrong size: {ph_feat.size(0)}, expected 360. Using zero vector.")
                            print(f"  This indicates a problem in feature loading. Please check the file format.")
                            ph_features_list.append(torch.zeros(360, dtype=torch.float32))
                    else:
                        # If 2D, take the first dimension
                        if ph_feat.size(0) > 0 and ph_feat.size(1) == 360:
                            ph_features_list.append(ph_feat[0])
                        else:
                            # Incorrect dimension
                            print(f"Error: Graph {i} path_homology_features has wrong shape: {ph_feat.shape}, expected (1, 360) or (360,). Using zero vector.")
                            ph_features_list.append(torch.zeros(360, dtype=torch.float32))
                else:
                    # If no path_homology features, create zero vector
                    print(f"Warning: Graph {i} missing path_homology_features, using zero vector")
                    ph_features_list.append(torch.zeros(360, dtype=torch.float32))
            
            # Manually process ESM2 features for batching
            # ESM2 is graph-level (protein-level) feature, each protein has a 1280-dimensional vector
            # Collect ESM2 features from all graphs (before Batch.from_data_list)
            esm2_features_list = []
            for i, graph in enumerate(graphs):
                if hasattr(graph, 'esm2_features') and graph.esm2_features is not None:
                    # Ensure ESM2 features are 1D vector
                    esm2_feat = graph.esm2_features
                    if esm2_feat.dim() == 1:
                        if esm2_feat.size(0) == 1280:
                            esm2_features_list.append(esm2_feat)
                        else:
                            print(f"Warning: Graph {i} esm2_features has wrong size: {esm2_feat.size(0)}, expected 1280. Using zero vector.")
                            esm2_features_list.append(torch.zeros(1280, dtype=torch.float32))
                    else:
                        # If 2D, take the first dimension
                        if esm2_feat.size(0) > 0 and esm2_feat.size(1) == 1280:
                            esm2_features_list.append(esm2_feat[0])
                        else:
                            print(f"Warning: Graph {i} esm2_features has wrong shape: {esm2_feat.shape}. Using zero vector.")
                            esm2_features_list.append(torch.zeros(1280, dtype=torch.float32))
                else:
                    # If no ESM2 features, create zero vector
                    print(f"Warning: Graph {i} missing esm2_features, using zero vector")
                    esm2_features_list.append(torch.zeros(1280, dtype=torch.float32))
            
            # Stack path_homology features into shape [batch_size, 360]
            try:
                batched_ph_features = torch.stack(ph_features_list, dim=0)
                print(f"Debug: Manually batched path_homology features shape: {batched_ph_features.shape}")
                assert batched_ph_features.shape == (len(graphs), 360), f"Expected shape ({len(graphs)}, 360), got {batched_ph_features.shape}"
            except Exception as e:
                print(f"Error stacking path_homology features: {e}")
                print(f"  Number of graphs: {len(graphs)}")
                print(f"  Number of ph_features: {len(ph_features_list)}")
                for i, ph_feat in enumerate(ph_features_list):
                    print(f"    Graph {i}: shape {ph_feat.shape if isinstance(ph_feat, torch.Tensor) else type(ph_feat)}")
                # Create zero vector as fallback
                batched_ph_features = torch.zeros(len(graphs), 360, dtype=torch.float32)
                print(f"  Created fallback zero features with shape: {batched_ph_features.shape}")
            
            # Stack ESM2 features into shape [batch_size, 1280]
            try:
                batched_esm2_features = torch.stack(esm2_features_list, dim=0)
                print(f"Debug: Manually batched esm2_features shape: {batched_esm2_features.shape}")
                assert batched_esm2_features.shape == (len(graphs), 1280), f"Expected shape ({len(graphs)}, 1280), got {batched_esm2_features.shape}"
            except Exception as e:
                print(f"Error stacking esm2_features: {e}")
                print(f"  Number of graphs: {len(graphs)}")
                print(f"  Number of esm2_features: {len(esm2_features_list)}")
                for i, esm2_feat in enumerate(esm2_features_list):
                    print(f"    Graph {i}: shape {esm2_feat.shape if isinstance(esm2_feat, torch.Tensor) else type(esm2_feat)}")
                # Create zero vector as fallback
                batched_esm2_features = torch.zeros(len(graphs), 1280, dtype=torch.float32)
                print(f"  Created fallback zero features with shape: {batched_esm2_features.shape}")
            
            # Check for any abnormal values
            if torch.isnan(batched_ph_features).any():
                print("Warning: NaN values detected in path_homology features")
            if torch.isinf(batched_ph_features).any():
                print("Warning: Inf values detected in path_homology features")
            if torch.isnan(batched_esm2_features).any():
                print("Warning: NaN values detected in esm2_features")
            if torch.isinf(batched_esm2_features).any():
                print("Warning: Inf values detected in esm2_features")
            
            # Temporarily remove path_homology and ESM2 features to prevent PyTorch Geometric automatic batching
            for graph in graphs:
                if hasattr(graph, 'path_homology_features'):
                    graph._path_homology_features_backup = graph.path_homology_features
                    delattr(graph, 'path_homology_features')
                if hasattr(graph, 'esm2_features'):
                    graph._esm2_features_backup = graph.esm2_features
                    delattr(graph, 'esm2_features')
            
            # Manually add codebar features to batched graph
            if has_codebar:
                print(f"Debug: Adding manually processed codebar features to batched graph")
                batched_graph.codebar_features = manual_codebar_features
                print(f"Debug: Final batched codebar_features shape: {batched_graph.codebar_features.shape}")
                
                # Verify if final dimensions are correct
                if batched_graph.codebar_features.size(0) == batched_graph.x.size(0):
                    print(f"Debug: Codebar features dimension matches node count: {batched_graph.codebar_features.size(0)}")
                else:
                    print(f"Warning: Final dimension mismatch! Codebar: {batched_graph.codebar_features.size(0)}, Nodes: {batched_graph.x.size(0)}")
                
                if batched_graph.codebar_features.size(1) == 35:
                    print(f"Debug: Codebar features second dimension is correct: 35")
                else:
                    print(f"Warning: Final second dimension wrong! Expected 35, got {batched_graph.codebar_features.size(1)}")
                
                # Restore original graph's codebar features
                for i, graph in enumerate(graphs):
                    if hasattr(graph, '_codebar_features_backup'):
                        graph.codebar_features = graph._codebar_features_backup
                        delattr(graph, '_codebar_features_backup')
            
            # Process path_homology and ESM2 features for batching (graph-level features)
            # Ensure batched graph always has path_homology_features and esm2_features
            # Batched shapes should be [batch_size, 360] and [batch_size, 1280]
            
            # Set path_homology_features
            if not hasattr(batched_ph_features, 'shape'):
                print(f"Error: batched_ph_features is not a tensor! Type: {type(batched_ph_features)}")
                batched_ph_features = torch.zeros(len(graphs), 360, dtype=torch.float32)
            
            batched_graph.path_homology_features = batched_ph_features
            print(f"Debug: Set path_homology_features on batched_graph, shape: {batched_graph.path_homology_features.shape}")
            
            # Set esm2_features
            if not hasattr(batched_esm2_features, 'shape'):
                print(f"Error: batched_esm2_features is not a tensor! Type: {type(batched_esm2_features)}")
                batched_esm2_features = torch.zeros(len(graphs), 1280, dtype=torch.float32)
            
            batched_graph.esm2_features = batched_esm2_features
            print(f"Debug: Set esm2_features on batched_graph, shape: {batched_graph.esm2_features.shape}")
            
            # Final verification
            assert hasattr(batched_graph, 'path_homology_features'), "batched_graph must have path_homology_features attribute"
            assert batched_graph.path_homology_features is not None, "batched_graph.path_homology_features must not be None"
            assert batched_graph.path_homology_features.shape == (len(graphs), 360), f"Expected shape ({len(graphs)}, 360), got {batched_graph.path_homology_features.shape}"
            assert hasattr(batched_graph, 'esm2_features'), "batched_graph must have esm2_features attribute"
            assert batched_graph.esm2_features is not None, "batched_graph.esm2_features must not be None"
            assert batched_graph.esm2_features.shape == (len(graphs), 1280), f"Expected shape ({len(graphs)}, 1280), got {batched_graph.esm2_features.shape}"
            print(f"Debug: Final verification passed - path_homology_features shape: {batched_graph.path_homology_features.shape}, esm2_features shape: {batched_graph.esm2_features.shape}")
            
            # Restore original graph's features
            for i, graph in enumerate(graphs):
                if hasattr(graph, '_path_homology_features_backup'):
                    graph.path_homology_features = graph._path_homology_features_backup
                    delattr(graph, '_path_homology_features_backup')
                if hasattr(graph, '_esm2_features_backup'):
                    graph.esm2_features = graph._esm2_features_backup
                    delattr(graph, '_esm2_features_backup')
            
            return batched_graph, torch.tensor(labels)
        else:
            # First check and unify node feature dimensions
            node_feature_dims = [graph.x.size(1) if hasattr(graph, 'x') and graph.x is not None else 0 for graph in samples]
            if len(set(node_feature_dims)) > 1:
                # Node feature dimensions are inconsistent, need to unify
                max_dim = max(node_feature_dims)
                min_dim = min(node_feature_dims)
                print(f"Warning: Node feature dimensions are inconsistent: min={min_dim}, max={max_dim}")
                print(f"  Individual dimensions: {node_feature_dims}")
                print(f"  Unifying to dimension {max_dim} by padding smaller features")
                
                # Unify node feature dimensions for all graphs
                for i, graph in enumerate(samples):
                    if hasattr(graph, 'x') and graph.x is not None:
                        current_dim = graph.x.size(1)
                        if current_dim < max_dim:
                            # Need padding
                            num_nodes = graph.x.size(0)
                            padding = torch.zeros(num_nodes, max_dim - current_dim, 
                                                dtype=graph.x.dtype, device=graph.x.device)
                            graph.x = torch.cat([graph.x, padding], dim=1)
                            print(f"  Graph {i}: Padded node features from {current_dim} to {max_dim} dims")
                        elif current_dim > max_dim:
                            # Need truncation (this should not happen, but for safety)
                            graph.x = graph.x[:, :max_dim]
                            print(f"  Graph {i}: Truncated node features from {current_dim} to {max_dim} dims")
            
            # Process path_homology and ESM2 features for batching (no-label branch also needs this)
            # path_homology is graph-level (protein-level) feature, each graph has a 360-dimensional vector
            ph_features_list = []
            for i, graph in enumerate(samples):
                if hasattr(graph, 'path_homology_features') and graph.path_homology_features is not None:
                    ph_feat = graph.path_homology_features
                    if ph_feat.dim() == 1:
                        if ph_feat.size(0) == 360:
                            ph_features_list.append(ph_feat)
                        else:
                            # Incorrect dimension, this is an error case
                            print(f"Error: Graph {i} path_homology_features has wrong size: {ph_feat.size(0)}, expected 360. Using zero vector.")
                            print(f"  This indicates a problem in feature loading. Please check the file format.")
                            ph_features_list.append(torch.zeros(360, dtype=torch.float32))
                    else:
                        if ph_feat.size(0) > 0 and ph_feat.size(1) == 360:
                            ph_features_list.append(ph_feat[0])
                        else:
                            # Incorrect dimension
                            print(f"Error: Graph {i} path_homology_features has wrong shape: {ph_feat.shape}, expected (1, 360) or (360,). Using zero vector.")
                            ph_features_list.append(torch.zeros(360, dtype=torch.float32))
                else:
                    print(f"Warning: Graph {i} missing path_homology_features, using zero vector")
                    ph_features_list.append(torch.zeros(360, dtype=torch.float32))
            
            # ESM2 is graph-level (protein-level) feature, each graph has a 1280-dimensional vector
            esm2_features_list = []
            for i, graph in enumerate(samples):
                if hasattr(graph, 'esm2_features') and graph.esm2_features is not None:
                    esm2_feat = graph.esm2_features
                    if esm2_feat.dim() == 1:
                        if esm2_feat.size(0) == 1280:
                            esm2_features_list.append(esm2_feat)
                        else:
                            print(f"Warning: Graph {i} esm2_features has wrong size: {esm2_feat.size(0)}, expected 1280. Using zero vector.")
                            esm2_features_list.append(torch.zeros(1280, dtype=torch.float32))
                    else:
                        if esm2_feat.size(0) > 0 and esm2_feat.size(1) == 1280:
                            esm2_features_list.append(esm2_feat[0])
                        else:
                            print(f"Warning: Graph {i} esm2_features has wrong shape: {esm2_feat.shape}. Using zero vector.")
                            esm2_features_list.append(torch.zeros(1280, dtype=torch.float32))
                else:
                    print(f"Warning: Graph {i} missing esm2_features, using zero vector")
                    esm2_features_list.append(torch.zeros(1280, dtype=torch.float32))
            
            # Stack features
            try:
                batched_ph_features = torch.stack(ph_features_list, dim=0)
                print(f"Debug: Manually batched path_homology features shape: {batched_ph_features.shape}")
            except Exception as e:
                print(f"Error stacking path_homology features: {e}")
                batched_ph_features = torch.zeros(len(samples), 360, dtype=torch.float32)
                print(f"  Created fallback zero features with shape: {batched_ph_features.shape}")
            
            try:
                batched_esm2_features = torch.stack(esm2_features_list, dim=0)
                print(f"Debug: Manually batched esm2_features shape: {batched_esm2_features.shape}")
            except Exception as e:
                print(f"Error stacking esm2_features: {e}")
                batched_esm2_features = torch.zeros(len(samples), 1280, dtype=torch.float32)
                print(f"  Created fallback zero features with shape: {batched_esm2_features.shape}")
            
            # Temporarily remove features to prevent PyTorch Geometric automatic batching
            for graph in samples:
                if hasattr(graph, 'path_homology_features'):
                    graph._path_homology_features_backup = graph.path_homology_features
                    delattr(graph, 'path_homology_features')
                if hasattr(graph, 'esm2_features'):
                    graph._esm2_features_backup = graph.esm2_features
                    delattr(graph, 'esm2_features')
            
            batched_graph = Batch.from_data_list(samples)
            
            # Verify if batch indices are correctly created
            if not hasattr(batched_graph, 'batch') or batched_graph.batch is None:
                # Manually create batch indices
                batch_indices = []
                for i, graph in enumerate(samples):
                    batch_indices.extend([i] * graph.num_nodes)
                batched_graph.batch = torch.tensor(batch_indices, dtype=torch.long)
            
            # Set features to batched_graph
            batched_graph.path_homology_features = batched_ph_features
            batched_graph.esm2_features = batched_esm2_features
            print(f"Debug: Set path_homology_features and esm2_features on batched_graph (no labels)")
            print(f"  path_homology_features shape: {batched_graph.path_homology_features.shape}")
            print(f"  esm2_features shape: {batched_graph.esm2_features.shape}")
            
            # Verify if setting was successful
            if not hasattr(batched_graph, 'path_homology_features') or batched_graph.path_homology_features is None:
                print(f"Error: Failed to set path_homology_features attribute in no-labels branch!")
                batched_graph.path_homology_features = torch.zeros(len(samples), 360, dtype=torch.float32)
            if not hasattr(batched_graph, 'esm2_features') or batched_graph.esm2_features is None:
                print(f"Error: Failed to set esm2_features attribute in no-labels branch!")
                batched_graph.esm2_features = torch.zeros(len(samples), 1280, dtype=torch.float32)
            
            # Restore original graph's features
            for i, graph in enumerate(samples):
                if hasattr(graph, '_path_homology_features_backup'):
                    graph.path_homology_features = graph._path_homology_features_backup
                    delattr(graph, '_path_homology_features_backup')
                if hasattr(graph, '_esm2_features_backup'):
                    graph.esm2_features = graph._esm2_features_backup
                    delattr(graph, '_esm2_features_backup')
            
            return batched_graph

def split_dataset_auto(df: pd.DataFrame, pdb_dir: str, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, random_seed: int = 42):
    """
    Automatically split dataset into Training, Validation, and Test sets
    
    Args:
        df: DataFrame containing all data
        pdb_dir: Directory containing all PDB files
        train_ratio: Training ratio
        val_ratio: Validation ratio
        test_ratio: Test ratio
        random_seed: Random seed
        
    Returns:
        df_train, df_val, df_test: Split DataFrames
    """
    # Set random seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    # Check if ratios are reasonable
    total_ratio = train_ratio + val_ratio + test_ratio
    if abs(total_ratio - 1.0) > 1e-6:
        raise ValueError(f"Sum of ratios must be 1.0, current is {total_ratio}")
    
    # Get all available PDB files
    pdb_files = []
    for filename in os.listdir(pdb_dir):
        if filename.endswith('.pdb') or filename.endswith('.pdbqt'):
            pdb_files.append(filename.replace('.pdb', '').replace('.pdbqt', ''))
    
    print(f"Found {len(pdb_files)} PDB files")
    
    # Check if all protein IDs in CSV have corresponding PDB files
    available_ids = set(pdb_files)
    csv_ids = set(df['protein_id'].values)
    
    # Find protein IDs with PDB files
    valid_ids = csv_ids.intersection(available_ids)
    missing_ids = csv_ids - available_ids
    
    if missing_ids:
        print(f"Warning: The following protein IDs exist in CSV but have no corresponding PDB files: {missing_ids}")
        print(f"Will only use proteins with PDB files: {len(valid_ids)} proteins")
    
    # Filter data, only keep proteins with PDB files
    df_filtered = df[df['protein_id'].isin(valid_ids)].copy()
    df_filtered = df_filtered.reset_index(drop=True)
    
    print(f"Valid data: {len(df_filtered)} proteins")
    
    # Randomly shuffle data
    df_shuffled = df_filtered.sample(frac=1, random_state=random_seed).reset_index(drop=True)
    
    # Calculate split points
    n_total = len(df_shuffled)
    n_train = int(n_total * train_ratio)
    n_val = int(n_total * val_ratio)
    
    # Split data
    df_train = df_shuffled.iloc[:n_train].copy()
    df_val = df_shuffled.iloc[n_train:n_train + n_val].copy()
    df_test = df_shuffled.iloc[n_train + n_val:].copy()
    
    print(f"Dataset split completed:")
    print(f"  Training: {len(df_train)} samples ({len(df_train)/n_total*100:.1f}%)")
    print(f"  Validation: {len(df_val)} samples ({len(df_val)/n_total*100:.1f}%)")
    print(f"  Test: {len(df_test)} samples ({len(df_test)/n_total*100:.1f}%)")
    
    # Print label distribution statistics
    print_label_distribution(df_train, df_val, df_test)
    
    return df_train, df_val, df_test

def load_protein_data_auto(config):
    """
    Automatically load and split protein data
    
    Args:
        config: Configuration dictionary, needs to contain:
            - dataset_name: Dataset name
            - pdb_dir: PDB file directory
            - csv_file: CSV file path
            - target: Target column name
            - batch_size: Batch size
            - train_ratio: Training ratio (default 0.8)
            - val_ratio: Validation ratio (default 0.1)
            - test_ratio: Test ratio (default 0.1)
            - random_seed: Random seed (default 42)
        
    Returns:
        Training, validation, and test data loaders
    """
    # If only using path_homology features (no need for PDB/graph structure, not using ESM2), use simplified branch
    if config.get('data', {}).get('esm2_only', False):
        return load_protein_data_ph_only(config)

    dataset_name = config['training']['dataset']
    target = config['training'].get('target')
    batch_size = config['training']['batch_size']
    distance_threshold = config.get('protein', {}).get('distance_threshold', 6.0)
    
    # Get data paths
    pdb_dir = config['data'].get('pdb_dir', f'data/{dataset_name}/pdb_files')
    csv_file = config['data'].get('csv_file', f'data/{dataset_name}/{dataset_name}.csv')
    
    # Get split parameters
    train_ratio = config['data'].get('train_ratio', 0.8)
    val_ratio = config['data'].get('val_ratio', 0.1)
    test_ratio = config['data'].get('test_ratio', 0.1)
    random_seed = config['data'].get('random_seed', 42)
    
    print("Loading protein data with automatic splitting...")
    
    # Check if files exist
    if not os.path.exists(csv_file):
        raise FileNotFoundError(f"CSV file does not exist: {csv_file}")
    if not os.path.exists(pdb_dir):
        raise FileNotFoundError(f"PDB directory does not exist: {pdb_dir}")
    
    # Check if pre-split datasets already exist
    output_dir = config['data'].get('output_dir', f'data/{dataset_name}')
    train_csv_path = f'{output_dir}/{dataset_name}_train.csv'
    val_csv_path = f'{output_dir}/{dataset_name}_val.csv'
    test_csv_path = f'{output_dir}/{dataset_name}_test.csv'
    
    # Check if pre-split datasets exist
    existing_datasets = []
    if os.path.exists(train_csv_path):
        existing_datasets.append('train')
    if os.path.exists(val_csv_path):
        existing_datasets.append('val')
    if os.path.exists(test_csv_path):
        existing_datasets.append('test')
    
    if len(existing_datasets) == 3:
        print(f"✅ Found pre-split datasets: {', '.join(existing_datasets)}")
        print(f"Directly loading pre-split datasets, skipping re-split...")
        
        # Directly load pre-split datasets
        df_train = pd.read_csv(train_csv_path)
        df_val = pd.read_csv(val_csv_path)
        df_test = pd.read_csv(test_csv_path)
        
        print(f"Pre-split dataset loading completed:")
        print(f"  Training: {len(df_train)} samples")
        print(f"  Validation: {len(df_val)} samples")
        print(f"  Test: {len(df_test)} samples")
        print(f"  Total: {len(df_train) + len(df_val) + len(df_test)} samples")
        
    else:
        print(f"⚠️  Complete pre-split dataset not found, performing automatic split...")
        print(f"Existing datasets: {existing_datasets if existing_datasets else 'none'}")
        
        # Read CSV file
        print(f"Reading CSV file: {csv_file}")
        df = pd.read_csv(csv_file)
        print(f"Total samples in CSV: {len(df)}")
        
        # Automatically split dataset
        df_train, df_val, df_test = split_dataset_auto(
            df, pdb_dir, train_ratio, val_ratio, test_ratio, random_seed
        )
    
    # Print label distribution statistics (if target column exists)
    # Check if target column exists in any dataset
    if target is not None:
        # Check if target column exists in Training set
        if target in df_train.columns:
            print_label_distribution(df_train, df_val, df_test, target_col=target)
        else:
            print(f"Warning: Target column \'{target}\' not found in dataset")
            print(f"Available columns: {list(df_train.columns)}")
    
    # Create output directory
    # Use output directory specified in config file, or use default if not specified
    output_dir = config['data'].get('output_dir', f'data/{dataset_name}')
    
    # Try to create output directory, use temporary directory if failed
    try:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f'{output_dir}/pdb_files/train', exist_ok=True)
        os.makedirs(f'{output_dir}/pdb_files/val', exist_ok=True)
        os.makedirs(f'{output_dir}/pdb_files/test', exist_ok=True)
        print(f"Output directory created successfully: {output_dir}")
    except PermissionError as e:
        print(f"Insufficient permissions, cannot create directory {output_dir}: {e}")
        # Use temporary directory as fallback
        import tempfile
        output_dir = tempfile.mkdtemp(prefix=f"{dataset_name}_")
        os.makedirs(f'{output_dir}/pdb_files/train', exist_ok=True)
        os.makedirs(f'{output_dir}/pdb_files/val', exist_ok=True)
        os.makedirs(f'{output_dir}/pdb_files/test', exist_ok=True)
        print(f"Using temporary directory: {output_dir}")
    
    # Only save CSV files when re-splitting dataset
    if len(existing_datasets) < 3:
        print(f"Saving newly split dataset to CSV files...")
        try:
            # If file exists, delete first then create (ensure overwrite)
            train_csv_path = f'{output_dir}/{dataset_name}_train.csv'
            val_csv_path = f'{output_dir}/{dataset_name}_val.csv'
            test_csv_path = f'{output_dir}/{dataset_name}_test.csv'
            
            # Delete existing files
            for csv_path in [train_csv_path, val_csv_path, test_csv_path]:
                if os.path.exists(csv_path):
                    os.remove(csv_path)
                    print(f"Deleted old file: {csv_path}")
            
            # Save new CSV files
            df_train.to_csv(train_csv_path, index=False)
            df_val.to_csv(val_csv_path, index=False)
            df_test.to_csv(test_csv_path, index=False)
            print(f"CSV files saved successfully: {output_dir}")
            
        except PermissionError as e:
            print(f"Insufficient permissions when saving CSV files: {e}")
            # Try using temporary directory
            import tempfile
            temp_dir = tempfile.mkdtemp(prefix=f"{dataset_name}_csv_")
            train_csv_path = f'{temp_dir}/{dataset_name}_train.csv'
            val_csv_path = f'{temp_dir}/{dataset_name}_val.csv'
            test_csv_path = f'{temp_dir}/{dataset_name}_test.csv'
            
            df_train.to_csv(train_csv_path, index=False)
            df_val.to_csv(val_csv_path, index=False)
            df_test.to_csv(test_csv_path, index=False)
            print(f"CSV files saved to temporary directory: {temp_dir}")
            
        except Exception as e:
            print(f"Error occurred while saving CSV files: {e}")
            raise
    else:
        print(f"Using existing CSV files, no need to re-save")
    
    print("Processing protein graphs...")
    
    # Check if processed graph files already exist
    train_graphs_file = f'{output_dir}/protein_graphs_train.pkl'
    val_graphs_file = f'{output_dir}/protein_graphs_val.pkl'
    test_graphs_file = f'{output_dir}/protein_graphs_test.pkl'
    
    # Initialize graph variables
    train_graphs = None
    val_graphs = None
    test_graphs = None
    
    # If config specifies force reprocessing, delete old files
    force_reprocess = config.get('training', {}).get('force_reprocess', False)
    if force_reprocess:
        print("Force reprocessing: deleting old graph files...")
        for graph_file in [train_graphs_file, val_graphs_file, test_graphs_file]:
            if os.path.exists(graph_file):
                try:
                    os.remove(graph_file)
                    print(f"Deleted old file: {graph_file}")
                except PermissionError as e:
                    print(f"Unable to delete file {graph_file}: {e}")
    
    # Get cache validation configuration
    enable_cache_validation = config.get('training', {}).get('enable_cache_validation', True)
    
    # Cache file validity check function
    def is_cache_valid(cache_file, expected_count=None):
        """Check if cache file is valid"""
        if not enable_cache_validation:
            # If validation is disabled, only check if file exists
            return os.path.exists(cache_file)
        
        if not os.path.exists(cache_file):
            return False
        
        try:
            # Check file size
            file_size = os.path.getsize(cache_file)
            if file_size < 1024:  # Less than 1KB considered invalid
                print(f"Cache file {cache_file} too small, may be corrupted")
                return False
            
            # Try to load and check data
            with open(cache_file, 'rb') as f:
                data = pickle.load(f)
            
            # Check data type and basic structure
            if not isinstance(data, list) or len(data) == 0:
                print(f"Cache file {cache_file} data format invalid")
                return False
            
            # Check basic structure of first sample
            sample = data[0]
            required_attrs = ['x', 'edge_index', 'edge_attr', 'triangle_index', 'triangle_dis']
            for attr in required_attrs:
                if not hasattr(sample, attr):
                    print(f"Cache file {cache_file} missing required attribute: {attr}")
                    return False
            
            # Check sample count (if expected count is provided)
            if expected_count is not None and len(data) != expected_count:
                print(f"Cache file {cache_file} sample count mismatch: expected {expected_count}, got {len(data)}")
                return False
            
            print(f"Cache file {cache_file} validation passed")
            return True
            
        except Exception as e:
            print(f"Cache file {cache_file} validation failed: {e}")
            return False
    
    # Process Training set
    if is_cache_valid(train_graphs_file, len(df_train)) and not force_reprocess:
        print("Loading pre-processed training protein graphs...")
        try:
            with open(train_graphs_file, 'rb') as f:
                train_graphs = pickle.load(f)
            print(f"Successfully loaded {len(train_graphs)} training protein graphs")
        except (PermissionError, FileNotFoundError) as e:
            print(f"Unable to load training graph file, reprocessing: {e}")
            train_graphs = None
    
    if train_graphs is None:
        print("Processing training PDB files...")
        train_graphs, successful_train_ids = process_protein_dataset_by_ids(
            df_train['protein_id'].tolist(), pdb_dir, distance_threshold, train_graphs_file
        )
        # Filter DataFrame, only keep successfully processed proteins
        df_train = df_train[df_train['protein_id'].isin(successful_train_ids)].reset_index(drop=True)
        print(f"Training set filtered: {len(df_train)} samples after processing")
    
    # Process Validation set
    if is_cache_valid(val_graphs_file, len(df_val)) and not force_reprocess:
        print("Loading pre-processed validation protein graphs...")
        try:
            with open(val_graphs_file, 'rb') as f:
                val_graphs = pickle.load(f)
            print(f"Successfully loaded {len(val_graphs)} validation protein graphs")
        except (PermissionError, FileNotFoundError) as e:
            print(f"Unable to load validation graph file, reprocessing: {e}")
            val_graphs = None
    
    if val_graphs is None:
        print("Processing validation PDB files...")
        val_graphs, successful_val_ids = process_protein_dataset_by_ids(
            df_val['protein_id'].tolist(), pdb_dir, distance_threshold, val_graphs_file
        )
        # Filter DataFrame, only keep successfully processed proteins
        df_val = df_val[df_val['protein_id'].isin(successful_val_ids)].reset_index(drop=True)
        print(f"Validation set filtered: {len(df_val)} samples after processing")
    
    # Process Test set
    if is_cache_valid(test_graphs_file, len(df_test)) and not force_reprocess:
        print("Loading pre-processed test protein graphs...")
        try:
            with open(test_graphs_file, 'rb') as f:
                test_graphs = pickle.load(f)
            print(f"Successfully loaded {len(test_graphs)} test protein graphs")
        except (PermissionError, FileNotFoundError) as e:
            print(f"Unable to load test graph file, reprocessing: {e}")
            test_graphs = None
    
    if test_graphs is None:
        print("Processing test PDB files...")
        test_graphs, successful_test_ids = process_protein_dataset_by_ids(
            df_test['protein_id'].tolist(), pdb_dir, distance_threshold, test_graphs_file
        )
        # Filter DataFrame, only keep successfully processed proteins
        df_test = df_test[df_test['protein_id'].isin(successful_test_ids)].reset_index(drop=True)
        print(f"Test set filtered: {len(df_test)} samples after processing")
    
    # Process label mapping (for classification tasks)
    label_mapping = None
    mean_train = None
    std_train = None
    
    if target is not None:
        if config['model']['criterion'] == 'cross_entropy':
            # Classification task: create label mapping
            all_labels = df_train[target].values
            unique_labels = sorted(set(all_labels))
            print(f"Found unique labels: {unique_labels}")
            
            # First merge labels
            merged_labels = [merge_labels(label) for label in all_labels]
            unique_merged_labels = sorted(set(merged_labels))
            print(f"Merged unique labels: {unique_merged_labels}")
            
            # Create label mapping: map merged labels to 0, 1, 2, ...
            label_mapping = {}
            mapped_idx = 0
            
            # Process all merged labels, map to 0, 1, 2, ... in order
            for merged_label in unique_merged_labels:
                # Find all original labels that should map to this merged label
                for original_label in unique_labels:
                    if merge_labels(original_label) == merged_label:
                        label_mapping[original_label] = mapped_idx
                mapped_idx += 1
            
            print(f"Label mapping: {label_mapping}")
            print(f"Number of classes after mapping: {len(set(label_mapping.values()))}")
            
            # Count samples for each class (by original label)
            for label in sorted(label_mapping.keys()):
                count = (all_labels == label).sum()
                merged_label = merge_labels(label)
                print(f"Label {label} -> merged to {merged_label} -> class {label_mapping[label]}: {count} samples")
            
            # Count merged class samples
            merged_class_counts = {}
            for mapped_class in set(label_mapping.values()):
                count = 0
                for original_label, mapped_label in label_mapping.items():
                    if mapped_label == mapped_class:
                        count += (all_labels == original_label).sum()
                merged_class_counts[mapped_class] = count
                # Find merged label name
                merged_label_name = None
                for orig_label in label_mapping.keys():
                    if label_mapping[orig_label] == mapped_class:
                        merged_label_name = merge_labels(orig_label)
                        break
                print(f"Merged class {mapped_class} ({merged_label_name}): {count} samples")
        else:
            # Regression task: calculate normalization parameters
            train_vals = df_train[target].values
            val_vals = df_val[target].values
            test_vals = df_test[target].values
            
            mean_train = np.mean(train_vals)
            std_train = np.std(train_vals)
    
    # Create datasets
    train_dataset = ProteinDataset(
        df_train, train_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    
    val_dataset = ProteinDataset(
        df_val, val_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    
    test_dataset = ProteinDataset(
        df_test, test_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    
    # Create data loaders
    from torch_geometric.loader import DataLoader
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    return train_loader, val_loader, test_loader, mean_train, std_train

def _read_labels_csv(csv_file: str) -> pd.DataFrame:
    """
    Read CSV with two columns: protein_id, label (no header)
    Automatically adapt to whether header exists; if no header, name with fixed two columns.
    """
    try:
        # Try to read and detect if header exists (roughly judge by whether column names contain 'protein' or 'label')
        df_try = pd.read_csv(csv_file)
        lower_cols = [c.lower() for c in df_try.columns]
        if any('protein' in c for c in lower_cols) and any('label' in c for c in lower_cols):
            df = df_try
            # Standardize column names
            rename_map = {}
            for c in df.columns:
                cl = c.lower()
                if 'protein' in cl and 'protein_id' not in rename_map.values():
                    rename_map[c] = 'protein_id'
                if 'label' in cl and 'label' not in rename_map.values():
                    rename_map[c] = 'label'
            if rename_map:
                df = df.rename(columns=rename_map)
        else:
            # Consider no header
            df = pd.read_csv(csv_file, header=None, names=['protein_id', 'label'])
    except Exception:
        # Fallback: read as no header with two columns
        df = pd.read_csv(csv_file, header=None, names=['protein_id', 'label'])
    # Clean whitespace
    df['protein_id'] = df['protein_id'].astype(str).str.strip()
    # Labels can be string categories
    return df

def stratified_split_dataframe(df: pd.DataFrame, target_col: str, train_ratio: float = 0.8, 
                                 val_ratio: float = 0.1, test_ratio: float = 0.1, 
                                 random_seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Stratified split dataset based on labels, ensuring consistent class proportions in each dataset
    
    Args:
        df: DataFrame containing labels
        target_col: Target column name
        train_ratio: Training ratio
        val_ratio: Validation ratio
        test_ratio: Test ratio
        random_seed: Random seed
        
    Returns:
        (df_train, df_val, df_test): Three split DataFrames
    """
    if not SKLEARN_AVAILABLE:
        # If sklearn is not available, use random split
        print("⚠️  Warning: sklearn unavailable, using random split instead of stratified split")
        random.seed(random_seed)
        np.random.seed(random_seed)
        df_shuffled = df.sample(frac=1, random_state=random_seed).reset_index(drop=True)
        n_total = len(df_shuffled)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        df_train = df_shuffled.iloc[:n_train].copy()
        df_val = df_shuffled.iloc[n_train:n_train + n_val].copy()
        df_test = df_shuffled.iloc[n_train + n_val:].copy()
        return df_train, df_val, df_test
    
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' does not exist")
    
    # Validate ratios
    total_ratio = train_ratio + val_ratio + test_ratio
    if abs(total_ratio - 1.0) > 1e-6:
        raise ValueError(f"Sum of ratios must be 1.0, current is {total_ratio}")
    
    print(f"\nUsing stratified split for dataset（stratified split）...")
    print(f"  Total samples: {len(df)}")
    print(f"  Split ratio: Training={train_ratio}, Validation={val_ratio}, Test={test_ratio}")
    
    # Check sample count for each class
    label_counts = df[target_col].value_counts()
    single_sample_labels = label_counts[label_counts == 1].index.tolist()
    multi_sample_labels = label_counts[label_counts >= 2].index.tolist()
    
    # Separate single-sample classes and multi-sample classes
    if single_sample_labels:
        df_single = df[df[target_col].isin(single_sample_labels)].copy()
        print(f"⚠️  Found {len(single_sample_labels)} classes with only 1 sample, will put all in training set")
        print(f"  Single-sample classes: {single_sample_labels}")
        print(f"  Single-sample class sample count: {len(df_single)}")
    else:
        df_single = pd.DataFrame(columns=df.columns)
    
    if multi_sample_labels:
        df_multi = df[df[target_col].isin(multi_sample_labels)].copy()
    else:
        df_multi = pd.DataFrame(columns=df.columns)
    
    # If all classes have only 1 sample, use random split
    if len(df_multi) == 0:
        print("⚠️  Warning: All classes have only 1 sample, using random split instead of stratified split")
        random.seed(random_seed)
        np.random.seed(random_seed)
        df_shuffled = df.sample(frac=1, random_state=random_seed).reset_index(drop=True)
        n_total = len(df_shuffled)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        df_train = df_shuffled.iloc[:n_train].copy()
        df_val = df_shuffled.iloc[n_train:n_train + n_val].copy()
        df_test = df_shuffled.iloc[n_train + n_val:].copy()
    else:
        # For multi-sample classes, use one-time three-way stratified split
        # Use StratifiedShuffleSplit or manually implement one-time split into three parts
        # Method: split each class separately, then merge
        
        df_train_multi_list = []
        df_val_multi_list = []
        df_test_multi_list = []
        
        # Group by class, split each class separately
        for label in multi_sample_labels:
            label_data = df_multi[df_multi[target_col] == label].copy()
            n_samples = len(label_data)
            
            # Randomly shuffle
            label_data = label_data.sample(frac=1, random_state=random_seed).reset_index(drop=True)
            
            # Intelligently allocate based on sample count
            if n_samples == 1:
                # Only 1 sample: all to Training (this should not happen theoretically, as already filtered)
                df_train_multi_list.append(label_data.copy())
            elif n_samples == 2:
                # Only 2 samples: 1 to Training, 1 to Validation, 0 to Test
                df_train_multi_list.append(label_data.iloc[:1].copy())
                df_val_multi_list.append(label_data.iloc[1:].copy())
            elif n_samples == 3:
                # Only 3 samples: 2 to Training, 1 to Validation, 0 to Test (or 1 to Training, 1 to Validation, 1 to Test)
                # Prioritize ensuring Training has enough samples
                df_train_multi_list.append(label_data.iloc[:2].copy())
                df_val_multi_list.append(label_data.iloc[2:].copy())
            else:
                # 4 or more samples: allocate by ratio
                n_train_samples = max(1, int(n_samples * train_ratio))
                n_val_samples = max(1, int(n_samples * val_ratio))
                n_test_samples = n_samples - n_train_samples - n_val_samples
                
                # If test samples are negative, adjust allocation
                if n_test_samples < 0:
                    n_test_samples = 0
                    n_val_samples = n_samples - n_train_samples
                # If validation samples are negative, all to Training
                if n_val_samples < 0:
                    n_val_samples = 0
                    n_train_samples = n_samples
                
                # Split
                train_end = n_train_samples
                val_end = train_end + n_val_samples
                
                df_train_multi_list.append(label_data.iloc[:train_end].copy())
                if n_val_samples > 0:
                    df_val_multi_list.append(label_data.iloc[train_end:val_end].copy())
                if n_test_samples > 0:
                    df_test_multi_list.append(label_data.iloc[val_end:].copy())
        
        # Merge split results from all classes
        df_train_multi = pd.concat(df_train_multi_list, ignore_index=True) if df_train_multi_list else pd.DataFrame(columns=df_multi.columns)
        df_val_multi = pd.concat(df_val_multi_list, ignore_index=True) if df_val_multi_list else pd.DataFrame(columns=df_multi.columns)
        df_test_multi = pd.concat(df_test_multi_list, ignore_index=True) if df_test_multi_list else pd.DataFrame(columns=df_multi.columns)
        
        # For single-sample classes, all to Training (as they may merge with other classes)
        if len(df_single) > 0:
            df_train = pd.concat([df_train_multi, df_single], ignore_index=True)
            print(f"✓ Added {len(df_single)} single-sample class samples added to training set")
        else:
            df_train = df_train_multi
        
        df_val = df_val_multi
        df_test = df_test_multi
    
    print(f"  Training: {len(df_train)} samples")
    print(f"  Validation: {len(df_val)} samples")
    print(f"  Test: {len(df_test)} samples")
    
    # Print class distribution
    print(f"\nClass distribution:")
    print(f"{'Label':<10} {'Training':<12} {'Validation':<12} {'Test':<12} {'Total':<12}")
    print("-" * 60)
    all_labels = sorted(df[target_col].unique())
    for label in all_labels:
        train_count = len(df_train[df_train[target_col] == label])
        val_count = len(df_val[df_val[target_col] == label])
        test_count = len(df_test[df_test[target_col] == label])
        total_count = len(df[df[target_col] == label])
        print(f"{label:<10} {train_count:<12} {val_count:<12} {test_count:<12} {total_count:<12}")
    
    return df_train, df_val, df_test

def apply_runtime_sampling(df_train: pd.DataFrame, target_col: str = 'label') -> pd.DataFrame:
    """
    Apply runtime dynamic sampling strategy to training set
    
    Sampling strategy:
    - C1 data: undersample to 60% of original
    - C2 data: no sampling (keep as is)
    - C3 and D2: oversample to 5x original
    - Other data: oversample to 20x original
    
    Args:
        df_train: Training DataFrame
        target_col: Target column name
        
    Returns:
        Sampled training DataFrame
    """
    if target_col not in df_train.columns:
        print(f"Warning: Target column \'{target_col}\' does not exist, skipping sampling")
        return df_train
    
    print(f"\nOriginal training set size: {len(df_train)} samples")
    
    # Group by label
    label_groups = {}
    for label in df_train[target_col].unique():
        label_groups[label] = df_train[df_train[target_col] == label]
    
    # Apply sampling strategy
    sampled_data = []
    
    for label, group in sorted(label_groups.items()):
        original_size = len(group)
        
        if label == 'C1':
            # C1 data: undersample to 60% of original
            factor = 0.6
            sample_size = max(int(original_size * factor), 1)
            sampled_group = group.sample(n=sample_size, random_state=42, replace=False)
            print(f"  Label {label}: undersampling {original_size} -> {len(sampled_group)} ({factor * 100:.0f}%)")
        elif label == 'C2':
            # C2 data: no change
            sampled_group = group.copy()
            print(f"  Label {label}: no sampling, keeping {original_size} samples")
        elif label == 'C3' or label == 'D2':
            # C3 and D2: oversample to 5x original
            factor = 5
            sample_size = max(int(original_size * factor), original_size)
            sampled_group = group.sample(n=sample_size, replace=True, random_state=42)
            print(f"  Label {label}: oversampling {original_size} -> {len(sampled_group)} ({factor}x)")
        else:
            # Other data: oversample to 20x original
            factor = 20
            sample_size = max(int(original_size * factor), original_size)
            sampled_group = group.sample(n=sample_size, replace=True, random_state=42)
            print(f"  Label {label}: oversampling {original_size} -> {len(sampled_group)} ({factor}x)")
        
        sampled_data.append(sampled_group)
    
    # Merge all sampled data
    df_train_sampled = pd.concat(sampled_data, ignore_index=True)
    
    # Shuffle data
    df_train_sampled = df_train_sampled.sample(frac=1, random_state=42).reset_index(drop=True)
    
    print(f"\nSampling completed! Training set size: {len(df_train_sampled)} samples")
    
    return df_train_sampled

def load_protein_data_ph_only(config):
    """
    Build data loader based on ESM2 and path_homology mixed features with label CSV, no dependency on PDB/graph structure.
    By default uses original data file for random split, supports runtime dynamic sampling.
    
    Expected configuration:
      data.csv_file: CSV path (can be absolute path), two columns: protein_id, label (header optional)
      data.path_homology_dir: path_homology feature directory (can be absolute path), file name is <protein_id_upper>.csv
      data.esm2_dir: ESM2 feature directory (can be absolute path), file name is <protein_id_lower>.csv
      data.train_ratio / val_ratio / test_ratio optional
      data.use_sampled_data: whether to use sampled data file (default False, use runtime dynamic sampling)
      data.output_dir: output directory (only used when use_sampled_data=true)
      training.target: 'label'
    
    Note: This function uses a mix of ESM2 features (1280-dimensional graph-level features) and path_homology features (360-dimensional graph-level features).
    """
    target = config['training'].get('target', 'label')
    batch_size = config['training']['batch_size']
    csv_file = config['data']['csv_file']
    path_homology_dir = config['data'].get('path_homology_dir', 'symm1mer_012-8dis')
    esm2_dir = config['data'].get('esm2_dir', 'symm_esm2')
    train_ratio = config['data'].get('train_ratio', 0.8)
    val_ratio = config['data'].get('val_ratio', 0.1)
    test_ratio = config['data'].get('test_ratio', 0.1)
    random_seed = config['data'].get('random_seed', 42)
    use_sampled_data = config['data'].get('use_sampled_data', False)
    output_dir = config['data'].get('output_dir', './data_output')

    # No longer use pre-generated sampling files, always use runtime sampling
    # Delete old sampling cache files
    if use_sampled_data:
        data_train_path = os.path.join(output_dir, 'protein_classification_train.csv')
        data_val_path = os.path.join(output_dir, 'protein_classification_val.csv')
        data_test_path = os.path.join(output_dir, 'protein_classification_test.csv')
        
        print("="*60)
        print("Detected use_sampled_data=True configuration")
        print("⚠️  Changed to use runtime dynamic sampling, will delete old sampling cache files")
        print("="*60)
        
        # Delete old sampling cache files
        for cache_file in [data_train_path, data_val_path, data_test_path]:
            if os.path.exists(cache_file):
                try:
                    os.remove(cache_file)
                    print(f"✓ Deleted old sampling cache file: {cache_file}")
                except Exception as e:
                    print(f"⚠️  Failed to delete file {cache_file}: {e}")
        
        # Force use runtime sampling
        use_sampled_data = False
    
    if False:  # No longer use pre-generated sampling files
        # Use sampled data files
        print("\nLoading sampled data files...")
        df_train = pd.read_csv(data_train_path)
        df_val = pd.read_csv(data_val_path)
        df_test = pd.read_csv(data_test_path)
        
        print(f"✓ Training samples: {len(df_train)} (sampled)")
        print(f"✓ Validation set samples: {len(df_val)}")
        print(f"✓ Test set samples: {len(df_test)}")
        
        # Immediately print training label distribution after sampling
        if target is not None:
            print("\n" + "="*70)
            print("Training set label distribution details after sampling")
            print("="*70)
            train_labels = df_train[target].values
            unique_labels, counts = np.unique(train_labels, return_counts=True)
            total_train = len(df_train)
            
            print(f"{'Label':<10} {'Samples':<12} {'Percentage':<12} {'Sampling strategy':<20}")
            print("-" * 70)
            
            # Define sampling strategy description
            sampling_strategy_map = {
                'C1': 'undersampling 60%',
                'C2': 'no change',
                'C3': 'oversampling 5x',
                'D2': 'oversampling 5x',
            }
            
            for label, count in zip(unique_labels, counts):
                ratio = count / total_train * 100
                strategy = sampling_strategy_map.get(label, 'oversampling 20x')
                print(f"{label:<10} {count:<12} {ratio:<12.2f}% {strategy:<20}")
            
            print("-" * 70)
            print(f"{'Total':<10} {total_train:<12} {'100.00%':<12} {'':<20}")
            print("="*70)
    else:
        # Using original data file for stratified split and runtime sampling
        if not os.path.exists(csv_file):
            raise FileNotFoundError(f"CSV file does not exist: {csv_file}")

        print("\n" + "="*70)
        print("Using original data file for stratified split and runtime sampling")
        print("="*70)
        
        # Read CSV
        df_all = _read_labels_csv(csv_file)
        if target not in df_all.columns:
            raise ValueError(f"Target column '{target}' not found in CSV, existing columns: {list(df_all.columns)}")

        # Check single-sample classes (for information display, but not removed)
        label_counts_before = df_all[target].value_counts()
        single_sample_labels = label_counts_before[label_counts_before == 1].index.tolist()
        
        if single_sample_labels:
            print(f"\n⚠️  Single-sample class handling:")
            print(f"  Found {len(single_sample_labels)} classes with only 1 sample")
            print(f"  Single-sample classes: {single_sample_labels}")
            single_proteins = df_all[df_all[target].isin(single_sample_labels)]['protein_id'].tolist()
            print(f"  Single-sample class protein IDs: {single_proteins}")
            print(f"  These samples will all be put in training set (as they may merge with other classes)")

        # Use stratified split to ensure consistent class proportions
        # Single-sample classes will be automatically put in Training in stratified_split_dataframe
        df_train, df_val, df_test = stratified_split_dataframe(
            df_all, target, train_ratio, val_ratio, test_ratio, random_seed
        )
        
        # Apply runtime dynamic sampling strategy to training set
        print("\n" + "="*70)
        print("Applying runtime dynamic sampling strategy (training set only)")
        print("="*70)
        print("Sampling strategy:")
        print("  - C1 data: undersample to 60% of original")
        print("  - C2 data: no sampling (keep as is)")
        print("  - C3 and D2: oversample to 5x original")
        print("  - Other data: oversample to 20x original")
        print("  - Validation and test sets: maintain original distribution (no sampling)")
        print("="*70)
        
        df_train = apply_runtime_sampling(df_train, target)
        
        print(f"\n✓ Training set samples after sampling: {len(df_train)}")
        print(f"✓ Validation set samples: {len(df_val)} (not sampled)")
        print(f"✓ Test set samples: {len(df_test)} (not sampled)")
        
        # Swapping validation and test sets
        print("\n" + "="*70)
        print("⚠️  Swapping validation and test sets")
        print("="*70)
        df_val, df_test = df_test.copy(), df_val.copy()
        print("✓ Validation and test sets swapped")
        print(f"  After swap - Validation set samples: {len(df_val)}")
        print(f"  After swap - Test set samples: {len(df_test)}")
        
        # Saving split dataset
        print("\n" + "="*70)
        print("Saving split dataset")
        print("="*70)
        os.makedirs(output_dir, exist_ok=True)
        
        train_output_path = os.path.join(output_dir, 'protein_classification_train.csv')
        val_output_path = os.path.join(output_dir, 'protein_classification_val.csv')
        test_output_path = os.path.join(output_dir, 'protein_classification_test.csv')
        
        df_train.to_csv(train_output_path, index=False)
        df_val.to_csv(val_output_path, index=False)
        df_test.to_csv(test_output_path, index=False)
        
        print(f"✓ Training set saved: {train_output_path} ({len(df_train)} samples)")
        print(f"✓ Validation set saved: {val_output_path} ({len(df_val)} samples)")
        print(f"✓ Test set saved: {test_output_path} ({len(df_test)} samples)")
        print("="*70)

    # Build Data objects containing ESM2 and path_homology mixed features
    def build_graphs(df: pd.DataFrame) -> List[Data]:
        graphs = []
        missing_ph_count = 0
        missing_esm2_count = 0
        for pid in df['protein_id'].tolist():
            pid_str = str(pid)
            pid_upper = pid_str.upper()
            pid_lower = pid_str.lower()
            
            # Load path_homology features (graph-level features, each protein has a 360-dimensional vector)
            ph = load_path_homology_features(pid_upper, path_homology_dir)
            if ph.sum() == 0:  # If zero vector is returned, file does not exist
                missing_ph_count += 1
            
            # Load ESM2 features (graph-level features, each protein has a 1280-dimensional vector)
            from protein_utils import load_esm2_features
            esm2 = load_esm2_features(pid_lower, esm2_dir)
            if esm2.sum() == 0:  # If zero vector is returned, file does not exist
                missing_esm2_count += 1
            
            # Minimal graph object: contains ESM2 and path_homology features (graph-level features, not graph structure features)
            g = Data(
                x=torch.zeros(1, 1, dtype=torch.float32),  # Placeholder node features (not actually used)
                edge_index=torch.zeros(2, 0, dtype=torch.long),  # Empty edges (not actually used)
                num_nodes=1,
                path_homology_features=ph,  # Graph-level features: each protein has a 360-dimensional vector
                esm2_features=esm2  # Graph-level features: each protein has a 1280-dimensional vector
            )
            graphs.append(g)
        
        if missing_ph_count > 0:
            print(f"Warning: {missing_ph_count} proteins' path_homology feature files not found, will use zero vectors")
        if missing_esm2_count > 0:
            print(f"Warning: {missing_esm2_count} proteins' ESM2 feature files not found, will use zero vectors")
        return graphs

    print("Building datasets with ESM2 and path_homology mixed features (no PDB/graph features)...")
    train_graphs = build_graphs(df_train)
    val_graphs = build_graphs(df_val)
    test_graphs = build_graphs(df_test)

    # Process label mapping (for classification tasks)
    label_mapping = None
    mean_train = None
    std_train = None
    if target is not None:
        # Compatible with string categories, handled by mapping in ProteinDataset
        if config['model']['criterion'] == 'cross_entropy':
            all_labels = df_train[target].values
            unique_labels = sorted(set(all_labels))
            print(f"Found unique labels: {unique_labels}")
            
            # First merge labels
            merged_labels = [merge_labels(label) for label in all_labels]
            unique_merged_labels = sorted(set(merged_labels))
            print(f"Merged unique labels: {unique_merged_labels}")
            
            # Create label mapping: map merged labels to 0, 1, 2, ...
            label_mapping = {}
            mapped_idx = 0
            
            # Process all merged labels, map to 0, 1, 2, ... in order
            for merged_label in unique_merged_labels:
                # Find all original labels that should map to this merged label
                for original_label in unique_labels:
                    if merge_labels(original_label) == merged_label:
                        label_mapping[original_label] = mapped_idx
                mapped_idx += 1
            
            print(f"Label mapping: {label_mapping}")
            print(f"Number of classes after mapping: {len(set(label_mapping.values()))}")
            
            # Count samples for each class (by original label)
            for label in sorted(label_mapping.keys()):
                count = (all_labels == label).sum()
                merged_label = merge_labels(label)
                print(f"Label {label} -> merged to {merged_label} -> class {label_mapping[label]}: {count} samples")

    # Create datasets
    train_dataset = ProteinDataset(
        df_train, train_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    val_dataset = ProteinDataset(
        df_val, val_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )
    test_dataset = ProteinDataset(
        df_test, test_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True),
        label_mapping=label_mapping
    )

    from torch_geometric.loader import DataLoader
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )

    # Print label distribution (enhanced version, shows sampling information)
    if target is not None:
        print("\n" + "="*70)
        print("Complete dataset label distribution statistics")
        print("="*70)
        print_label_distribution(df_train, df_val, df_test, target_col=target, is_sampled=use_sampled_data)
        
        # If sampled data is used, additionally show before/after sampling comparison (if possible)
        if use_sampled_data:
            print("\n" + "="*70)
            print("Sampling strategy execution summary")
            print("="*70)
            train_labels = df_train[target].values
            unique_labels, counts = np.unique(train_labels, return_counts=True)
            
            print(f"{'Label':<10} {'Samples after sampling':<15} {'Sampling strategy':<20} {'Description':<30}")
            print("-" * 70)
            
            for label, count in zip(unique_labels, counts):
                if label == 'C1':
                    strategy = 'undersampling 60%'
                    note = 'Reduce majority class samples'
                elif label == 'C2':
                    strategy = 'no change'
                    note = 'Keep original count'
                else:
                    strategy = 'oversampling 20x'
                    note = 'Increase minority class samples'
                
                print(f"{label:<10} {count:<15} {strategy:<20} {note:<30}")
            
            print("="*70)

    return train_loader, val_loader, test_loader, mean_train, std_train

def process_protein_dataset_by_ids(protein_ids: List[str], pdb_dir: str, distance_threshold: float = 6.0, output_file: str = None) -> Tuple[List[Data], List[str]]:
    """
    Process PDB files based on protein ID list
    
    Args:
        protein_ids: List of protein IDs
        pdb_dir: PDB file directory
        distance_threshold: Connection threshold
        output_file: Output file path (optional)
        
    Returns:
        Tuple[protein graph list, list of successfully processed protein IDs]
    """
    protein_graphs = []
    successful_ids = []
    
    for protein_id in protein_ids:
        # Try different file extensions
        pdb_file = None
        for ext in ['.pdb', '.pdbqt']:
            potential_file = os.path.join(pdb_dir, protein_id + ext)
            if os.path.exists(potential_file):
                pdb_file = potential_file
                break
        
        if pdb_file is None:
            print(f"Warning: Cannot find PDB file for protein {protein_id}")
            continue
        
        try:
            from protein_utils import extract_ca_coordinates, create_protein_graph
            
            # Extract CA coordinates
            ca_coords, residue_names = extract_ca_coordinates(pdb_file)
            
            # Create protein graph
            node_attr, edge_index, edge_attr, triangle_index, triangle_dis, codebar_features = create_protein_graph(
                ca_coords, residue_names, distance_threshold, pdb_file
            )
            
            # Create Data object (not using ESM2 features)
            data = Data(
                x=node_attr,
                edge_index=edge_index,
                edge_attr=edge_attr,
                triangle_index=triangle_index,
                triangle_dis=triangle_dis,
                codebar_features=codebar_features,  # Add codebar features
                num_nodes=len(ca_coords)
            )
            
            protein_graphs.append(data)
            successful_ids.append(protein_id)
            
        except Exception as e:
            print(f"Error: Error processing protein {protein_id}: {e}")
            continue
    
    if output_file:
        with open(output_file, 'wb') as f:
            pickle.dump(protein_graphs, f)
        print(f"Saving {len(protein_graphs)} protein graphs to {output_file}")
    
    return protein_graphs, successful_ids

def load_protein_data(config):
    """
    Load protein data (compatible with old version)
    
    Args:
        config: Configuration dictionary
        
    Returns:
        Training, validation, and test data loaders
    """
    # Check if auto-split mode is used
    if 'data' in config and ('pdb_dir' in config['data'] or 'csv_file' in config['data']):
        return load_protein_data_auto(config)
    
    # Original loading method
    dataset_name = config['training']['dataset']
    target = config['training'].get('target')
    batch_size = config['training']['batch_size']
    distance_threshold = config.get('protein', {}).get('distance_threshold', 6.0)
    
    # Data paths - CSV files stored separately for Training, Validation, Test
    data_train_path = f'data/{dataset_name}/{dataset_name}_train.csv'
    data_val_path = f'data/{dataset_name}/{dataset_name}_val.csv'
    data_test_path = f'data/{dataset_name}/{dataset_name}_test.csv'
    
    # PDB file path structure: stored separately for Training, Validation, Test
    pdb_train_dir = f'data/{dataset_name}/pdb_files/train'
    pdb_val_dir = f'data/{dataset_name}/pdb_files/val'
    pdb_test_dir = f'data/{dataset_name}/pdb_files/test'
    
    # Note: ESM2 features are no longer used, so ESM2 file paths are not needed
    
    print("Loading protein data...")
    
    # Check if CSV files exist
    if not os.path.exists(data_train_path):
        raise FileNotFoundError(f"Training CSV file does not exist: {data_train_path}")
    if not os.path.exists(data_val_path):
        raise FileNotFoundError(f"Validation CSV file does not exist: {data_val_path}")
    if not os.path.exists(data_test_path):
        raise FileNotFoundError(f"Test CSV file does not exist: {data_test_path}")
    
    # Read separately stored CSV files
    print("Loading separate CSV files for train/val/test...")
    df_train = pd.read_csv(data_train_path)
    df_val = pd.read_csv(data_val_path)
    df_test = pd.read_csv(data_test_path)
    
    print(f"Train samples: {len(df_train)}")
    print(f"Validation samples: {len(df_val)}")
    print(f"Test samples: {len(df_test)}")
    
    # Print label distribution statistics
    if target is not None:
        print_label_distribution(df_train, df_val, df_test, target_col=target)
    
    print("Processing protein graphs...")
    
    # Check if processed graph files already exist
    train_graphs_file = f'data/{dataset_name}/protein_graphs_train.pkl'
    val_graphs_file = f'data/{dataset_name}/protein_graphs_val.pkl'
    test_graphs_file = f'data/{dataset_name}/protein_graphs_test.pkl'
    
    # Process Training set
    if os.path.exists(train_graphs_file):
        print("Loading pre-processed training protein graphs...")
        with open(train_graphs_file, 'rb') as f:
            train_graphs = pickle.load(f)
    else:
        print("Processing training PDB files...")
        train_graphs = process_protein_dataset_from_dir(
            pdb_train_dir, distance_threshold, train_graphs_file
        )
    
    # Process Validation set
    if os.path.exists(val_graphs_file):
        print("Loading pre-processed validation protein graphs...")
        with open(val_graphs_file, 'rb') as f:
            val_graphs = pickle.load(f)
    else:
        print("Processing validation PDB files...")
        val_graphs = process_protein_dataset_from_dir(
            pdb_val_dir, distance_threshold, val_graphs_file
        )
    
    # Process Test set
    if os.path.exists(test_graphs_file):
        print("Loading pre-processed test protein graphs...")
        with open(test_graphs_file, 'rb') as f:
            test_graphs = pickle.load(f)
    else:
        print("Processing test PDB files...")
        test_graphs = process_protein_dataset_from_dir(
            pdb_test_dir, distance_threshold, test_graphs_file
        )
    
    # Calculate normalization parameters
    if target is not None:
        train_vals = df_train[target].values
        val_vals = df_val[target].values
        test_vals = df_test[target].values
        
        mean_train = np.mean(train_vals)
        std_train = np.std(train_vals)
    else:
        mean_train = None
        std_train = None
    
    # Create datasets
    train_dataset = ProteinDataset(
        df_train, train_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True)
    )
    
    val_dataset = ProteinDataset(
        df_val, val_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True)
    )
    
    test_dataset = ProteinDataset(
        df_test, test_graphs, target=target,
        mean_train=mean_train, std_train=std_train,
        normalize_labels=config['training'].get('normalize_labels', True)
    )
    
    # Create data loaders
    from torch_geometric.loader import DataLoader
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=ProteinDataset.collate, drop_last=True
    )
    
    return train_loader, val_loader, test_loader, mean_train, std_train
