import os
import sys
import time
import glob
import re
import torch
import ignite
from ignite.engine import (
    Events,
)
from ignite.metrics import Loss, MeanAbsoluteError, Accuracy
from ignite.metrics.metric import Metric
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan, EarlyStopping
import numpy as np
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import sklearn.metrics as sk_metrics

# Enable memory optimizations
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Set real-time output
print("="*60)
print("Real-time output enabled - all print messages will be displayed immediately")
print("="*60)
sys.stdout.flush()
os.environ['PYTHONUNBUFFERED'] = '1'

from protein_dataset import load_protein_data, load_protein_data_ph_only
from protein_dataset_sampling import load_protein_data_with_sampling
from model import QCformer
# Import sampling functions (if sampling file doesn't exist, can be automatically generated)
try:
    from data_sampling import process_dataset_with_sampling
    SAMPLING_AVAILABLE = True
except ImportError:
    SAMPLING_AVAILABLE = False
    print("Warning: Unable to import data_sampling module, automatic sampling data generation will be unavailable")
from focal_loss import FocalLoss, get_balanced_class_weights, AsymmetricLoss
import yaml
import argparse
import pickle

# torch config
torch.set_default_dtype(torch.float32)

# Mixed precision scaler for memory optimization
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

def select_best_gpu():
    """
    Automatically select the most idle GPU
    Prioritize GPU with least memory usage, while considering if available memory is sufficient
    """
    if not torch.cuda.is_available():
        return torch.device("cpu")
    
    cuda_device_count = torch.cuda.device_count()
    if cuda_device_count == 0:
        return torch.device("cpu")
    
    print(f"Found {cuda_device_count} GPU(s), checking memory usage...")
    
    best_gpu = 0
    best_score = float('inf')
    gpu_info = []
    
    for i in range(cuda_device_count):
        try:
            # Temporarily set device to check memory usage
            torch.cuda.set_device(i)
            
            # Get GPU properties
            props = torch.cuda.get_device_properties(i)
            total_memory = props.total_memory
            
            # Check current memory usage
            allocated = torch.cuda.memory_allocated(i)
            reserved = torch.cuda.memory_reserved(i)
            free_memory = total_memory - reserved
            
            memory_usage_percent = (reserved / total_memory) * 100
            free_memory_gb = free_memory / 1024**3
            
            gpu_info.append({
                'id': i,
                'name': props.name,
                'total_memory': total_memory,
                'allocated': allocated,
                'reserved': reserved,
                'free_memory': free_memory,
                'usage_percent': memory_usage_percent,
                'free_memory_gb': free_memory_gb
            })
            
            print(f"GPU {i}: {props.name}")
            print(f"  Total memory: {total_memory / 1024**3:.2f} GB")
            print(f"  Allocated: {allocated / 1024**3:.2f} GB")
            print(f"  Reserved: {reserved / 1024**3:.2f} GB")
            print(f"  Free memory: {free_memory_gb:.2f} GB")
            print(f"  Memory usage: {memory_usage_percent:.1f}%")
            
        except Exception as e:
            print(f"GPU {i} check failed: {e}")
            continue
    
    # Strategy for selecting best GPU:
    # 1. Prioritize GPU with available memory >= 8GB (sufficient for training)
    # 2. Among GPUs meeting conditions, select one with lowest memory usage rate
    # 3. If no GPU meets conditions, select one with most available memory
    
    suitable_gpus = [gpu for gpu in gpu_info if gpu['free_memory_gb'] >= 8.0]
    
    if suitable_gpus:
        # Among GPUs meeting conditions, select one with lowest memory usage rate
        best_gpu = min(suitable_gpus, key=lambda x: x['usage_percent'])
        print(f"Selected GPU {best_gpu['id']} (meets requirements and has lowest memory usage: {best_gpu['usage_percent']:.1f}%)")
    else:
        # If no GPU meets conditions, select one with most available memory
        best_gpu = max(gpu_info, key=lambda x: x['free_memory_gb'])
        print(f"Warning: No GPU has sufficient free memory (>=8GB)")
        print(f"Selected GPU {best_gpu['id']} (most free memory: {best_gpu['free_memory_gb']:.2f} GB)")
    
    return torch.device(f"cuda:{best_gpu['id']}")

# Automatically select best GPU
device = select_best_gpu()
print(f"Training on device: {device}")

# Set selected GPU
if device.type == 'cuda':
    torch.cuda.set_device(device.index)
    print(f"Current GPU set to: {device}")

def monitor_gpu_memory():
    """Monitor current GPU memory usage"""
    if torch.cuda.is_available() and device.type == 'cuda':
        allocated = torch.cuda.memory_allocated(device.index) / 1024**3  # GB
        reserved = torch.cuda.memory_reserved(device.index) / 1024**3    # GB
        total = torch.cuda.get_device_properties(device.index).total_memory / 1024**3  # GB
        free = total - reserved
        usage_percent = (reserved / total) * 100
        
        print(f"GPU {device.index} Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved, {free:.2f}GB free, {usage_percent:.1f}% used")
        
        # If memory usage too high, issue warning
        if usage_percent > 90:
            print(f"⚠️  Warning: GPU memory usage too high ({usage_percent:.1f}%)")
        elif usage_percent > 80:
            print(f"⚠️  Note: GPU memory usage is high ({usage_percent:.1f}%)")
    else:
        print("GPU monitoring unavailable")

def clear_gpu_cache():
    """Clear GPU cache"""
    if torch.cuda.is_available() and device.type == 'cuda':
        torch.cuda.empty_cache()

def enable_memory_optimizations():
    """Enable memory optimization settings"""
    if torch.cuda.is_available():
        # Enable memory optimizations
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        # Enable memory-efficient attention if available
        if hasattr(torch.backends.cuda, 'enable_flash_sdp'):
            torch.backends.cuda.enable_flash_sdp(True)
        print("Memory optimizations enabled: cudnn.benchmark, flash attention")

def get_memory_usage():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated()
        reserved = torch.cuda.memory_reserved()
        return allocated, reserved
    return 0, 0

def log_memory_usage(stage=""):
    """Log memory usage"""
    if torch.cuda.is_available():
        allocated, reserved = get_memory_usage()
        print(f"Memory usage {stage}: {allocated / 1024**3:.2f}GB allocated, {reserved / 1024**3:.2f}GB reserved")

def display_class_detailed_metrics(model, data_loader, device, num_classes):
    """
    Display detailed information for each class: predicted count, true count, F1 score, AUC-PR, etc.
    
    Args:
        model: Trained model
        data_loader: Data loader
        device: Device
        num_classes: Number of classes
    """
    model.eval()
    all_predictions = []
    all_targets = []
    all_outputs = []
    
    with torch.no_grad():
        for batch in data_loader:
            data, target = batch
            data = data.to(device)
            target = target.to(device)
            
            # Get model predictions
            output = model(data)
            if output.dim() > 1:
                pred = torch.argmax(output, dim=1)
            else:
                pred = output
            
            all_predictions.append(pred.cpu())
            all_targets.append(target.cpu())
            all_outputs.append(output.cpu())
    
    # Merge results from all batches
    all_predictions = torch.cat(all_predictions, dim=0).numpy()
    all_targets = torch.cat(all_targets, dim=0).numpy()
    all_outputs = torch.cat(all_outputs, dim=0)
    
    # Ensure targets are integer type
    if all_targets.dtype in [np.float32, np.float64]:
        all_targets = all_targets.astype(np.int64)
    
    # Calculate statistics for each class
    print(f"{'Class':<6} {'Predicted':<10} {'True':<10} {'Precision':<8} {'Recall':<8} {'F1':<8} {'AUC-PR':<8}")
    print("-" * 70)
    
    total_predictions = len(all_predictions)
    total_targets = len(all_targets)
    
    # Calculate overall metrics
    overall_accuracy = sk_metrics.accuracy_score(all_targets, all_predictions)
    overall_macro_f1 = sk_metrics.f1_score(all_targets, all_predictions, average='macro', zero_division=0)
    overall_micro_f1 = sk_metrics.f1_score(all_targets, all_predictions, average='micro', zero_division=0)
    
    # Calculate AUC-PR for each class
    softmax_probs = torch.softmax(all_outputs, dim=1).numpy()
    class_auc_pr = []
    
    # Calculate detailed metrics for each class
    for class_idx in range(num_classes):
        # Create binary labels (current class vs other classes)
        binary_targets = (all_targets == class_idx).astype(int)
        binary_predictions = (all_predictions == class_idx).astype(int)
        
        # Calculate metrics for this class
        precision = sk_metrics.precision_score(binary_targets, binary_predictions, zero_division=0)
        recall = sk_metrics.recall_score(binary_targets, binary_predictions, zero_division=0)
        f1 = sk_metrics.f1_score(binary_targets, binary_predictions, zero_division=0)
        
        # Calculate AUC-PR for this class
        binary_probs = softmax_probs[:, class_idx]
        try:
            auc_pr = sk_metrics.average_precision_score(binary_targets, binary_probs)
            class_auc_pr.append(auc_pr)
        except Exception as e:
            print(f"Warning: Could not compute AUC-PR for class {class_idx}: {e}")
            class_auc_pr.append(0.0)
        
        # Calculate count for this class
        pred_count = np.sum(binary_predictions)
        true_count = np.sum(binary_targets)
        
        print(f"{class_idx:<6} {pred_count:<10} {true_count:<10} {precision:<8.4f} {recall:<8.4f} {f1:<8.4f} {class_auc_pr[-1]:<8.4f}")
    
    print("-" * 70)
    print(f"{'Total':<6} {total_predictions:<10} {total_targets:<10} {'':<8} {'':<8} {'':<8} {'':<8}")
    print(f"Overall accuracy: {overall_accuracy:.4f}")
    print(f"Overall Macro-F1: {overall_macro_f1:.4f}")
    print(f"Overall Micro-F1: {overall_micro_f1:.4f}")
    print(f"Overall Macro AUC-PR: {np.mean(class_auc_pr):.4f}")
    
    # Calculate confusion matrix
    # Use labels parameter to ensure confusion matrix includes all classes (even if some classes don't appear in data)
    # This ensures confusion matrix dimension is always num_classes x num_classes
    try:
        cm = sk_metrics.confusion_matrix(
            all_targets, 
            all_predictions, 
            labels=list(range(num_classes))
        )
        # Ensure confusion matrix dimension is correct
        if cm.shape[0] != num_classes or cm.shape[1] != num_classes:
            # If dimension doesn't match, create an all-zero matrix
            cm_full = np.zeros((num_classes, num_classes), dtype=int)
            # Copy actual confusion matrix values to full matrix
            min_dim = min(cm.shape[0], cm.shape[1], num_classes)
            cm_full[:min_dim, :min_dim] = cm[:min_dim, :min_dim]
            cm = cm_full
    except Exception as e:
        # If calculation fails, create an all-zero matrix
        print(f"Warning: Could not compute confusion matrix: {e}")
        cm = np.zeros((num_classes, num_classes), dtype=int)
    
    print(f"\nConfusion matrix (rows=true class, columns=predicted class):")
    print("     ", end="")
    for i in range(num_classes):
        print(f"{i:>6}", end="")
    print()
    
    for i in range(num_classes):
        print(f"{i:>4} ", end="")
        for j in range(num_classes):
            # Add boundary check to prevent index out of bounds
            if i < cm.shape[0] and j < cm.shape[1]:
                print(f"{cm[i, j]:>6}", end="")
            else:
                print(f"{0:>6}", end="")
        print()
    
    return {
        'overall_accuracy': overall_accuracy,
        'overall_macro_f1': overall_macro_f1,
        'overall_micro_f1': overall_micro_f1,
        'overall_macro_auc_pr': np.mean(class_auc_pr),
        'class_auc_pr': class_auc_pr,
        'confusion_matrix': cm
    }

# Custom metric classes
class MacroF1Score(Metric):
    """Calculate Macro-F1 score"""
    
    def __init__(self, output_transform=lambda x: x, device=device):
        self._device = device
        super(MacroF1Score, self).__init__(output_transform=output_transform)
    
    def reset(self):
        self._predictions = []
        self._targets = []
    
    def update(self, output):
        y_pred, y = output
        # For multi-class classification, use argmax to get predicted class
        if y_pred.dim() > 1:
            pred = torch.argmax(y_pred, dim=1)
        else:
            pred = y_pred
        
        self._predictions.append(pred.cpu())
        self._targets.append(y.cpu())
    
    def compute(self):
        if not self._predictions:
            return 0.0
        
        # Concatenate predictions and targets from all batches
        all_preds = torch.cat(self._predictions, dim=0).numpy()
        all_targets = torch.cat(self._targets, dim=0).numpy()
        
        # Ensure targets are integer type
        if all_targets.dtype in [np.float32, np.float64]:
            all_targets = all_targets.astype(np.int64)
        
        # Calculate macro-F1
        try:
            macro_f1 = sk_metrics.f1_score(all_targets, all_preds, average='macro', zero_division=0)
            return macro_f1
        except Exception as e:
            print(f"Error computing macro-F1: {e}")
            return 0.0

