#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Single protein sequence prediction script
Functions:
1. Compute ESM2 features (1280 dimensions)
2. Compute Path homology features (360 dimensions)
3. Use trained model for prediction
"""

import os
import sys

# Set environment variables to avoid MKL errors
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# Disable some MKL optimizations to avoid segmentation faults
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'

import torch
import numpy as np
import pandas as pd
import argparse
import tempfile
from torch_geometric.data import Data
from typing import Tuple

# Set PyTorch thread count
torch.set_num_threads(1)

# Try to import Bio for FASTA reading
try:
    from Bio import SeqIO
    BIO_AVAILABLE = True
except ImportError:
    BIO_AVAILABLE = False

# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from model import QCformer
# Import ph-protein-sequence module (note: filename contains hyphen, needs special handling)
import importlib.util
ph_protein_sequence_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ph-protein-sequence.py")
spec = importlib.util.spec_from_file_location("ph_protein_sequence", ph_protein_sequence_path)
ph_protein_sequence = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ph_protein_sequence)

# Import required functions and constants
compute_kmer_features = ph_protein_sequence.compute_kmer_features
clean_sequence = ph_protein_sequence.clean_sequence
read_fasta = ph_protein_sequence.read_fasta
AA = ph_protein_sequence.AA

# Label mapping (15 classes)
# Based on label merging and class mapping relationships during training
# During training: original label -> merged label -> class index
LABEL_MAPPING = {
    0: 'C1',       # C1 -> C1 -> 类别 0
    1: 'C10-C17',  # C10-C17 -> C10-C17 -> 类别 1
    2: 'C2',       # C2 -> C2 -> 类别 2
    3: 'C3',       # C3 -> C3 -> 类别 3
    4: 'C4',       # C4 -> C4 -> 类别 4
    5: 'C5',       # C5 -> C5 -> 类别 5
    6: 'C6',       # C6 -> C6 -> 类别 6
    7: 'C7-C9',    # C7-C9 -> C7-C9 -> 类别 7
    8: 'D2',       # D2 -> D2 -> 类别 8
    9: 'D3',       # D3 -> D3 -> 类别 9
    10: 'D4',      # D4 -> D4 -> 类别 10
    11: 'D5',      # D5 -> D5 -> 类别 11
    12: 'D6-D8',   # D6-D8 -> D6-D8 -> 类别 12
    13: 'H',       # H -> H -> 类别 13
    14: 'T'        # T -> T -> 类别 14
}

def compute_esm2_from_sequence(sequence: str, esm2_model=None, alphabet=None, batch_converter=None, device='cpu'):
    """
    Compute ESM2 features from protein sequence (using esm library, reference ESM_token.py method)
    
    Args:
        sequence: Protein sequence string
        esm2_model: ESM2 model (if None, try to load)
        alphabet: ESM alphabet (if None, try to load)
        batch_converter: batch converter (if None, try to create)
        device: Device
        
    Returns:
        ESM2 feature tensor [1280]
    """
    try:
        import esm
        
        # If model not provided, try to load
        if esm2_model is None or alphabet is None:
            print("Loading ESM2 model (esm2_t33_650M_UR50D)...")
            esm2_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
            esm2_model = esm2_model.to(device)
            esm2_model.eval()
            print("ESM2 model loaded successfully")
        
        # If batch_converter not provided, create it
        if batch_converter is None:
            batch_converter = alphabet.get_batch_converter()
        
        # Clean sequence
        clean_seq = clean_sequence(sequence)
        if not clean_seq:
            print("Warning: Empty sequence after cleaning")
            return torch.zeros(1280, dtype=torch.float32, device=device)
        
        # Prepare data format (label, sequence)
        data = [("protein_1", clean_seq)]
        
        # Convert to batch format
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
        
        # Get ESM2 features
        with torch.no_grad():
            results = esm2_model(batch_tokens, repr_layers=[33], return_contacts=False)
            token_representations = results["representations"][33]  # [batch_size, seq_len, 1280]
        
        # Extract sequence representation: remove padding tokens, then mean pooling
        # token_representations[0] is the first (and only) sequence
        # 1:batch_lens[0]-1 removes start and end special tokens
        sequence_rep = token_representations[0, 1:batch_lens[0]-1].mean(0)  # [1280]
        
        # Keep on original device, don't move to CPU
        return sequence_rep
        
    except ImportError:
        print("Error: esm library not found. Please install it: pip install fair-esm")
        print("Alternatively, provide ESM2 features as a CSV file.")
        return None
    except Exception as e:
        print(f"Error computing ESM2 features: {e}")
        import traceback
        traceback.print_exc()
        return None

def compute_path_homology_features(sequence: str) -> np.ndarray:
    """
    Compute Path homology features from protein sequence (360 dimensions)
    
    Args:
        sequence: Protein sequence string
        
    Returns:
        Path homology feature array [360]
    """
    # Clean sequence
    clean_seq = clean_sequence(sequence)
    if not clean_seq:
        print("Warning: Empty sequence after cleaning")
        return np.zeros(360, dtype=np.float32)
    
    # Compute k-mer features (only compute 1-mer, 20 amino acids)
    feats = compute_kmer_features(clean_seq)
    
    # Convert feature dictionary to 360-dimensional vector
    # Order: 20 amino acids × 18 features = 360 dimensions
    AA = "ACDEFGHIKLMNPQRSTVWY"
    feature_vector = []
    
    for aa in AA:
        if aa in feats:
            feature_vector.extend(feats[aa])
        else:
            feature_vector.extend([0.0] * 18)
    
    if len(feature_vector) != 360:
        print(f"Warning: Path homology features dimension incorrect: {len(feature_vector)}, expected 360")
        # Pad or truncate
        if len(feature_vector) < 360:
            feature_vector.extend([0.0] * (360 - len(feature_vector)))
        else:
            feature_vector = feature_vector[:360]
    
    return np.array(feature_vector, dtype=np.float32)

def load_model(checkpoint_path: str, config_path: str, device='cpu'):
    """
    Load trained model
    
    Args:
        checkpoint_path: Model checkpoint path
        config_path: Configuration file path
        device: Device
        
    Returns:
        Loaded model
    """
    import yaml
    
    # Load configuration
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Create model
    model = QCformer(config)
    # Move to device first, then load checkpoint, ensure all buffers are on correct device
    model = model.to(device)
    model.eval()
    
    # Load checkpoint (using safer method)
    try:
        # Try to use weights_only parameter (PyTorch 1.13+)
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        except TypeError:
            # Older PyTorch versions don't support weights_only parameter
            checkpoint = torch.load(checkpoint_path, map_location=device)
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        print("Trying alternative loading method...")
        # Try using pickle_module
        import pickle
        checkpoint = torch.load(checkpoint_path, map_location=device, pickle_module=pickle)
    
    # Handle different checkpoint formats
    try:
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'], strict=False)
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        else:
            # Assume entire checkpoint is state_dict
            model.load_state_dict(checkpoint, strict=False)
    except Exception as e:
        print(f"Warning: Error loading state dict: {e}")
        print("Attempting partial load...")
        # Try partial loading
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'], strict=False)
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        else:
            # Only load matching keys
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict, strict=False)
    
    # Ensure model and all buffers are on correct device (move again after loading state_dict)
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded from {checkpoint_path}")
    return model, config

def predict_protein(sequence: str, model, config, esm2_features=None, 
                   esm2_dir=None, esm2_model=None, alphabet=None, batch_converter=None, device='cpu'):
    """
    Predict protein symmetry type
    
    Args:
        sequence: Protein sequence
        model: Trained model
        config: Configuration dictionary
        esm2_features: ESM2 features (optional, use directly if provided)
        esm2_dir: ESM2 feature directory (if provided, try to load from file)
        esm2_model: ESM2 model (optional, for reusing already loaded model)
        alphabet: ESM alphabet (optional, for reuse)
        batch_converter: batch converter (optional, for reuse)
        device: Device
        
    Returns:
        Predicted class and probability distribution
    """
    # Compute or load ESM2 features
    if esm2_features is None:
        # Try to load from file (if protein_id and esm2_dir provided)
        # Otherwise compute
        esm2_features = compute_esm2_from_sequence(sequence, esm2_model=esm2_model, 
                                                   alphabet=alphabet, batch_converter=batch_converter, 
                                                   device=device)
        if esm2_features is None:
            print("Error: Failed to compute ESM2 features")
            return None, None
    else:
        if isinstance(esm2_features, np.ndarray):
            esm2_features = torch.from_numpy(esm2_features).float()
        elif not isinstance(esm2_features, torch.Tensor):
            esm2_features = torch.tensor(esm2_features, dtype=torch.float32)
    
    # Ensure ESM2 features are on correct device
    if not isinstance(esm2_features, torch.Tensor):
        esm2_features = torch.tensor(esm2_features, dtype=torch.float32)
    esm2_features = esm2_features.to(device)
    
    # Ensure ESM2 feature dimensions are correct
    if esm2_features.shape != (1280,):
        if esm2_features.shape[0] == 1280:
            esm2_features = esm2_features.view(1280)
        else:
            print(f"Error: ESM2 features dimension incorrect: {esm2_features.shape}, expected (1280,)")
            return None, None
    
    # Compute Path homology features
    ph_features = compute_path_homology_features(sequence)
    if ph_features is None or ph_features.shape != (360,):
        print(f"Error: Path homology features dimension incorrect: {ph_features.shape if ph_features is not None else None}, expected (360,)")
        return None, None
    
    ph_features = torch.from_numpy(ph_features).float().to(device)
    
    # Create graph data object (minimal graph), ensure all tensors are on correct device
    graph = Data(
        x=torch.zeros(1, 1, dtype=torch.float32).to(device),  # Placeholder node features
        edge_index=torch.zeros(2, 0, dtype=torch.long).to(device),  # Empty edges
        num_nodes=1,
        path_homology_features=ph_features,  # [360], already on device
        esm2_features=esm2_features  # [1280], already on device
    )
    
    # Create batch (single graph)
    from torch_geometric.data import Batch
    batch = Batch.from_data_list([graph])
    
    # Move to device (add error handling)
    try:
        batch = batch.to(device)
        # Ensure all features are on correct device
        if hasattr(batch, 'esm2_features'):
            batch.esm2_features = batch.esm2_features.to(device)
        if hasattr(batch, 'path_homology_features'):
            batch.path_homology_features = batch.path_homology_features.to(device)
    except Exception as e:
        print(f"Error moving batch to device: {e}")
        import traceback
        traceback.print_exc()
        return None, None
    
    # Predict (add error handling)
    try:
        with torch.no_grad():
            # Ensure model is in eval mode
            model.eval()
            output = model(batch)
            if output.dim() > 1:
                output = output[0]  # Take first (and only) sample
            
            # Compute probabilities
            probs = torch.softmax(output, dim=0)
            pred_class = torch.argmax(probs, dim=0).item()
        
        # Ensure data is moved to CPU before returning
        probs_np = probs.cpu().numpy()
        return pred_class, probs_np
    except RuntimeError as e:
        print(f"Runtime error during prediction: {e}")
        print("This might be due to memory issues or model compatibility.")
        import traceback
        traceback.print_exc()
        return None, None
    except Exception as e:
        print(f"Unexpected error during prediction: {e}")
        import traceback
        traceback.print_exc()
        return None, None

def main():
    parser = argparse.ArgumentParser(description='Predict protein symmetry type from FASTA file or sequence')
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument('--fasta', type=str,
                            help='Path to FASTA file containing protein sequence(s)')
    input_group.add_argument('--sequence', type=str,
                            help='Protein sequence string (alternative to --fasta)')
    parser.add_argument('--model', type=str, 
                       default='trained_models/protein_classifier_unified/best_checkpoint_1_macro_auc_pr=0.5918.pt',
                       help='Path to model checkpoint')
    parser.add_argument('--config', type=str,
                       default='config/protein_classifier_config.yaml',
                       help='Path to config file')
    parser.add_argument('--esm2_dir', type=str, default=None,
                       help='Directory containing ESM2 feature files (optional)')
    parser.add_argument('--esm2_file', type=str, default=None,
                       help='Path to ESM2 feature CSV file (optional, only for single sequence)')
    parser.add_argument('--output', type=str, default=None,
                       help='Output CSV file path (optional, for batch prediction)')
    parser.add_argument('--device', type=str, default='cpu',
                       help='Device to use (cpu or cuda)')
    
    args = parser.parse_args()
    
    # Set device
    device = torch.device(args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu')
    print(f"Using device: {device}")
    
    # Check if model file exists
    if not os.path.exists(args.model):
        print(f"Error: Model file not found: {args.model}")
        return
    
    # Check if config file exists
    if not os.path.exists(args.config):
        print(f"Error: Config file not found: {args.config}")
        return
    
    # Read sequences
    sequences = []
    sequence_names = []
    
    if args.fasta:
        # Read from FASTA file
        if not os.path.exists(args.fasta):
            print(f"Error: FASTA file not found: {args.fasta}")
            return
        
        print(f"Reading sequences from FASTA file: {args.fasta}")
        try:
            if BIO_AVAILABLE:
                # Use Bio to read FASTA (preserve sequence names)
                for record in SeqIO.parse(args.fasta, "fasta"):
                    seq = str(record.seq)
                    clean_seq = clean_sequence(seq)
                    if clean_seq:
                        sequences.append(clean_seq)
                        sequence_names.append(record.id)
                    else:
                        print(f"Warning: Skipped empty sequence: {record.id}")
            else:
                # If Bio not available, use custom read_fasta function
                print("Bio not available, using custom FASTA reader")
                for i, seq in enumerate(read_fasta(args.fasta), 1):
                    if seq:
                        sequences.append(seq)
                        sequence_names.append(f"sequence_{i}")
        except Exception as e:
            print(f"Error reading FASTA file: {e}")
            import traceback
            traceback.print_exc()
            return
        
        if not sequences:
            print("Error: No valid sequences found in FASTA file")
            return
        print(f"Found {len(sequences)} sequence(s) in FASTA file")
        
    elif args.sequence:
        # Directly use provided sequence
        clean_seq = clean_sequence(args.sequence)
        if not clean_seq:
            print("Error: Empty sequence after cleaning")
            return
        sequences = [clean_seq]
        sequence_names = ["input_sequence"]
    
    # Load model
    print("\nLoading model...")
    model, config = load_model(args.model, args.config, device=device)
    
    # If need to compute ESM2 features, preload model (avoid reloading on each call)
    esm2_model = None
    alphabet = None
    batch_converter = None
    print("Preloading ESM2 model...")
    try:
        import esm
        esm2_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        esm2_model = esm2_model.to(device)
        esm2_model.eval()
        batch_converter = alphabet.get_batch_converter()
        print("ESM2 model preloaded successfully")
    except ImportError:
        print("Warning: esm library not found. ESM2 computation may fail.")
    except Exception as e:
        print(f"Warning: Failed to preload ESM2 model: {e}")
    
    # Predict all sequences
    results = []
    
    for idx, (seq, seq_name) in enumerate(zip(sequences, sequence_names), 1):
        print(f"\n{'='*70}")
        print(f"Processing sequence {idx}/{len(sequences)}: {seq_name}")
        print(f"{'='*70}")
        print(f"Sequence length: {len(seq)}")
        
        # Load ESM2 features (if file path provided, only effective for first sequence)
        esm2_features = None
        if args.esm2_file and idx == 1:
            if os.path.exists(args.esm2_file):
                print(f"Loading ESM2 features from {args.esm2_file}")
                try:
                    esm2_data = pd.read_csv(args.esm2_file, header=None)
                    if esm2_data.shape[1] == 1280:
                        esm2_features = torch.from_numpy(esm2_data.iloc[0].values).float()
                    else:
                        print(f"Warning: ESM2 file has wrong shape: {esm2_data.shape}")
                except Exception as e:
                    print(f"Error loading ESM2 file: {e}")
            else:
                print(f"Warning: ESM2 file not found: {args.esm2_file}")
        
        # Predict
        pred_class, probs = predict_protein(
            seq, model, config, 
            esm2_features=esm2_features,
            esm2_dir=args.esm2_dir,
            esm2_model=esm2_model,
            alphabet=alphabet,
            batch_converter=batch_converter,
            device=device
        )
        
        if pred_class is None:
            print(f"Error: Prediction failed for {seq_name}")
            continue
        
        # Display results
        predicted_label = LABEL_MAPPING[pred_class]
        confidence = probs[pred_class]
        
        print(f"\nPredicted Symmetry Type: {predicted_label}")
        print(f"Confidence: {confidence:.4f} ({confidence*100:.2f}%)")
        
        if len(sequences) == 1:
            # Show detailed results for single sequence
            print("\nTop 5 Predictions (sorted by probability):")
            print(f"{'Rank':<8} {'Symmetry Type':<20} {'Probability':<15} {'Percentage':<15}")
            print("-" * 70)
            
            # Display sorted by probability
            sorted_indices = np.argsort(probs)[::-1]
            for rank, class_idx in enumerate(sorted_indices[:5], 1):
                label_name = LABEL_MAPPING[class_idx]
                prob = probs[class_idx]
                percentage = prob * 100
                marker = " ← " if class_idx == pred_class else "   "
                print(f"{marker}{rank:<5} {label_name:<20} {prob:<15.4f} {percentage:<15.2f}%")
        
        # Save results
        sorted_indices = np.argsort(probs)[::-1]
        top2_prob = probs[sorted_indices[1]] if len(sorted_indices) > 1 else 0.0
        top3_prob = probs[sorted_indices[2]] if len(sorted_indices) > 2 else 0.0
        
        results.append({
            'sequence_name': seq_name,
            'sequence_length': len(seq),
            'predicted_label': predicted_label,
            'predicted_class': pred_class,
            'confidence': confidence,
            'top1_prob': probs[pred_class],
            'top2_prob': top2_prob,
            'top3_prob': top3_prob,
        })
    
    # Output results
    if not results:
        print(f"\n{'='*70}")
        print("ERROR: No successful predictions")
        print("="*70)
        print("All predictions failed. Please check the error messages above.")
        return
    
    print(f"\n{'='*70}")
    print("SUMMARY")
    print("="*70)
    if len(sequences) == 1:
        if len(results) > 0:
            print(f"Final Prediction: {results[0]['predicted_label']} (confidence: {results[0]['confidence']*100:.2f}%)")
        else:
            print("Prediction failed for this sequence.")
    else:
        # Batch results display table
        print(f"\n{'Sequence Name':<30} {'Length':<10} {'Predicted Label':<20} {'Confidence':<15}")
        print("-" * 75)
        for r in results:
            print(f"{r['sequence_name']:<30} {r['sequence_length']:<10} {r['predicted_label']:<20} {r['confidence']*100:<15.2f}%")
        
        # Save to CSV file
        if args.output:
            df = pd.DataFrame(results)
            df.to_csv(args.output, index=False)
            print(f"\nResults saved to: {args.output}")
        else:
            print("\nTip: Use --output <file.csv> to save batch results to a file")

if __name__ == '__main__':
    main()