class BalancedAccuracy(Metric):
    """Calculate Balanced Accuracy"""
    
    def __init__(self, output_transform=lambda x: x, device=device):
        self._device = device
        super(BalancedAccuracy, self).__init__(output_transform=output_transform)
    
    def reset(self):
        self._predictions = []
        self._targets = []
    
    def update(self, output):
        y_pred, y = output
        # For multi-class classification, use argmax to get predicted class
        if y_pred.dim() > 1:
            pred = torch.argmax(y_pred, dim=1)
        else:
            pred = y_pred
        
        self._predictions.append(pred.cpu())
        self._targets.append(y.cpu())
    
    def compute(self):
        if not self._predictions:
            return 0.0
        
        # Concatenate predictions and targets from all batches
        all_preds = torch.cat(self._predictions, dim=0).numpy()
        all_targets = torch.cat(self._targets, dim=0).numpy()
        
        # Ensure targets are integer type
        if all_targets.dtype in [np.float32, np.float64]:
            all_targets = all_targets.astype(np.int64)
        
        # Calculate balanced accuracy
        try:
            balanced_acc = sk_metrics.balanced_accuracy_score(all_targets, all_preds)
            return balanced_acc
        except Exception as e:
            print(f"Error computing balanced accuracy: {e}")
            return 0.0

class MicroF1Score(Metric):
    """Calculate Micro-F1 score"""
    
    def __init__(self, output_transform=lambda x: x, device=device):
        self._device = device
        super(MicroF1Score, self).__init__(output_transform=output_transform)
    
    def reset(self):
        self._predictions = []
        self._targets = []
    
    def update(self, output):
        y_pred, y = output
        # For multi-class classification, use argmax to get predicted class
        if y_pred.dim() > 1:
            pred = torch.argmax(y_pred, dim=1)
        else:
            pred = y_pred
        
        self._predictions.append(pred.cpu())
        self._targets.append(y.cpu())
    
    def compute(self):
        if not self._predictions:
            return 0.0
        
        # Concatenate predictions and targets from all batches
        all_preds = torch.cat(self._predictions, dim=0).numpy()
        all_targets = torch.cat(self._targets, dim=0).numpy()
        
        # Ensure targets are integer type
        if all_targets.dtype in [np.float32, np.float64]:
            all_targets = all_targets.astype(np.int64)
        
        # Calculate micro-F1
        try:
            micro_f1 = sk_metrics.f1_score(all_targets, all_preds, average='micro', zero_division=0)
            return micro_f1
        except Exception as e:
            print(f"Error computing micro-F1: {e}")
            return 0.0

class MacroAUCScore(Metric):
    """Calculate Macro AUC-PR score"""
    
    def __init__(self, output_transform=lambda x: x, device=device):
        self._device = device
        super(MacroAUCScore, self).__init__(output_transform=output_transform)
    
    def reset(self):
        self._outputs = []
        self._targets = []
    
    def update(self, output):
        y_pred, y = output
        self._outputs.append(y_pred.cpu())
        self._targets.append(y.cpu())
    
    def compute(self):
        if not self._outputs:
            return 0.0
        
        # Concatenate outputs and targets from all batches
        all_outputs = torch.cat(self._outputs, dim=0)
        all_targets = torch.cat(self._targets, dim=0).numpy()
        
        # Ensure targets are integer type
        if all_targets.dtype in [np.float32, np.float64]:
            all_targets = all_targets.astype(np.int64)
        
        # Get softmax probabilities
        softmax_probs = torch.softmax(all_outputs, dim=1).numpy()
        num_classes = softmax_probs.shape[1]
        
        # Calculate AUC-PR for each class
        class_auc_pr = []
        for class_idx in range(num_classes):
            # Create binary labels (current class vs other classes)
            binary_targets = (all_targets == class_idx).astype(int)
            binary_probs = softmax_probs[:, class_idx]
            
            # Calculate AUC-PR
            try:
                auc_pr = sk_metrics.average_precision_score(binary_targets, binary_probs)
                class_auc_pr.append(auc_pr)
            except Exception as e:
                print(f"Warning: Could not compute AUC-PR for class {class_idx}: {e}")
                class_auc_pr.append(0.0)
        
        # Calculate macro-averaged AUC-PR
        try:
            macro_auc_pr = np.mean(class_auc_pr)
            return macro_auc_pr
        except Exception as e:
            print(f"Error computing macro AUC-PR: {e}")
            return 0.0

def find_optimal_threshold(val_outputs, val_targets, thresholds=None, loss_type='cross_entropy'):
    """
    Scan thresholds on validation set to find the threshold that maximizes Macro-F1 or Balanced Accuracy
    
    Args:
        val_outputs: Model outputs
        val_targets: True labels
        thresholds: List of thresholds
        loss_type: Loss function type, used to determine how to process model outputs
    """
    # Multi-class classification does not need threshold optimization, use argmax directly
    print("Multi-class classification detected. No threshold optimization needed.")
    print("Using argmax for class prediction.")
    
    # Ensure targets are integer type (required by scikit-learn metrics) and 1D
    if val_targets.dtype == np.float32 or val_targets.dtype == np.float64:
        val_targets = val_targets.astype(np.int64)
    else:
        val_targets = val_targets.cpu().numpy()
    
    # Ensure targets are 1D
    if val_targets.ndim > 1:
        val_targets = val_targets.squeeze()
    
    # For multi-class classification, use argmax directly for prediction
    val_preds = torch.argmax(val_outputs, dim=1).cpu().numpy()
    
    # Calculate various metrics
    accuracy = sk_metrics.accuracy_score(val_targets, val_preds)
    balanced_acc = sk_metrics.balanced_accuracy_score(val_targets, val_preds)
    macro_f1 = sk_metrics.f1_score(val_targets, val_preds, average='macro')
    precision = sk_metrics.precision_score(val_targets, val_preds, average='macro', zero_division=0)
    recall = sk_metrics.recall_score(val_targets, val_preds, average='macro', zero_division=0)
    
    print(f"Multi-class performance metrics:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Balanced Accuracy: {balanced_acc:.4f}")
    print(f"Macro-F1: {macro_f1:.4f}")
    print(f"Macro Precision: {precision:.4f}")
    print(f"Macro Recall: {recall:.4f}")
    
    # Return None as threshold (not needed for multi-class)
    return None, macro_f1, []

def evaluate_with_threshold(outputs, targets, threshold=None, loss_type='cross_entropy'):
    """Evaluate model performance with specified threshold (multi-class version)"""
    # Multi-class classification does not need threshold, use argmax directly
    if threshold is not None:
        print("Warning: Threshold provided for multi-class classification, ignoring...")
    
    # For multi-class classification, use argmax directly for prediction
    preds = torch.argmax(outputs, dim=1).cpu().numpy()
    
    # Ensure targets are integer type (required by scikit-learn metrics) and 1D
    if targets.dtype == np.float32 or targets.dtype == np.float64:
        targets = targets.astype(np.int64)
    else:
        targets = targets.cpu().numpy()
    
    # Ensure targets are 1D
    if targets.ndim > 1:
        targets = targets.squeeze()
    
    # Calculate multi-class metrics
    accuracy = sk_metrics.accuracy_score(targets, preds)
    balanced_acc = sk_metrics.balanced_accuracy_score(targets, preds)
    micro_f1 = sk_metrics.f1_score(targets, preds, average='micro', zero_division=0)
    macro_f1 = sk_metrics.f1_score(targets, preds, average='macro')
    precision = sk_metrics.precision_score(targets, preds, average='macro', zero_division=0)
    recall = sk_metrics.recall_score(targets, preds, average='macro', zero_division=0)
    
    # Calculate per-class metrics
    class_f1 = sk_metrics.f1_score(targets, preds, average=None, zero_division=0)
    class_precision = sk_metrics.precision_score(targets, preds, average=None, zero_division=0)
    class_recall = sk_metrics.recall_score(targets, preds, average=None, zero_division=0)
    
    # Calculate AUC-PR for each class
    num_classes = outputs.shape[1]
    class_auc_pr = []
    
    # Get softmax probabilities
    softmax_probs = torch.softmax(outputs, dim=1).cpu().numpy()
    
    for class_idx in range(num_classes):
        # Create binary labels (current class vs other classes)
        binary_targets = (targets == class_idx).astype(int)
        binary_probs = softmax_probs[:, class_idx]
        
        # Calculate AUC-PR
        try:
            auc_pr = sk_metrics.average_precision_score(binary_targets, binary_probs)
            class_auc_pr.append(auc_pr)
        except Exception as e:
            print(f"Warning: Could not compute AUC-PR for class {class_idx}: {e}")
            class_auc_pr.append(0.0)
    
    # Calculate macro-averaged AUC-PR
    macro_auc_pr = np.mean(class_auc_pr)
    
    return {
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'macro_precision': precision,
        'macro_recall': recall,
        'class_f1': class_f1,
        'class_precision': class_precision,
        'class_recall': class_recall,
        'class_auc_pr': class_auc_pr,
        'macro_auc_pr': macro_auc_pr,
        'predictions': preds,
        'targets': targets
    }

# NOTE: plot_threshold_optimization function removed - not used for multi-class classification
# (Multi-class uses argmax directly, no threshold optimization needed)
# If needed for binary classification tasks in the future, this function can be restored

def count_parameters(model):
    total_params = 0
    for parameter in model.parameters():
        total_params += parameter.element_size() * parameter.nelement()
    for parameter in model.buffers():
        total_params += parameter.element_size() * parameter.nelement()
    total_params = total_params / 1024 / 1024
    print(f"Total Trainable Params: {total_params} MB")
    return total_params

def group_decay(model):
    """Omit weight decay from bias and batchnorm params."""
    decay, no_decay = [], []

    for name, p in model.named_parameters():
        if "bias" in name or "bn" in name or "norm" in name:
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay},
        {"params": no_decay, "weight_decay": 0},
    ]

def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

def train(config_path):
    """Train protein model - improved version"""
    # Record training start time
    start_time = time.time()
    print(f"[Training] Training start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
    sys.stdout.flush()
    
    config = load_config(config_path)
    
    print("Configuration:")
    print_dict(config)
    sys.stdout.flush()
    
    # Enable memory optimizations
    enable_memory_optimizations()
    log_memory_usage("after optimizations")
    
    # Load data: decide whether to use sampled data based on configuration
    use_sampled_data = config.get('data', {}).get('use_sampled_data', False)
    esm2_only = config.get('data', {}).get('esm2_only', False)
    disable_sampling = config.get('training', {}).get('disable_sampling', False)
    
    # If configuration enables sampled data, check if sampled files exist
    if use_sampled_data:
        output_dir = config.get('data', {}).get('output_dir', './data_output')
        sampled_train_path = os.path.join(output_dir, 'protein_classification_train.csv')
        if not os.path.exists(sampled_train_path):
            print(f"⚠️  Warning: Sampled data file does not exist: {sampled_train_path}")
            print("="*70)
            print("Automatically running sampling script to generate sampled data...")
            print("="*70)
            
            # Automatically run sampling script to generate sampled data
            try:
                csv_file = config.get('data', {}).get('csv_file')
                train_ratio = config.get('data', {}).get('train_ratio', 0.8)
                val_ratio = config.get('data', {}).get('val_ratio', 0.1)
                test_ratio = config.get('data', {}).get('test_ratio', 0.1)
                random_seed = config.get('data', {}).get('random_seed', 42)
                
                if csv_file and os.path.exists(csv_file):
                    if SAMPLING_AVAILABLE:
                        print(f"Automatically generating sampled data...")
                        print(f"  Input file: {csv_file}")
                        print(f"  Output directory: {output_dir}")
                        print(f"  Split ratio: Training set={train_ratio}, Validation set={val_ratio}, Test set={test_ratio}")
                        print(f"  Sampling strategy: C1 undersample 60%, C2 unchanged, C3 and D2 oversample 5x, others oversample 20x")
                        
                        # Directly call sampling function
                        df_train_sampled, df_val, df_test = process_dataset_with_sampling(
                            csv_file=csv_file,
                            output_dir=output_dir,
                            train_ratio=train_ratio,
                            val_ratio=val_ratio,
                            test_ratio=test_ratio,
                            random_seed=random_seed
                        )
                        
                        if os.path.exists(sampled_train_path):
                            print(f"✓ Sampled data file generated: {sampled_train_path}")
                            use_sampled_data = True
                        else:
                            print(f"⚠️  Sampling completed but file still not generated")
                            print("Will use original data file for random split")
                            use_sampled_data = False
                            config['data']['use_sampled_data'] = False
                    else:
                        print(f"⚠️  Sampling module unavailable, cannot automatically generate sampled data")
                        print("Please run first: python3 data_sampling.py")
                        print("Will use original data file for random split")
                        use_sampled_data = False
                        config['data']['use_sampled_data'] = False
                else:
                    print(f"⚠️  Original CSV file does not exist or is not configured: {csv_file}")
                    print("Will use original data file for random split")
                    use_sampled_data = False
                    config['data']['use_sampled_data'] = False
            except Exception as e:
                print(f"⚠️  Error occurred while automatically generating sampled data: {e}")
                import traceback
                traceback.print_exc()
                print("Will use original data file for random split")
                use_sampled_data = False

    # Priority: dynamic sampling > esm2_only > disable_sampling
    if disable_sampling:
        # Disable sampling: use original data
        path_homology_dir = config.get('data', {}).get('path_homology_dir')
        use_path_homology = config.get('features', {}).get('use_path_homology', False)
        is_ph_only = path_homology_dir and use_path_homology
        
        if is_ph_only:
            print("Loading protein data (PH_only mode - path_homology features only, no sampling, using original data)...")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → path_homology directory: {path_homology_dir}")
        elif esm2_only:
            print("Loading protein data (ESM2-only mode, no sampling, using original data)...")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → ESM2 directory: {config.get('data', {}).get('esm2_dir', 'N/A')}")
        else:
            print("Loading protein data (no oversampling, using original data)...")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
        sys.stdout.flush()
        train_loader, val_loader, test_loader, mean_train, std_train = load_protein_data(config)
    elif not use_sampled_data:
        # Use dynamic sampling (runtime sampling)
        # Check if using path_homology features (PH_only mode)
        # PH_only mode: use path_homology features, do not use ESM2 features
        path_homology_dir = config.get('data', {}).get('path_homology_dir')
        use_path_homology = config.get('features', {}).get('use_path_homology', False)
        # Determine if PH_only mode: has path_homology_dir and use_path_homology=True, and actually uses load_protein_data_ph_only function
        # If path_homology_dir exists and use_path_homology=True, prioritize PH_only mode
        is_ph_only = path_homology_dir and use_path_homology
        
        if is_ph_only:
            print("Loading protein data (PH_only mode - path_homology features only, with dynamic sampling strategy)...")
            print("Sampling strategy: C1 undersample 60%, C2 unchanged, C3 and D2 oversample 5x, others oversample 20x")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → path_homology directory: {path_homology_dir}")
            sys.stdout.flush()
            train_loader, val_loader, test_loader, mean_train, std_train = load_protein_data_ph_only(config)
        elif esm2_only:
            print("Loading protein data (ESM2-only mode, with dynamic sampling strategy)...")
            print("Sampling strategy: C1 undersample 60%, C2 unchanged, C3 and D2 oversample 5x, others oversample 20x")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → ESM2 directory: {config.get('data', {}).get('esm2_dir', 'N/A')}")
            sys.stdout.flush()
            train_loader, val_loader, test_loader, mean_train, std_train = load_protein_data_with_sampling(config)
        else:
            print("Loading protein data with dynamic sampling strategy (runtime sampling)...")
            print("Sampling strategy: C1 undersample 60%, C2 unchanged, C3 and D2 oversample 5x, others oversample 20x")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            sys.stdout.flush()
            train_loader, val_loader, test_loader, mean_train, std_train = load_protein_data_with_sampling(config)
    else:
        # Use pre-generated sampled files
        path_homology_dir = config.get('data', {}).get('path_homology_dir')
        use_path_homology = config.get('features', {}).get('use_path_homology', False)
        is_ph_only = path_homology_dir and use_path_homology
        
        if is_ph_only:
            print("Loading protein data (PH_only mode - path_homology features only, using pre-sampled data files)...")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → path_homology directory: {path_homology_dir}")
            print(f"  → Sampled data directory: {config.get('data', {}).get('output_dir', './data_output')}")
        elif esm2_only:
            print("Loading protein data (ESM2-only mode, using pre-sampled data files)...")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → ESM2 directory: {config.get('data', {}).get('esm2_dir', 'N/A')}")
            print(f"  → Sampled data directory: {config.get('data', {}).get('output_dir', './data_output')}")
        else:
            print("Loading protein data (using pre-sampled data files)...")
            print(f"  → CSV file: {config.get('data', {}).get('csv_file', 'N/A')}")
            print(f"  → Sampled data directory: {config.get('data', {}).get('output_dir', './data_output')}")
        sys.stdout.flush()
        # Ensure configuration is correctly passed
        config['data']['use_sampled_data'] = use_sampled_data
        train_loader, val_loader, test_loader, mean_train, std_train = load_protein_data(config)
    
    print(f"Train samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    print(f"Test samples: {len(test_loader.dataset)}")
    sys.stdout.flush()
    
    # Explicitly do not use oversampling/sampler
    if config['model']['criterion'] == 'cross_entropy':
        print("Oversampling disabled - using original data distribution")
        sys.stdout.flush()
    
    # Calculate class weights to handle data imbalance
    if config['model']['criterion'] == 'cross_entropy':
        print("Calculating class weights for imbalanced data...")
        sys.stdout.flush()
        # Get label distribution of training set
        train_labels = []
        for batch in train_loader:
            _, labels = batch
            train_labels.extend(labels.tolist())
        
        # Use improved class weight calculation method
        print("Original class weights calculation:")
        sys.stdout.flush()
        # Original method: w0 = N/(2*N0), w1 = N/(2*N1)
        unique_labels, counts = torch.unique(torch.tensor(train_labels), return_counts=True)
        N = len(train_labels)  # Total number of samples
        
        # Define number of classes
        num_classes = len(unique_labels)
        print(f"Number of classes: {num_classes}")
        print(f"Class labels: {unique_labels.tolist()}")
        print(f"Class distribution: {counts.tolist()}")
        
        # Multi-class: calculate weights for each class
        original_weights = torch.zeros(len(unique_labels), dtype=torch.float32, device=device)
        for i, label in enumerate(unique_labels):
            count = counts[i].item()
            original_weights[i] = N / (len(unique_labels) * count)
        
        print(f"Original class weights: {original_weights.tolist()}")
        print(f"Original weight ratio (max/min): {original_weights.max()/original_weights.min():.3f}")
        sys.stdout.flush()
        
        # Use improved class weight method
        print("\nEnhanced class weights calculation:")
        sys.stdout.flush()
        
        # Directly use improved class weight method, no longer depends on sampler configuration
        enhanced_weights = get_balanced_class_weights(
            train_labels, 
            method='gentle',  # Use gentler gentle method
            max_ratio=2.0    # Limit maximum weight ratio to 2.0
        )
        enhanced_weights = enhanced_weights.to(device)
        print("Using improved class weights (gentle method)")
        
        print(f"Enhanced weight ratio (max/min): {enhanced_weights.max()/enhanced_weights.min():.3f}")
        
        # Choose which weights to use based on configuration
        if config.get('training', {}).get('use_improved_weights', True):
            # Use more aggressive weight strategy
            class_weights = original_weights  # Use original weights, stronger compensation
            print("Using aggressive class weights for severe imbalance")
        else:
            class_weights = enhanced_weights
            print("Using enhanced class weights")
            
        # Do not enable resampling sampler, only use class weights
        print("Using only class weights for class balancing (no resampling sampler)")
        
        # Update configuration based on actual number of classes in data to ensure correct model output dimensions
        if num_classes > 0:
            print(f"Updating config: num_classes = {num_classes} (from actual data)")
            config['training']['num_classes'] = num_classes
            # If num_outputs exists in configuration, update it too
            if 'num_outputs' in config.get('model', {}):
                config['model']['num_outputs'] = num_classes
                print(f"Updated model.num_outputs to {num_classes}")
    
    # Create model
    print("Creating model...")
    sys.stdout.flush()
    model = QCformer(config)
    print("Model created successfully")
    sys.stdout.flush()
    model = model.to(device)
    print(f"Model moved to device: {device}")
    sys.stdout.flush()
    
    # Reset model weights to avoid pretraining bias
    def reset_weights(m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.BatchNorm2d):
            torch.nn.init.constant_(m.weight, 1)
            torch.nn.init.constant_(m.bias, 0)
    
    print("Resetting model weights...")
    sys.stdout.flush()
    model.apply(reset_weights)
    print("Model weights reset to avoid pretraining bias")
    sys.stdout.flush()
    
    # Calculate number of parameters
    count_parameters(model)
    
    # Initialize loss recording lists
    train_losses = []
    val_losses = []
    epochs = []
    train_accuracies = []
    val_accuracies = []
    val_micro_f1s = []
    val_macro_f1s = []
    val_balanced_accs = []
    val_macro_auc_prs = []
    
    # Add training loss accumulator
    total_train_loss = 0.0
    total_train_samples = 0
    
    # Setup optimizer
    # Ensure numerical parameters are correctly parsed
    learning_rate = float(config['training']['learning_rate'])
    weight_decay = float(config['training']['weight_decay'])
    
    # For severely imbalanced data, use smaller learning rate
    if learning_rate > 0.0005:
        adjusted_lr = learning_rate * 0.3  # Lower learning rate
        print(f"Adjusted learning rate from {learning_rate} to {adjusted_lr} for severe class imbalance")
        learning_rate = adjusted_lr
    
    optimizer = torch.optim.AdamW(
        group_decay(model),
        lr=learning_rate,
        weight_decay=weight_decay
    )
    
    # Setup learning rate scheduler - use gentler scheduling strategy
    if config['model']['scheduler'] == 'onecycle':
        # OneCycleLR: step-level scheduling, need to correctly handle gradient accumulation
        # steps_per_epoch should be the actual number of optimizer updates, not the number of data loader steps
        actual_steps_per_epoch = len(train_loader) // gradient_accumulation_steps
        if len(train_loader) % gradient_accumulation_steps != 0:
            actual_steps_per_epoch += 1
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=learning_rate,
            epochs=int(config['training']['epoch']),
            steps_per_epoch=actual_steps_per_epoch,
            pct_start=0.3,  # First 30% used for warmup
            anneal_strategy='cos'  # Use cosine annealing
        )
        print(f"OneCycleLR scheduler: steps_per_epoch={actual_steps_per_epoch} (with gradient accumulation)")
        print("Note: OneCycleLR will step after each optimizer update")
    elif config['model']['scheduler'] == 'cosine':
        # Cosine annealing scheduler: epoch-level scheduling, T_max is total number of epochs
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=int(config['training']['epoch']),
            eta_min=learning_rate * 0.001  # Minimum learning rate is 0.1% of initial value
        )
        print(f"CosineAnnealingLR scheduler: T_max={int(config['training']['epoch'])}, eta_min={learning_rate * 0.001:.6f}")
        print("Note: CosineAnnealingLR will step at the end of each epoch")
    else:
        scheduler = None
    
    # Setup loss function - optimize multi-class loss function and class weight usage strategy
    loss_type = config.get('training', {}).get('loss_type', 'cross_entropy')
    
    if loss_type == 'focal_loss':
        # Use aggressive Class-Balanced Focal Loss to amplify gradients for hard samples
        focal_alpha = config.get('training', {}).get('focal_alpha', 'effective_number')
        focal_gamma = config.get('training', {}).get('focal_gamma', 2.5)
        focal_beta = config.get('training', {}).get('focal_beta', 0.9999)
        
        if focal_alpha == 'effective_number':
            # Use aggressive Effective Number weighted Focal Loss
            from focal_loss import ClassBalancedFocalLoss
            print(f"Using Class-Balanced Focal Loss with Enhanced Effective Number weights")
            print(f"  - β={focal_beta} (more aggressive weight calculation)")
            print(f"  - γ={focal_gamma} (higher focus parameter, more attention to difficult samples)")
            
            # Calculate Effective Number weights
            unique_labels, counts = torch.unique(torch.tensor(train_labels), return_counts=True)
            criterion = ClassBalancedFocalLoss(
                class_counts=counts, 
                beta=focal_beta, 
                gamma=focal_gamma
            )
            
        else:
            # Use traditional Focal Loss
            print(f"Using Multi-class Focal Loss with α={focal_alpha}, γ={focal_gamma}")
            criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
            
    elif loss_type == 'asl':
        # Use Asymmetric Loss, gentler than Focal Loss
        gamma_neg = config.get('training', {}).get('asl_gamma_neg', 2.0)
        gamma_pos = config.get('training', {}).get('asl_gamma_pos', 1.0)
        margin = config.get('training', {}).get('asl_margin', 0.05)
        print(f"Using Asymmetric Loss with γ_neg={gamma_neg}, γ_pos={gamma_pos}, margin={margin}")
        criterion = AsymmetricLoss(gamma_neg=gamma_neg, gamma_pos=gamma_pos, margin=margin)
        
    elif config['model']['criterion'] == 'cross_entropy':
        # Use improved CrossEntropyLoss for multi-class classification
        if class_weights is not None:
            # Ensure weights are on the correct device
            if hasattr(class_weights, 'to'):
                class_weights = class_weights.to(device)
            print(f"Using class weights for CrossEntropyLoss: {class_weights.tolist()}")
            criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = torch.nn.CrossEntropyLoss()
    elif config['model']['criterion'] == 'mse':
        criterion = torch.nn.MSELoss()
        else:
            print(f"Warning: Unknown criterion '{config['model']['criterion']}', using CrossEntropy")
        criterion = torch.nn.CrossEntropyLoss()
    
    # Get gradient accumulation steps
    gradient_accumulation_steps = config.get('training', {}).get('gradient_accumulation_steps', 1)
    batch_size = config['training']['batch_size']
    effective_batch_size = batch_size * gradient_accumulation_steps
    print(f"Batch size: {batch_size}, Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"Effective batch size: {effective_batch_size}")
    
    # Print scheduler configuration information
    if scheduler:
        if isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
            print(f"Scheduler: OneCycleLR (step-level, with gradient accumulation)")
            print(f"  - Actual steps per epoch: {scheduler.steps_per_epoch}")
            print(f"  - Total steps: {scheduler.steps_per_epoch * scheduler.epochs}")
        elif isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
            print(f"Scheduler: CosineAnnealingLR (epoch-level)")
            print(f"  - T_max: {scheduler.T_max} epochs")
            print(f"  - eta_min: {scheduler.eta_min:.6f}")
        else:
            print(f"Scheduler: {type(scheduler).__name__}")
    else:
        print("No scheduler configured")
    
    # Create trainer and evaluator
    def train_step(engine, batch):
        model.train()
        
        # Process based on actual type of batch
        if isinstance(batch, (list, tuple)) and len(batch) == 2:
            # If batch is a tuple or list of (data, target)
            data, target = batch
            data = data.to(device)
            target = target.to(device)
        elif hasattr(batch, 'to'):
            # If batch is a DataBatch object
            data = batch.to(device)
            if hasattr(batch, 'y'):
                target = batch.y.to(device)
            else:
                target = getattr(batch, 'label', None)
                if target is None:
                    raise ValueError("Cannot find target labels in batch")
                target = target.to(device)
        else:
            raise ValueError(f"Unexpected batch type: {type(batch)}")
        
        # Input data validation: check for NaN/Inf
        if hasattr(data, 'path_homology_features') and data.path_homology_features is not None:
            ph_features = data.path_homology_features
            if isinstance(ph_features, torch.Tensor):
                if torch.isnan(ph_features).any() or torch.isinf(ph_features).any():
                    nan_count = torch.isnan(ph_features).sum().item()
                    inf_count = torch.isinf(ph_features).sum().item()
                    print(f"Warning: Invalid path_homology_features detected at iteration {engine.state.iteration}")
                    print(f"  NaN count: {nan_count}, Inf count: {inf_count}")
                    # Replace NaN and Inf with 0
                    ph_features = torch.nan_to_num(ph_features, nan=0.0, posinf=0.0, neginf=0.0)
                    data.path_homology_features = ph_features
        
        if hasattr(data, 'esm2_features') and data.esm2_features is not None:
            esm2_features = data.esm2_features
            if isinstance(esm2_features, torch.Tensor):
                if torch.isnan(esm2_features).any() or torch.isinf(esm2_features).any():
                    nan_count = torch.isnan(esm2_features).sum().item()
                    inf_count = torch.isinf(esm2_features).sum().item()
                    print(f"Warning: Invalid esm2_features detected at iteration {engine.state.iteration}")
                    print(f"  NaN count: {nan_count}, Inf count: {inf_count}")
                    # Replace NaN and Inf with 0
                    esm2_features = torch.nan_to_num(esm2_features, nan=0.0, posinf=0.0, neginf=0.0)
                    data.esm2_features = esm2_features
        
        # Use mixed precision training (fix deprecated API)
        with torch.amp.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', enabled=torch.cuda.is_available()):
            # Pass data to model
            output = model(data)
            
            # Ensure output is valid logits
            if torch.isnan(output).any() or torch.isinf(output).any():
                nan_count = torch.isnan(output).sum().item()
                inf_count = torch.isinf(output).sum().item()
                print(f"Warning: Invalid output detected at iteration {engine.state.iteration}")
                print(f"  NaN count: {nan_count}, Inf count: {inf_count}")
                # If output contains NaN/Inf, use zero padding and record
                output = torch.nan_to_num(output, nan=0.0, posinf=1.0, neginf=-1.0)
                # Limit output range to avoid extreme values
                output = torch.clamp(output, min=-10.0, max=10.0)
            
            # Debug info: print output and target shapes
            if engine.state.iteration <= 3:  # Only print for first few iterations
                pass
            
            # Handle classification and regression tasks
            if config['model']['criterion'] == 'cross_entropy' or loss_type.startswith('focal') or loss_type == 'asl':
                # Multi-class task: target needs to be class index
                if target.dtype == torch.float32:
                    target = target.long()  # Convert to long integer
                
                # Ensure target is within valid range
                num_classes = output.size(1)
                if target.max() >= num_classes or target.min() < 0:
                    print(f"Warning: Invalid target values at iteration {engine.state.iteration}")
                    print(f"  Target range: [{target.min().item()}, {target.max().item()}], Expected: [0, {num_classes-1}]")
                    # Limit invalid target values to valid range
                    target = torch.clamp(target, min=0, max=num_classes - 1)
                
                # Ensure output has no extreme values to prevent numerical issues in CrossEntropyLoss calculation
                output_safe = torch.clamp(output, min=-50.0, max=50.0)
                
                try:
                    loss = criterion(output_safe, target)
                except RuntimeError as e:
                    print(f"Error computing loss at iteration {engine.state.iteration}: {e}")
                    print(f"  Output stats: min={output.min().item():.4f}, max={output.max().item():.4f}, "
                          f"mean={output.mean().item():.4f}, std={output.std().item():.4f}")
                    print(f"  Target range: [{target.min().item()}, {target.max().item()}]")
                    # Use a small valid loss value but do not perform backpropagation
                    loss = torch.tensor(0.0, device=output.device, requires_grad=False)
            else:
                # Regression task
                # Limit output and target ranges to prevent extreme values
                output_safe = torch.clamp(output, min=-1e6, max=1e6)
                target_safe = torch.clamp(target, min=-1e6, max=1e6)
                
                try:
                    loss = criterion(output_safe, target_safe)
                except RuntimeError as e:
                    print(f"Error computing loss at iteration {engine.state.iteration}: {e}")
                    loss = torch.tensor(0.0, device=output.device, requires_grad=False)
            
            # Check if loss is NaN/Inf
            skip_backward = False
            if torch.isnan(loss) or torch.isinf(loss) or (not torch.isfinite(loss)):
                print(f"Warning: Invalid loss detected at iteration {engine.state.iteration}, skipping backward pass")
                print(f"  Loss value: {loss.item() if hasattr(loss, 'item') else loss}")
                # If loss is NaN/Inf, skip backward propagation for this batch
                # Use a small valid loss value to avoid errors but do not perform backpropagation
                loss = torch.tensor(0.0, device=loss.device if hasattr(loss, 'device') else output.device, requires_grad=False)
                skip_backward = True
            elif loss.item() > 1e6:
                print(f"Warning: Extremely large loss detected at iteration {engine.state.iteration}: {loss.item():.4f}")
                # For overly large loss, clip it
                loss = torch.clamp(loss, max=1e6)
        
        # Gradient accumulation: divide loss by accumulation steps
        loss = loss / gradient_accumulation_steps
        
        # Use mixed precision scaler for backward propagation (only when loss is valid)
        if not skip_backward:
            try:
                if scaler is not None:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
            except RuntimeError as e:
                print(f"Error during backward pass at iteration {engine.state.iteration}: {e}")
                # If backward propagation fails, skip this batch
                skip_backward = True
                # Clear any existing gradients
                optimizer.zero_grad()
        
        # Record accumulated loss for logging
        if hasattr(engine.state, 'total_train_loss'):
            engine.state.total_train_loss += loss.item() * gradient_accumulation_steps
        if hasattr(engine.state, 'total_train_samples'):
            engine.state.total_train_samples += target.size(0)
        
        # Only update parameters when accumulation steps are reached (and backward was not skipped)
        if (engine.state.iteration % gradient_accumulation_steps) == 0 and not skip_backward:
            # Gradient clipping - use stricter clipping for severely imbalanced data
            gradient_clip = config.get('training', {}).get('stability', {}).get('gradient_clip', 1.0)
            if gradient_clip > 0:
                # Use stricter gradient clipping for severely imbalanced data
                effective_clip = gradient_clip * 0.5  # Halve gradient clipping threshold
                if scaler is not None:
                    scaler.unscale_(optimizer)
                
                # Check if gradients have NaN/Inf - enhanced diagnostics and handling
                has_nan_grad = False
                nan_param_names = []
                total_nan_count = 0
                total_inf_count = 0
                
                # First check loss value
                current_loss_value = loss.item() * gradient_accumulation_steps if not skip_backward else 0.0
                if not skip_backward and (not torch.isfinite(torch.tensor(current_loss_value)) or current_loss_value > 1e6):
                    print(f"Warning: Abnormal loss value at iteration {engine.state.iteration}: {current_loss_value}")
                
                # Check gradients of all parameters
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        nan_count = torch.isnan(param.grad).sum().item()
                        inf_count = torch.isinf(param.grad).sum().item()
                        
                        if nan_count > 0 or inf_count > 0:
                            has_nan_grad = True
                            nan_param_names.append(name)
                            total_nan_count += nan_count
                            total_inf_count += inf_count
                            
                            # Record detailed gradient information (only for first iteration or every 100 iterations)
                            if engine.state.iteration <= 3 or engine.state.iteration % 100 == 0:
                                grad_norm = param.grad.norm().item() if param.grad.numel() > 0 else 0.0
                                param_norm = param.norm().item() if param.numel() > 0 else 0.0
                                print(f"  Parameter '{name}': NaN={nan_count}, Inf={inf_count}, "
                                      f"grad_norm={grad_norm:.4f}, param_norm={param_norm:.4f}")
                            
                            # Set NaN/Inf gradients to zero and limit gradient range
                            param.grad = torch.nan_to_num(param.grad, nan=0.0, posinf=1e6, neginf=-1e6)
                            # Additional gradient clipping to prevent extreme values
                            torch.clamp_(param.grad, min=-1e6, max=1e6)
                
                # If NaN/Inf gradients detected, print summary information
                if has_nan_grad:
                    print(f"Warning: NaN/Inf gradients detected at iteration {engine.state.iteration}")
                    print(f"  Total NaN count: {total_nan_count}, Total Inf count: {total_inf_count}")
                    print(f"  Affected parameters: {len(nan_param_names)}")
                    if len(nan_param_names) <= 10:
                        print(f"  Parameter names: {nan_param_names}")
                    else:
                        print(f"  First 10 parameter names: {nan_param_names[:10]}...")
                    
                    # Check input data status
                    if hasattr(data, 'path_homology_features') and data.path_homology_features is not None:
                        ph_nan = torch.isnan(data.path_homology_features).sum().item()
                        ph_inf = torch.isinf(data.path_homology_features).sum().item()
                        if ph_nan > 0 or ph_inf > 0:
                            print(f"  Input path_homology_features: NaN={ph_nan}, Inf={ph_inf}")
                    
                    if hasattr(data, 'esm2_features') and data.esm2_features is not None:
                        esm2_nan = torch.isnan(data.esm2_features).sum().item()
                        esm2_inf = torch.isinf(data.esm2_features).sum().item()
                        if esm2_nan > 0 or esm2_inf > 0:
                            print(f"  Input esm2_features: NaN={esm2_nan}, Inf={esm2_inf}")
                    
                    # Check model output status
                    if not skip_backward:
                        output_nan = torch.isnan(output).sum().item()
                        output_inf = torch.isinf(output).sum().item()
                        if output_nan > 0 or output_inf > 0:
                            print(f"  Model output: NaN={output_nan}, Inf={output_inf}")
                    
                    # If NaN gradients occur consecutively multiple times, may need to lower learning rate or skip update
                    if not hasattr(engine.state, 'nan_grad_count'):
                        engine.state.nan_grad_count = 0
                    engine.state.nan_grad_count += 1
                    
                    if engine.state.nan_grad_count >= 10:
                        print(f"  Critical: NaN/Inf gradients detected {engine.state.nan_grad_count} times consecutively!")
                        print(f"  Consider: reducing learning rate or checking data quality")
                        # Reset counter to avoid unlimited growth
                        engine.state.nan_grad_count = 0
                else:
                    # If no NaN gradients, reset counter
                    if hasattr(engine.state, 'nan_grad_count'):
                        engine.state.nan_grad_count = 0
                
                # Only perform gradient clipping when there are no NaN/Inf gradients
                if not has_nan_grad:
                    try:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), effective_clip)
                    except RuntimeError as e:
                        print(f"Error during gradient clipping at iteration {engine.state.iteration}: {e}")
                        # If gradient clipping fails, try clipping parameter by parameter
                        for param in model.parameters():
                            if param.grad is not None:
                                torch.clamp_(param.grad, min=-effective_clip, max=effective_clip)
                
                # Print gradient information for first few iterations
                if engine.state.iteration <= 10:
                    try:
                        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
                        print(f"Gradient norm before clipping: {total_norm:.4f}")
                    except RuntimeError:
                        print(f"Warning: Could not compute gradient norm at iteration {engine.state.iteration}")
            
            # Use mixed precision scaler for optimizer step
            if scaler is not None:
                try:
                    # scaler.step() automatically checks for inf/NaN, skips parameter update if detected
                    scaler.step(optimizer)
                    scaler.update()
                except RuntimeError as e:
                    print(f"Error during optimizer step at iteration {engine.state.iteration}: {e}")
                    # Even if error occurs, update scaler state to prevent scaler state inconsistency
                    try:
                        scaler.update()
                    except:
                        pass
            else:
                try:
                    optimizer.step()
                except RuntimeError as e:
                    print(f"Error during optimizer step at iteration {engine.state.iteration}: {e}")
            
            # Clear gradients (regardless of whether parameter update succeeded)
            optimizer.zero_grad()
        elif skip_backward:
            # If backward propagation was skipped, still need to clear gradients (if accumulation steps reached)
            if (engine.state.iteration % gradient_accumulation_steps) == 0:
                optimizer.zero_grad()
            
            # Learning rate scheduling: only call after parameter update
            # OneCycleLR needs to be called every step, other schedulers called at end of epoch
            if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                scheduler.step()
        
        # Accumulate loss and sample count for correct average loss calculation
        batch_size = target.size(0)
        engine.state.total_train_loss = getattr(engine.state, 'total_train_loss', 0.0) + loss.item() * batch_size * gradient_accumulation_steps
        engine.state.total_train_samples = getattr(engine.state, 'total_train_samples', 0) + batch_size * gradient_accumulation_steps
        
        # Calculate training accuracy
        if config['model']['criterion'] == 'cross_entropy' or loss_type.startswith('focal') or loss_type == 'asl':
            # Multi-class task: use argmax to calculate accuracy
            pred = torch.argmax(output, dim=1)
            correct = (pred == target).sum().item()
            engine.state.total_train_correct = getattr(engine.state, 'total_train_correct', 0) + correct
            engine.state.total_train_samples_for_acc = getattr(engine.state, 'total_train_samples_for_acc', 0) + batch_size * gradient_accumulation_steps
        
        # Periodically monitor GPU memory usage
        if engine.state.iteration % 10 == 0:  # Monitor every 10 iterations, more frequently
            if torch.cuda.is_available():
                allocated, reserved = get_memory_usage()
                if allocated > 70 * 1024**3:  # If exceeds 70GB
                    print(f"⚠️  High memory usage: {allocated / 1024**3:.1f}GB")
        
        # Periodically clear GPU cache
        if engine.state.iteration % 25 == 0:  # Clear every 25 iterations, more frequently
            clear_gpu_cache()
        
        return {'loss': loss.item() * gradient_accumulation_steps, 'y_pred': output, 'y': target, 'batch_size': batch_size}
    
    def eval_step(engine, batch):
        model.eval()
        with torch.no_grad():
            
            
            # Process based on actual type of batch
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                # If batch is a tuple or list of (data, target)
                data, target = batch
                data = data.to(device)
                target = target.to(device)
            elif hasattr(batch, 'to'):
                # If batch is a DataBatch object
                data = batch.to(device)
                if hasattr(batch, 'y'):
                    target = batch.y.to(device)
                else:
                    target = getattr(batch, 'label', None)
                    if target is None:
                        raise ValueError("Cannot find target labels in batch")
                    target = target.to(device)
            else:
                raise ValueError(f"Unexpected batch type: {type(batch)}")
            
            # Use mixed precision for inference
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                # Pass data to model
                output = model(data)
            
            # Handle classification and regression tasks
            if config['model']['criterion'] == 'cross_entropy' or loss_type.startswith('focal') or loss_type == 'asl':
                # Multi-class task: target needs to be class index
                if target.dtype == torch.float32:
                    target = target.long()  # Convert to long integer
                loss = criterion(output, target)
            else:
                # Regression task
                loss = criterion(output, target)
            
            # Return format expected by PyTorch Ignite
            return {'y_pred': output, 'y': target, 'criterion_kwargs': {}}
    
    # Create trainer and evaluator
    # Since using PyTorch Geometric's DataLoader, need to create custom trainer
    from ignite.engine import Engine
    
    # Create trainer engine
    trainer = Engine(train_step)
    
    # Create metrics for evaluator
    if config['model']['criterion'] == 'cross_entropy' or loss_type.startswith('focal') or loss_type == 'asl':
        # Multi-class task metrics
        model_evaluation_metrics = {
            'loss': Loss(criterion),
            'accuracy': Accuracy(output_transform=lambda x: (x['y_pred'], x['y'])),
            'macro_f1': MacroF1Score(output_transform=lambda x: (x['y_pred'], x['y'])),
            'micro_f1': MicroF1Score(output_transform=lambda x: (x['y_pred'], x['y'])),
            'balanced_accuracy': BalancedAccuracy(output_transform=lambda x: (x['y_pred'], x['y'])),
            'macro_auc_pr': MacroAUCScore(output_transform=lambda x: (x['y_pred'], x['y'])),
        }
    else:
        # Regression task metrics
        model_evaluation_metrics = {
            'loss': Loss(criterion),
            'mae': MeanAbsoluteError(output_transform=lambda x: (x['y_pred'], x['y'])),
        }
    
    # Create evaluator engine
    evaluator = Engine(eval_step)
    
    # Manually add metrics
    for name, metric in model_evaluation_metrics.items():
        metric.attach(evaluator, name)
    
    # Add metrics
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    
    # Setup checkpoint saving - use smarter directory creation strategy
    checkpoint_dir = f"trained_models/{config['training']['model_name']}"
    
    # Check if processed graph structure data (pkl files) already exists
    existing_graph_pkl_files = []
    for dir_path in [checkpoint_dir, f"./trained_models/{config['training']['model_name']}", 
                    f"./checkpoints/{config['training']['model_name']}", 
                    f"./models/{config['training']['model_name']}"]:
        if os.path.exists(dir_path):
            # Check if there are graph structure related pkl files
            pkl_files = [f for f in os.listdir(dir_path) if f.endswith('.pkl')]
            # Only keep graph structure related pkl files, do not keep evaluation result pkl files
            graph_pkl_files = [f for f in pkl_files if any(keyword in f.lower() for keyword in 
                           ['graph', 'protein_graphs', 'coordinates_dict', 'pathcomplex'])]
            if graph_pkl_files:
                existing_graph_pkl_files.extend([(dir_path, f) for f in graph_pkl_files])
    
    # If configuration requires cache clearing but detected existing graph structure pkl files, prompt user
    if config.get('training', {}).get('clear_cache', False):
        if existing_graph_pkl_files:
            print(f"Detected existing graph structure pkl files, will keep these files:")
            for dir_path, pkl_file in existing_graph_pkl_files:
                print(f"  - {os.path.join(dir_path, pkl_file)}")
            print("Recommend setting clear_cache: false to avoid accidentally deleting processed graph structure data")
        
        # Clean directory but keep graph structure related pkl files
        import shutil
        for dir_path in [checkpoint_dir, f"./trained_models/{config['training']['model_name']}", 
                        f"./checkpoints/{config['training']['model_name']}", 
                        f"./models/{config['training']['model_name']}"]:
            if os.path.exists(dir_path):
                try:
                    # Delete non-graph structure related pkl files and other files
                    for item in os.listdir(dir_path):
                        item_path = os.path.join(dir_path, item)
                        if os.path.isfile(item_path):
                            if item.endswith('.pkl'):
                                # Check if it's a graph structure related pkl file
                                if any(keyword in item.lower() for keyword in 
                                       ['graph', 'protein_graphs', 'coordinates_dict', 'pathcomplex']):
                                    print(f"Keeping graph structure file: {item_path}")
                                else:
                                    # Delete evaluation result and other non-graph structure pkl files
                                    os.remove(item_path)
                                    print(f"Deleted evaluation result file: {item_path}")
                            else:
                                # Delete non-pkl files
                                os.remove(item_path)
                                print(f"Deleted non-pkl file: {item_path}")
                        elif os.path.isdir(item_path):
                            shutil.rmtree(item_path)
                            print(f"Deleted subdirectory: {item_path}")
                    print(f"Cleaned directory (keeping graph structure pkl files): {dir_path}")
                except PermissionError as e:
                    print(f"Unable to clean directory {dir_path}: {e}")
                except Exception as e:
                    print(f"Error occurred while cleaning directory {dir_path}: {e}")
    
    # Try multiple possible directory locations
    possible_dirs = [
        checkpoint_dir,  # Original path
        f"./trained_models/{config['training']['model_name']}",  # Relative path
        f"./checkpoints/{config['training']['model_name']}",  # checkpoints under current directory
        f"./models/{config['training']['model_name']}",  # models under current directory
        f"./{config['training']['model_name']}_checkpoints"  # Directly in current directory
    ]
    
    checkpoint_dir = None
    for dir_path in possible_dirs:
        try:
            os.makedirs(dir_path, exist_ok=True)
            checkpoint_dir = dir_path
            print(f"Successfully created checkpoint directory: {checkpoint_dir}")
            break
        except (PermissionError, OSError) as e:
            print(f"Failed to create {dir_path}: {e}")
            continue
    
    if checkpoint_dir is None:
        # If all paths fail, use temporary directory
        import tempfile
        temp_dir = tempfile.mkdtemp(prefix=f"{config['training']['model_name']}_")
        checkpoint_dir = temp_dir
        print(f"Using temporary directory: {checkpoint_dir}")
        print(f"WARNING: This directory will be deleted when the system restarts!")
    
    print(f"Checkpoint directory: {checkpoint_dir}")
    
    # Save best model - use macro_auc_pr as scoring criterion
    def score_function(engine):
        # Prioritize macro_auc_pr, if not available use macro_f1, then micro_f1, finally fallback to balanced_accuracy
        if 'macro_auc_pr' in engine.state.metrics:
            return engine.state.metrics['macro_auc_pr']  # Higher is better
        elif 'macro_f1' in engine.state.metrics:
            return engine.state.metrics['macro_f1']  # Higher is better
        elif 'micro_f1' in engine.state.metrics:
            return engine.state.metrics['micro_f1']  # Higher is better
        elif 'balanced_accuracy' in engine.state.metrics:
            return engine.state.metrics['balanced_accuracy']  # Higher is better
        else:
            # If none of these metrics available, fallback to loss (negative value, lower is better)
            return -engine.state.metrics['loss']
    
    best_model_handler = Checkpoint(
        {'model': model, 'optimizer': optimizer},
        DiskSaver(checkpoint_dir, create_dir=True, require_empty=False),
        n_saved=3,
        filename_prefix='best',
        score_function=score_function,
        score_name='macro_auc_pr',  # Use macro_auc_pr as score name
        global_step_transform=lambda engine, event_name: engine.state.epoch,
    )
    
    # Get EarlyStopping parameters from configuration
    early_stopping_patience = config.get('training', {}).get('stability', {}).get('early_stopping_patience', 10)
    early_stopping_monitor = config.get('training', {}).get('early_stopping_monitor', 'macro_f1')
    
    print(f"EarlyStopping configuration:")
    print(f"  - Patience: {early_stopping_patience}")
    print(f"  - Monitor: {early_stopping_monitor}")
    
    # Create EarlyStopping score_function based on configuration
    def early_stopping_score_function(engine):
        # Select monitoring metric based on early_stopping_monitor parameter in configuration file
        if early_stopping_monitor == 'macro_auc_pr' and 'macro_auc_pr' in engine.state.metrics:
            return engine.state.metrics['macro_auc_pr']  # Higher is better
        elif early_stopping_monitor == 'macro_f1' and 'macro_f1' in engine.state.metrics:
            return engine.state.metrics['macro_f1']  # Higher is better
        elif early_stopping_monitor == 'balanced_accuracy' and 'balanced_accuracy' in engine.state.metrics:
            return engine.state.metrics['balanced_accuracy']  # Higher is better
        elif early_stopping_monitor == 'micro_f1' and 'micro_f1' in engine.state.metrics:
            return engine.state.metrics['micro_f1']  # Higher is better
        elif early_stopping_monitor == 'loss':
            return -engine.state.metrics['loss']  # Negative value, lower is better
        else:
            # Fallback strategy: select available metric by priority
            if 'macro_auc_pr' in engine.state.metrics:
                print(f"Warning: {early_stopping_monitor} not available, falling back to macro_auc_pr")
                return engine.state.metrics['macro_auc_pr']
            elif 'macro_f1' in engine.state.metrics:
                print(f"Warning: {early_stopping_monitor} not available, falling back to macro_f1")
                return engine.state.metrics['macro_f1']
            elif 'balanced_accuracy' in engine.state.metrics:
                print(f"Warning: {early_stopping_monitor} not available, falling back to balanced_accuracy")
                return engine.state.metrics['balanced_accuracy']
            elif 'micro_f1' in engine.state.metrics:
                print(f"Warning: {early_stopping_monitor} not available, falling back to micro_f1")
                return engine.state.metrics['micro_f1']
            else:
                print(f"Warning: No suitable metrics available, falling back to loss")
                return -engine.state.metrics['loss']
    
    # Diagnostics switch configuration
    enable_train_eval_metrics = config.get('training', {}).get('diagnostics', {}).get('enable_train_eval_metrics', True)
    enable_prediction_histogram = config.get('training', {}).get('diagnostics', {}).get('enable_prediction_histogram', True)
    train_eval_subset_size = config.get('training', {}).get('diagnostics', {}).get('train_eval_subset_size', 0.3)  # Use 30% of training set for eval
    
    print(f"Diagnostics configuration:")
    print(f"  - Train eval metrics: {enable_train_eval_metrics}")
    print(f"  - Prediction histogram: {enable_prediction_histogram}")
    print(f"  - Train eval subset size: {train_eval_subset_size}")
    
    early_stopping_handler = EarlyStopping(
        patience=early_stopping_patience,
        score_function=early_stopping_score_function,
        trainer=trainer
    )
    
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, best_model_handler)
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping_handler)
    
    # Training loop
    print(f"\nStarting training for {config['training']['epoch']} epochs...")
    
    @trainer.on(Events.EPOCH_STARTED)
    def reset_training_metrics(engine):
        engine.state.metrics = {}
        # Reset training loss accumulator
        engine.state.total_train_loss = 0.0
        engine.state.total_train_samples = 0
        # Reset training accuracy accumulator
        engine.state.total_train_correct = 0
        engine.state.total_train_samples_for_acc = 0
        
        # Update epoch information for aggressive Focal Loss
        if hasattr(criterion, 'update_epoch'):
            criterion.update_epoch(engine.state.epoch)
    
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_results_iteration(engine):
        if engine.state.iteration % 100 == 0:
            # Get current batch loss
            current_loss = engine.state.output.get('loss', 0)
            print(f"Epoch {engine.state.epoch}, Iteration {engine.state.iteration}, Loss: {current_loss:.4f}")
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results_epoch(engine):
        # Calculate correct average training loss: accumulate by batch size then divide by sample count
        if hasattr(engine.state, 'total_train_loss') and hasattr(engine.state, 'total_train_samples'):
            avg_loss = engine.state.total_train_loss / engine.state.total_train_samples
            # Reset accumulator
            engine.state.total_train_loss = 0.0
            engine.state.total_train_samples = 0
        else:
            # If no accumulator, use default method (not recommended)
            avg_loss = engine.state.metrics.get('loss', 0)
            print("Warning: Using fallback loss calculation method")
        
        # Calculate training accuracy
        if hasattr(engine.state, 'total_train_correct') and hasattr(engine.state, 'total_train_samples_for_acc'):
            train_acc = engine.state.total_train_correct / engine.state.total_train_samples_for_acc
            # Reset accumulator
            engine.state.total_train_correct = 0
            engine.state.total_train_samples_for_acc = 0
        else:
            train_acc = 0.0
        
        train_losses.append(avg_loss)
        train_accuracies.append(train_acc)
        epochs.append(engine.state.epoch)
        
        print(f"Epoch {engine.state.epoch} completed. Average training loss: {avg_loss:.4f}, Training accuracy: {train_acc:.4f}")
        
        # Display difficult sample statistics for aggressive Focal Loss
        if hasattr(criterion, 'get_difficult_sample_ratio'):
            difficult_ratio = criterion.get_difficult_sample_ratio()
            if difficult_ratio > 0:
                print(f"  - Difficult sample ratio: {difficult_ratio:.1f}% (pt < 0.3)")
        
        # Evaluate on validation set
        evaluator.run(val_loader)
        val_metrics = evaluator.state.metrics
        val_losses.append(val_metrics['loss'])
        
        # Calculate validation set accuracy
        val_acc = val_metrics.get('accuracy', 0.0)
        val_accuracies.append(val_acc)
        
        # Calculate validation set micro_f1, macro_f1, balanced_accuracy and macro_auc_pr
        val_micro_f1 = val_metrics.get('micro_f1', 0.0)
        val_macro_f1 = val_metrics.get('macro_f1', 0.0)
        val_balanced_acc = val_metrics.get('balanced_accuracy', 0.0)
        val_macro_auc_pr = val_metrics.get('macro_auc_pr', 0.0)
        val_micro_f1s.append(val_micro_f1)
        val_macro_f1s.append(val_macro_f1)
        val_balanced_accs.append(val_balanced_acc)
        val_macro_auc_prs.append(val_macro_auc_pr)
        
        print(f"Validation loss: {val_metrics['loss']:.4f}, Validation accuracy: {val_acc:.4f}")
        print(f"Validation Micro-F1: {val_micro_f1:.4f}, Macro-F1: {val_macro_f1:.4f}, Balanced Accuracy: {val_balanced_acc:.4f}, Macro AUC-PR: {val_macro_auc_pr:.4f} [Model saved based on Macro AUC-PR]")
        
        # Display detailed information for each class
        print(f"\n--- Detailed information for each class (Epoch {engine.state.epoch}) ---")
        display_class_detailed_metrics(model, val_loader, device, num_classes)
        
        # Display current training progress and time
        current_time = time.time()
        elapsed_time = current_time - training_start_time
        estimated_total_time = elapsed_time / engine.state.epoch * int(config['training']['epoch'])
        remaining_time = estimated_total_time - elapsed_time
        
        print(f"[Progress] Epoch {engine.state.epoch}/{int(config['training']['epoch'])} - Elapsed: {format_duration(elapsed_time)} - Estimated remaining: {format_duration(remaining_time)}")
        sys.stdout.flush()
        
        # Display EarlyStopping monitoring metric
        early_stop_score = early_stopping_score_function(evaluator)
        print(f"EarlyStopping monitoring score ({early_stopping_monitor}): {early_stop_score:.4f}")
        sys.stdout.flush()
        
        # ===== Diagnostics switch (a): Training set eval-mode metrics =====
        if enable_train_eval_metrics:
            print(f"\n--- Training set eval-mode diagnosis (Epoch {engine.state.epoch}) ---")
            sys.stdout.flush()
            train_eval_results = evaluate_train_set_eval_mode(
                model, train_loader, device, train_eval_subset_size, criterion
            )
            
            if train_eval_results:
                train_eval_acc = train_eval_results['accuracy']
                train_eval_micro_f1 = train_eval_results.get('micro_f1', 0.0)
                train_eval_macro_f1 = train_eval_results['macro_f1']
                train_eval_loss = train_eval_results.get('loss', 'N/A')
                
                print(f"Training set eval-mode metrics (using {train_eval_results['sample_count']}/{train_eval_results['total_samples']} samples):")
                print(f"  Accuracy: {train_eval_acc:.4f}")
                print(f"  Micro-F1: {train_eval_micro_f1:.4f}")
                print(f"  Macro-F1: {train_eval_macro_f1:.4f}")
                if train_eval_loss != 'N/A':
                    print(f"  Loss: {train_eval_loss:.4f}")
                sys.stdout.flush()
                
                # Check performance difference between training set and validation set
                train_val_acc_diff = train_eval_acc - val_acc
                train_val_micro_f1_diff = train_eval_micro_f1 - val_micro_f1
                train_val_f1_diff = train_eval_macro_f1 - val_macro_f1
                
                print(f"Training set vs validation set differences:")
                print(f"  Accuracy difference: {train_val_acc_diff:+.4f} (Training set {'higher' if train_val_acc_diff > 0 else 'lower'})")
                print(f"  Micro-F1 difference: {train_val_micro_f1_diff:+.4f} (Training set {'higher' if train_val_micro_f1_diff > 0 else 'lower'})")
                print(f"  Macro-F1 difference: {train_val_f1_diff:+.4f} (Training set {'higher' if train_val_f1_diff > 0 else 'lower'})")
                sys.stdout.flush()
                
                # Warn about abnormal differences
                if abs(train_val_acc_diff) > 0.2:
                    print(f"⚠️  Warning: Training set and validation set accuracy difference too large ({train_val_acc_diff:+.4f})")
                    print(f"   Possible causes: Overfitting, data leakage, or model configuration issues")
                
                if abs(train_val_micro_f1_diff) > 0.2:
                    print(f"⚠️  Warning: Training set and validation set Micro-F1 difference too large ({train_val_micro_f1_diff:+.4f})")
                    print(f"   Possible causes: Improper class imbalance handling, or model bias toward majority class")
                
                if abs(train_val_f1_diff) > 0.2:
                    print(f"⚠️  Warning: Training set and validation set Macro-F1 difference too large ({train_val_f1_diff:+.4f})")
                    print(f"   Possible causes: Improper class imbalance handling, or model bias toward majority class")
                
                sys.stdout.flush()
        
        # ===== Diagnostics switch (b): Prediction distribution diagnosis (plotting functionality removed) =====
        if enable_prediction_histogram:
            print(f"\n--- Prediction distribution diagnosis (Epoch {engine.state.epoch}) ---")
            sys.stdout.flush()
            
            # Use validation set results for prediction distribution analysis
            val_outputs = []
            val_targets = []
            
            model.eval()
            with torch.no_grad():
                for batch in val_loader:
                    data, target = batch
                    data = data.to(device)
                    target = target.to(device)
                    
                    if target.dtype == torch.float32:
                        target = target.long()
                    
                    output = model(data)
                    val_outputs.append(output)
                    val_targets.append(target)
            
            if val_outputs:
                val_outputs = torch.cat(val_outputs, dim=0)
                val_targets = torch.cat(val_targets, dim=0)
                
                # Get predicted classes
                val_preds = torch.argmax(val_outputs, dim=1).cpu().numpy()
                val_targets_np = val_targets.cpu().numpy()
                
                # Calculate prediction distribution
                unique_preds, pred_counts = np.unique(val_preds, return_counts=True)
                pred_distribution = dict(zip(unique_preds, pred_counts))
                
                # Calculate true distribution
                unique_targets, target_counts = np.unique(val_targets_np, return_counts=True)
                target_distribution = dict(zip(unique_targets, target_counts))
                
                # Calculate recall for each class
                per_class_recall = sk_metrics.recall_score(val_targets_np, val_preds, average=None, zero_division=0)
                
                # Calculate confusion matrix (use num_classes from config to ensure all classes are included)
                confusion_matrix = sk_metrics.confusion_matrix(
                    val_targets_np, 
                    val_preds,
                    labels=list(range(num_classes))
                )
                
                print(f"Predicted class distribution:")
                for class_id in sorted(pred_distribution.keys()):
                    count = pred_distribution[class_id]
                    percentage = count / len(val_preds) * 100
                    print(f"  Class {class_id}: {count} samples ({percentage:.1f}%)")
                sys.stdout.flush()
                
                print(f"\nTrue class distribution:")
                for class_id in sorted(target_distribution.keys()):
                    count = target_distribution[class_id]
                    percentage = count / len(val_targets_np) * 100
                    print(f"  Class {class_id}: {count} samples ({percentage:.1f}%)")
                sys.stdout.flush()
                
                print(f"\nRecall for each class:")
                for class_id in sorted(target_distribution.keys()):
                    if class_id < len(per_class_recall):
                        recall = per_class_recall[class_id]
                        print(f"  Class {class_id}: {recall:.4f}")
                sys.stdout.flush()
                
                # Check if all predictions collapse to a single class
                max_pred_class = max(pred_distribution, key=pred_distribution.get)
                max_pred_count = pred_distribution[max_pred_class]
                max_pred_percentage = max_pred_count / len(val_preds) * 100
                
                if max_pred_percentage > 80:
                    print(f"\n⚠️  Warning: Model severely biased toward class {max_pred_class} ({max_pred_percentage:.1f}%)")
                    print(f"   Possible causes: Class imbalance, learning rate too high, or loss function misconfiguration")
                    sys.stdout.flush()
                
                # Note: Plotting functionality for each epoch has been removed, only text diagnostics remain
        
        # Learning rate scheduling - only for non-OneCycleLR schedulers (epoch-level scheduling)
        if scheduler and not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
            scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Learning rate: {current_lr:.6f}")
            sys.stdout.flush()
            
            # For CosineAnnealingLR, display current epoch and total epochs
            if isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                print(f"CosineAnnealingLR: Epoch {engine.state.epoch}/{scheduler.T_max}")
                sys.stdout.flush()
    
    @evaluator.on(Events.EPOCH_STARTED)
    def reset_validation_metrics(engine):
        engine.state.metrics = {}
    
    @evaluator.on(Events.ITERATION_COMPLETED)
    def log_validation_results(engine):
        pass
    
    @evaluator.on(Events.EPOCH_COMPLETED)
    def log_validation_results_epoch(engine):
        val_metrics = engine.state.metrics
        print(f"Validation Results - Epoch {trainer.state.epoch}:")
        for name, value in val_metrics.items():
            print(f"    {name}: {value:.4f}")
        sys.stdout.flush()
        
        # Specifically highlight EarlyStopping monitoring metric
        if early_stopping_monitor == 'macro_f1' and 'macro_f1' in val_metrics:
            print(f"    *** EarlyStopping monitoring: Macro-F1 = {val_metrics['macro_f1']:.4f} ***")
        elif early_stopping_monitor == 'balanced_accuracy' and 'balanced_accuracy' in val_metrics:
            print(f"    *** EarlyStopping monitoring: Balanced Accuracy = {val_metrics['balanced_accuracy']:.4f} ***")
        elif early_stopping_monitor == 'micro_f1' and 'micro_f1' in val_metrics:
            print(f"    *** EarlyStopping monitoring: Micro-F1 = {val_metrics['micro_f1']:.4f} ***")
        elif early_stopping_monitor == 'macro_auc_pr' and 'macro_auc_pr' in val_metrics:
            print(f"    *** EarlyStopping monitoring: Macro AUC-PR = {val_metrics['macro_auc_pr']:.4f} ***")
        elif early_stopping_monitor == 'loss':
            print(f"    *** EarlyStopping monitoring: Loss = {val_metrics['loss']:.4f} ***")
        else:
            # Fallback display
            if 'macro_f1' in val_metrics:
                print(f"    *** EarlyStopping monitoring: Macro-F1 = {val_metrics['macro_f1']:.4f} ***")
            elif 'balanced_accuracy' in val_metrics:
                print(f"    *** EarlyStopping monitoring: Balanced Accuracy = {val_metrics['balanced_accuracy']:.4f} ***")
            elif 'micro_f1' in val_metrics:
                print(f"    *** EarlyStopping monitoring: Micro-F1 = {val_metrics['micro_f1']:.4f} ***")
            elif 'macro_auc_pr' in val_metrics:
                print(f"    *** EarlyStopping monitoring: Macro AUC-PR = {val_metrics['macro_auc_pr']:.4f} ***")
            else:
                print(f"    *** EarlyStopping monitoring: Loss = {val_metrics['loss']:.4f} ***")
        
        sys.stdout.flush()
    
    # Start training
    training_start_time = time.time()
    print(f"\n{'='*60}")
    print(f"Starting training loop...")
    print(f"Total epochs: {int(config['training']['epoch'])}")
    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    print(f"Test samples: {len(test_loader.dataset)}")
    print(f"Batch size: {config['training']['batch_size']}")
    print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"Learning rate: {learning_rate}")
    print(f"Weight decay: {weight_decay}")
    print(f"Loss function: {loss_type}")
    
    # Display initial GPU status
    print(f"\n{'='*50}")
    print("GPU status check")
    print(f"{'='*50}")
    monitor_gpu_memory()
    
    # Force clear GPU cache
    print("Force clearing GPU cache...")
    clear_gpu_cache()
    log_memory_usage("after cache cleanup")
    
    # Estimate training time (based on empirical values)
    estimated_epoch_time = 300  # Assume approximately 5 minutes per epoch
    estimated_total_training_time = estimated_epoch_time * int(config['training']['epoch'])
    print(f"Estimated training time: {format_duration(estimated_total_training_time)} (based on empirical estimation)")
    print(f"{'='*60}\n")
    sys.stdout.flush()
    
    try:
        trainer.run(train_loader, max_epochs=int(config['training']['epoch']))
        
        # Calculate training loop duration
        training_end_time = time.time()
        training_duration = training_end_time - training_start_time
        print(f"\n{'='*60}")
        print(f"Training loop completed!")
        print(f"Training loop duration: {format_duration(training_duration)}")
        print(f"{'='*60}")
        sys.stdout.flush()
        
    except KeyboardInterrupt:
        # Handle user interruption
        training_end_time = time.time()
        training_duration = training_end_time - training_start_time
        print(f"\n{'='*60}")
        print(f"Training interrupted by user!")
        print(f"Training loop runtime: {format_duration(training_duration)}")
        print(f"{'='*60}")
        sys.stdout.flush()
        
    except Exception as e:
        # Handle other exceptions
        training_end_time = time.time()
        training_duration = training_end_time - training_start_time
        print(f"\n{'='*60}")
        print(f"Error occurred during training: {e}")
        print(f"Training loop runtime: {format_duration(training_duration)}")
        print(f"{'='*60}")
        sys.stdout.flush()
        raise  # Re-raise exception
    
    # Save final model
    final_model_path = os.path.join(checkpoint_dir, 'final_model.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config,
        'epochs': epochs,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'val_micro_f1s': val_micro_f1s,
        'val_macro_f1s': val_macro_f1s,
        'val_balanced_accs': val_balanced_accs,
        'val_macro_auc_prs': val_macro_auc_prs,
        'class_weights': class_weights.tolist() if class_weights is not None else None,
        'loss_type': loss_type
    }, final_model_path)
    
    print(f"Final model saved to {final_model_path}")
    
    # Plot training curves (including macro_f1, balanced_accuracy and macro_auc_pr)
    plot_training_curves(epochs, train_losses, val_losses, train_accuracies, val_accuracies, 
                         val_macro_f1s, val_balanced_accs, val_macro_auc_prs, config['training']['model_name'], checkpoint_dir)
    
    # Perform validation set evaluation
    if config['model']['criterion'] == 'cross_entropy' or loss_type.startswith('focal') or loss_type == 'asl':
        print("\nPerforming threshold optimization on validation set...")
        
        # Multi-class classification does not need threshold optimization, evaluate performance directly
        print("\nMulti-class classification detected. Evaluating performance...")
        
        # Collect all outputs and targets from validation set
        val_outputs = []
        val_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                data, target = batch
                data = data.to(device)
                target = target.to(device)
                
                # Multi-class task: ensure target is long type
                if target.dtype == torch.float32:
                    target = target.long()
                
                output = model(data)
                val_outputs.append(output)
                val_targets.append(target)
        
        # Concatenate results from all batches
        val_outputs = torch.cat(val_outputs, dim=0)
        val_targets = torch.cat(val_targets, dim=0)
        
        # Directly evaluate multi-class performance
        best_threshold, best_metric, threshold_metrics = find_optimal_threshold(val_outputs, val_targets, loss_type=loss_type)
        
        # Use evaluate_with_threshold to get detailed metrics (including AUC-PR)
        val_detailed_metrics = evaluate_with_threshold(val_outputs, val_targets, None, loss_type=loss_type)
        
        # Save multi-class evaluation results
        threshold_results = {
            'best_metric': best_metric,
            'threshold_metrics': threshold_metrics,
            'detailed_metrics': val_detailed_metrics
        }
        
        threshold_path = os.path.join(checkpoint_dir, 'multi_class_evaluation.pkl')
        with open(threshold_path, 'wb') as f:
            pickle.dump(threshold_results, f)
        print(f"Multi-class evaluation results saved to {threshold_path}")
        
        print(f"\nValidation set detailed results:")
        print(f"  Accuracy: {val_detailed_metrics['accuracy']:.4f}")
        print(f"  Balanced Accuracy: {val_detailed_metrics['balanced_accuracy']:.4f}")
        print(f"  Micro-F1: {val_detailed_metrics['micro_f1']:.4f}")
        print(f"  Macro-F1: {val_detailed_metrics['macro_f1']:.4f}")
        print(f"  Macro Precision: {val_detailed_metrics['macro_precision']:.4f}")
        print(f"  Macro Recall: {val_detailed_metrics['macro_recall']:.4f}")
        print(f"  Macro AUC-PR: {val_detailed_metrics['macro_auc_pr']:.4f}")
        sys.stdout.flush()
        
        # Print per-class performance on validation set
        print(f"\nValidation set per-class performance:")
        for i, (f1, prec, rec, auc_pr) in enumerate(zip(val_detailed_metrics['class_f1'], 
                                                       val_detailed_metrics['class_precision'], 
                                                       val_detailed_metrics['class_recall'],
                                                       val_detailed_metrics['class_auc_pr'])):
            print(f"  Class {i}: F1={f1:.4f}, Precision={prec:.4f}, Recall={rec:.4f}, AUC-PR={auc_pr:.4f}")
        sys.stdout.flush()
        
        print(f"\nBest Macro-F1 on validation set: {best_metric:.4f}")
        print(f"Best Macro AUC-PR on validation set: {val_detailed_metrics['macro_auc_pr']:.4f}")
        
        # Calculate and save validation set confusion matrix
        val_preds = torch.argmax(val_outputs, dim=1).cpu().numpy()
        val_targets_np = val_targets.cpu().numpy()
        if val_targets_np.dtype == np.float32 or val_targets_np.dtype == np.float64:
            val_targets_np = val_targets_np.astype(np.int64)
        if val_targets_np.ndim > 1:
            val_targets_np = val_targets_np.squeeze()
            
        # Use num_classes from config to ensure confusion matrix includes all classes
        num_classes = config.get('training', {}).get('num_classes', len(np.unique(val_targets_np)))
        val_confusion_matrix = sk_metrics.confusion_matrix(
            val_targets_np, 
            val_preds,
            labels=list(range(num_classes))
        )
        
        # Generate class names
        class_names = [f'Class {i}' for i in range(num_classes)]
        
        # Save confusion matrix plot
        confusion_matrix_path = os.path.join(checkpoint_dir, 'val_confusion_matrix.png')
        plot_confusion_matrix(
            val_confusion_matrix,
            class_names,
            confusion_matrix_path,
            title='Validation Set Confusion Matrix',
            figsize=(12, 10)
        )
    
    # Load best model from validation set for test set evaluation
    print("\nLoading best model from validation set for test evaluation...")
    best_model_files = glob.glob(os.path.join(checkpoint_dir, 'best_*.pth'))
    
    if best_model_files:
        # Parse score value from filename, find model with highest score
        best_model_path = None
        best_score = float('-inf')
        
        # PyTorch Ignite saved filename formats may vary:
        # 1. best_checkpoint_1_macro_auc_pr=0.5852.pt (actual format)
        # 2. best_macro_auc_pr=0.8523_15.pth (standard format)
        # 3. best_model_0.8523_15.pth (old format)
        pattern1 = re.compile(r'best_checkpoint_\d+_macro_auc_pr=([\d.]+)\.pt')  # Actual format
        pattern2 = re.compile(r'best_macro_auc_pr=([\d.]+)(?:_(\d+))?\.(?:pth|pt)')  # Standard format
        pattern3 = re.compile(r'best_model_([\d.]+)_(\d+)\.(?:pth|pt)')  # Old format compatibility
        
        for model_file in best_model_files:
            filename = os.path.basename(model_file)
            match = pattern1.search(filename) or pattern2.search(filename) or pattern3.search(filename)
            if match:
                score = float(match.group(1))
                # epoch may not exist (e.g., some cases of pattern1 and pattern2)
                epoch = int(match.group(2)) if len(match.groups()) > 1 and match.group(2) else 0
                if score > best_score:
                    best_score = score
                    best_model_path = model_file
        
        if best_model_path:
            print(f"Found best model: {os.path.basename(best_model_path)}")
            print(f"  Score (macro_auc_pr): {best_score:.4f}")
            
            # Load best model
            checkpoint = torch.load(best_model_path, map_location=device)
            try:
                if 'model' in checkpoint:
                    # Ignite Checkpoint saved format - may be model object or state_dict
                    model_data = checkpoint['model']
                    if hasattr(model_data, 'state_dict'):
                        # If it's a model object
                        model.load_state_dict(model_data.state_dict())
                    else:
                        # If it's a state_dict
                        model.load_state_dict(model_data)
                elif 'model_state_dict' in checkpoint:
                    # Standard PyTorch save format
                    model.load_state_dict(checkpoint['model_state_dict'])
                else:
                    # If directly saved is state_dict
                    model.load_state_dict(checkpoint)
                print(f"Best model loaded successfully!")
            except Exception as e:
                print(f"Warning: Failed to load best model: {e}")
                print("Using final model state instead.")
        else:
            print("Warning: Could not parse best model files. Using final model state.")
    else:
        print("Warning: No best model files found. Using final model state.")
    
    # Perform test set evaluation
    print("\nEvaluating on test set with best model...")
    evaluator.run(test_loader)
    test_metrics = evaluator.state.metrics
    print(f"Final test loss: {test_metrics['loss']:.4f}")
    
    # If threshold optimization was performed, re-evaluate test set with optimal threshold
    if config['model']['criterion'] == 'cross_entropy' or loss_type.startswith('focal') or loss_type == 'asl':
        print(f"\nRe-evaluating test set for multi-class classification...")
        
        # Collect all outputs and targets from test set
        test_outputs = []
        test_targets = []
        
        with torch.no_grad():
            for batch in test_loader:
                data, target = batch
                data = data.to(device)
                target = target.to(device)
                
                # Multi-class task: ensure target is long type
                if target.dtype == torch.float32:
                    target = target.long()
                
                output = model(data)
                test_outputs.append(output)
                test_targets.append(target)
        
        # Concatenate results from all batches
        test_outputs = torch.cat(test_outputs, dim=0)
        test_targets = torch.cat(test_targets, dim=0)
        
        # Use multi-class evaluation (no threshold needed)
        threshold_test_metrics = evaluate_with_threshold(test_outputs, test_targets, None, loss_type=loss_type)
        
        print(f"Test set results for multi-class classification:")
        print(f"  Accuracy: {threshold_test_metrics['accuracy']:.4f}")
        print(f"  Balanced Accuracy: {threshold_test_metrics['balanced_accuracy']:.4f}")
        print(f"  Micro-F1: {threshold_test_metrics['micro_f1']:.4f}")
        print(f"  Macro-F1: {threshold_test_metrics['macro_f1']:.4f}")
        print(f"  Macro Precision: {threshold_test_metrics['macro_precision']:.4f}")
        print(f"  Macro Recall: {threshold_test_metrics['macro_recall']:.4f}")
        print(f"  Macro AUC-PR: {threshold_test_metrics['macro_auc_pr']:.4f}")
        sys.stdout.flush()
        
        # Print per-class performance
        print(f"\nPer-class performance:")
        for i, (f1, prec, rec, auc_pr) in enumerate(zip(threshold_test_metrics['class_f1'], 
                                                         threshold_test_metrics['class_precision'], 
                                                         threshold_test_metrics['class_recall'],
                                                         threshold_test_metrics['class_auc_pr'])):
            print(f"  Class {i}: F1={f1:.4f}, Precision={prec:.4f}, Recall={rec:.4f}, AUC-PR={auc_pr:.4f}")
        sys.stdout.flush()
        
        # Calculate confusion matrix
        test_preds = torch.argmax(test_outputs, dim=1).cpu().numpy()
        test_targets_np = test_targets.cpu().numpy()
        if test_targets_np.dtype == np.float32 or test_targets_np.dtype == np.float64:
            test_targets_np = test_targets_np.astype(np.int64)
        if test_targets_np.ndim > 1:
            test_targets_np = test_targets_np.squeeze()
            
        # Use num_classes from config to ensure confusion matrix includes all classes
        num_classes = config.get('training', {}).get('num_classes', len(np.unique(test_targets_np)))
        test_confusion_matrix = sk_metrics.confusion_matrix(
            test_targets_np, 
            test_preds,
            labels=list(range(num_classes))
        )
        
        # Generate class names
        class_names = [f'Class {i}' for i in range(num_classes)]
        
        # Save confusion matrix plot
        confusion_matrix_path = os.path.join(checkpoint_dir, 'test_confusion_matrix.png')
        plot_confusion_matrix(
            test_confusion_matrix,
            class_names,
            confusion_matrix_path,
            title='Test Set Confusion Matrix',
            figsize=(12, 10)
        )
        
        # Save multi-class test results
        final_results = {
            'default_metrics': test_metrics,
            'multi_class_metrics': threshold_test_metrics,
            'confusion_matrix': test_confusion_matrix.tolist()
        }
        
        results_path = os.path.join(checkpoint_dir, 'final_results.pkl')
        with open(results_path, 'wb') as f:
            pickle.dump(final_results, f)
        print(f"Final results saved to {results_path}")
        sys.stdout.flush()
    
    # Calculate total runtime
    end_time = time.time()
    total_duration = end_time - start_time
    
    print(f"\n{'='*60}")
    print(f"Training completed!")
    print(f"Total runtime: {format_duration(total_duration)}")
    print(f"Training loop duration: {format_duration(training_duration)}")
    print(f"Other processing time: {format_duration(total_duration - training_duration)}")
    print(f"Completion time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}")
    print(f"{'='*60}")
    sys.stdout.flush()
    
    return model, train_losses, val_losses

def evaluate_train_set_eval_mode(model, train_loader, device, subset_size=0.3, criterion=None):
    """
    Calculate metrics on training set in eval mode, removing dropout/BN perturbations
    
    Args:
        model: Model
        train_loader: Training data loader
        device: Device
        subset_size: Proportion of training set to use
        criterion: Loss function (optional)
    
    Returns:
        dict: Dictionary containing various metrics
    """
    model.eval()
    
    # Calculate number of samples to use
    total_samples = len(train_loader.dataset)
    subset_samples = int(total_samples * subset_size)
    
    # Collect predictions and targets
    all_predictions = []
    all_targets = []
    all_outputs = []
    total_loss = 0.0
    sample_count = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(train_loader):
            if sample_count >= subset_samples:
                break
                
            # Process batch data
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                data, target = batch
                data = data.to(device)
                target = target.to(device)
            elif hasattr(batch, 'to'):
                data = batch.to(device)
                if hasattr(batch, 'y'):
                    target = batch.y.to(device)
                else:
                    target = getattr(batch, 'label', None)
                    if target is None:
                        continue
                    target = target.to(device)
            else:
                continue
            
            # Ensure targets are integer type
            if target.dtype == torch.float32:
                target = target.long()
            
            # Forward propagation
            output = model(data)
            
            # Calculate loss (if loss function provided)
            if criterion is not None:
                loss = criterion(output, target)
                total_loss += loss.item() * target.size(0)
            
            # Get predicted class
            pred = torch.argmax(output, dim=1)
            
            # Collect results
            all_predictions.append(pred.cpu())
            all_targets.append(target.cpu())
            all_outputs.append(output.cpu())
            sample_count += target.size(0)
    
    # Concatenate all results
    if not all_predictions:
        return {}
    
    all_predictions = torch.cat(all_predictions, dim=0).numpy()
    all_targets = torch.cat(all_targets, dim=0).numpy()
    
    # Calculate metrics
    accuracy = sk_metrics.accuracy_score(all_targets, all_predictions)
    balanced_acc = sk_metrics.balanced_accuracy_score(all_targets, all_predictions)
    micro_f1 = sk_metrics.f1_score(all_targets, all_predictions, average='micro', zero_division=0)
    macro_f1 = sk_metrics.f1_score(all_targets, all_predictions, average='macro', zero_division=0)
    
    # Calculate recall for each class
    per_class_recall = sk_metrics.recall_score(all_targets, all_predictions, average=None, zero_division=0)
    
    # Calculate prediction class distribution
    unique_preds, pred_counts = np.unique(all_predictions, return_counts=True)
    pred_distribution = dict(zip(unique_preds, pred_counts))
    
    # Calculate true class distribution
    unique_targets, target_counts = np.unique(all_targets, return_counts=True)
    target_distribution = dict(zip(unique_targets, target_counts))
    
    # Calculate confusion matrix
    confusion_matrix = sk_metrics.confusion_matrix(all_targets, all_predictions)
    
    results = {
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'per_class_recall': per_class_recall,
        'prediction_distribution': pred_distribution,
        'target_distribution': target_distribution,
        'confusion_matrix': confusion_matrix,
        'sample_count': sample_count,
        'total_samples': total_samples
    }
    
    if criterion is not None:
        results['loss'] = total_loss / sample_count
    
    return results

# NOTE: plot_prediction_diagnostics function removed - plotting functionality has been removed
# (Only text diagnostics remain, as noted in the code comments)

def plot_confusion_matrix(confusion_matrix, class_names, save_path, title='Confusion Matrix', figsize=(10, 8)):
    """
    Plot and save confusion matrix image
    
    Args:
        confusion_matrix: Confusion matrix array
        class_names: List of class names
        save_path: Save path
        title: Plot title
        figsize: Image size
    """
    # Calculate percentages
    cm_percentage = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis] * 100
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot confusion matrix
    im = ax.imshow(confusion_matrix, cmap='Blues', interpolation='nearest')
    ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
    ax.set_xlabel('Predicted Class', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Class', fontsize=12, fontweight='bold')
    
    # Add value labels
    for i in range(confusion_matrix.shape[0]):
        for j in range(confusion_matrix.shape[1]):
            # Display original value
            text_str = f'{confusion_matrix[i, j]}'
            # If value exists, display percentage (small font)
            if confusion_matrix[i, j] > 0:
                text_str += f'\n({cm_percentage[i, j]:.1f}%)'
            # Choose font color based on background color
            color = 'white' if confusion_matrix[i, j] > confusion_matrix.max() / 2 else 'black'
            ax.text(j, i, text_str, ha="center", va="center", color=color, fontsize=9)
    
    # Set tick labels
    ax.set_xticks(range(len(class_names)))
    ax.set_yticks(range(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.set_yticklabels(class_names)
    
    # Add colorbar
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Confusion matrix saved to {save_path}")

def plot_training_curves(epochs, train_losses, val_losses, train_accuracies, val_accuracies, 
                         val_macro_f1s, val_balanced_accs, val_macro_auc_prs, model_name, checkpoint_dir):
    """Plot training and validation loss, accuracy, macro_f1, balanced_accuracy and macro_auc_pr curves"""
    # Create 2x3 subplot layout
    fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(20, 12))
    
    # Plot loss curves
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title(f'{model_name} - Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracy curves
    ax2.plot(epochs, train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_title(f'{model_name} - Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1)
    
    # Plot Macro-F1 curves
    ax3.plot(epochs, val_macro_f1s, 'g-', label='Validation Macro-F1', linewidth=2)
    ax3.set_title(f'{model_name} - Validation Macro-F1')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Macro-F1')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0, 1)
    
    # Plot Balanced Accuracy curves
    ax4.plot(epochs, val_balanced_accs, 'm-', label='Validation Balanced Accuracy', linewidth=2)
    ax4.set_title(f'{model_name} - Validation Balanced Accuracy')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Balanced Accuracy')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_ylim(0, 1)
    
    # Plot Macro AUC-PR curves
    ax5.plot(epochs, val_macro_auc_prs, 'c-', label='Validation Macro AUC-PR', linewidth=2)
    ax5.set_title(f'{model_name} - Validation Macro AUC-PR')
    ax5.set_xlabel('Epoch')
    ax5.set_ylabel('Macro AUC-PR')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    ax5.set_ylim(0, 1)
    
    # Plot all validation metrics comparison
    ax6.plot(epochs, val_macro_f1s, 'g-', label='Macro-F1', linewidth=2)
    ax6.plot(epochs, val_balanced_accs, 'm-', label='Balanced Accuracy', linewidth=2)
    ax6.plot(epochs, val_macro_auc_prs, 'c-', label='Macro AUC-PR', linewidth=2)
    ax6.set_title(f'{model_name} - All Validation Metrics')
    ax6.set_xlabel('Epoch')
    ax6.set_ylabel('Score')
    ax6.legend()
    ax6.grid(True, alpha=0.3)
    ax6.set_ylim(0, 1)
    
    plt.tight_layout()
    
    # Save comprehensive plot
    plot_path = os.path.join(checkpoint_dir, 'training_curves.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Training curves (loss, accuracy, macro_f1, balanced_accuracy) saved to {plot_path}")
    
    # Save separate loss curves (maintain backward compatibility)
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    plt.title(f'{model_name} - Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    loss_plot_path = os.path.join(checkpoint_dir, 'loss_curves.png')
    plt.savefig(loss_plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Loss curves saved to {loss_plot_path}")
    
    # Save separate accuracy curves
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    plt.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    plt.title(f'{model_name} - Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    plt.tight_layout()
    
    acc_plot_path = os.path.join(checkpoint_dir, 'accuracy_curves.png')
    plt.savefig(acc_plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Accuracy curves saved to {acc_plot_path}")
    
    # Save Macro-F1 curves
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, val_macro_f1s, 'g-', label='Validation Macro-F1', linewidth=2)
    plt.title(f'{model_name} - Validation Macro-F1')
    plt.xlabel('Epoch')
    plt.ylabel('Macro-F1')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    plt.tight_layout()
    
    macro_f1_plot_path = os.path.join(checkpoint_dir, 'macro_f1_curves.png')
    plt.savefig(macro_f1_plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Macro-F1 curves saved to {macro_f1_plot_path}")
    
    # Save Balanced Accuracy curves
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, val_balanced_accs, 'm-', label='Validation Balanced Accuracy', linewidth=2)
    plt.title(f'{model_name} - Validation Balanced Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Balanced Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    plt.tight_layout()
    
    balanced_acc_plot_path = os.path.join(checkpoint_dir, 'balanced_accuracy_curves.png')
    plt.savefig(balanced_acc_plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Balanced Accuracy curves saved to {balanced_acc_plot_path}")
    
    # Save Macro AUC-PR curves
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, val_macro_auc_prs, 'c-', label='Validation Macro AUC-PR', linewidth=2)
    plt.title(f'{model_name} - Validation Macro AUC-PR')
    plt.xlabel('Epoch')
    plt.ylabel('Macro AUC-PR')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    plt.tight_layout()
    
    macro_auc_pr_plot_path = os.path.join(checkpoint_dir, 'macro_auc_pr_curves.png')
    plt.savefig(macro_auc_pr_plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Macro AUC-PR curves saved to {macro_auc_pr_plot_path}")

def format_duration(seconds):
    """Format time display, convert seconds to more user-friendly format"""
    if seconds < 60:
        return f"{seconds:.2f} seconds"
    elif seconds < 3600:
        minutes = seconds / 60
        return f"{minutes:.2f} minutes ({seconds:.2f} seconds)"
    else:
        hours = seconds / 3600
        minutes = (seconds % 3600) / 60
        return f"{hours:.1f} hours {minutes:.1f} minutes ({seconds:.2f} seconds)"

def print_dict(d, indent=0):
    """Helper function to print dictionary"""
    for key, value in d.items():
        print(" " * indent + str(key) + ":", value)

def main():
    parser = argparse.ArgumentParser(description='Train protein classification model with improved class balancing')
    parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
    args = parser.parse_args()
    
    # Record main program start time
    main_start_time = time.time()
    print(f"Starting training with config: {args.config}")
    print(f"Main program start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(main_start_time))}")
    sys.stdout.flush()
    
    try:
        train(args.config)
        
        # Calculate main program total runtime
        main_end_time = time.time()
        main_total_duration = main_end_time - main_start_time
        
        print(f"\n{'='*60}")
        print(f"Main program completed!")
        print(f"Main program total runtime: {format_duration(main_total_duration)}")
        print(f"Completion time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(main_end_time))}")
        print(f"{'='*60}")
        sys.stdout.flush()
        
    except KeyboardInterrupt:
        # Handle user interruption
        main_end_time = time.time()
        main_total_duration = main_end_time - main_start_time
        print(f"\n{'='*60}")
        print(f"Main program interrupted by user!")
        print(f"Main program runtime: {format_duration(main_total_duration)}")
        print(f"Interruption time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(main_end_time))}")
        print(f"{'='*60}")
        sys.stdout.flush()
        
    except Exception as e:
        # Handle other exceptions
        main_end_time = time.time()
        main_total_duration = main_end_time - main_start_time
        print(f"\n{'='*60}")
        print(f"Error occurred in main program: {e}")
        print(f"Main program runtime: {format_duration(main_total_duration)}")
        print(f"Error time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(main_end_time))}")
        print(f"{'='*60}")
        sys.stdout.flush()
        raise  # Re-raise exception

if __name__ == "__main__":
    main()
