import numpy as np
import torch
from typing import List, Dict, Optional, Tuple
from Bio import PDB
from Bio.PDB import PDBParser
import os
import pandas as pd
from torch_geometric.data import Data
import pickle
from tqdm import tqdm
from codebar import topo_fea  # Import codebar feature calculation class

def extract_ca_coordinates(pdb_file: str) -> Tuple[np.ndarray, List[str]]:
    """
    Extract CA atom coordinates and residue information from PDB file
    
    Args:
        pdb_file: PDB file path
        
    Returns:
        coordinates: CA atom coordinate array [num_residues, 3]
        residue_names: Residue name list
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    
    ca_coords = []
    residue_names = []
    
    for model in structure:
        for chain in model:
            for residue in chain:
                if 'CA' in residue:
                    ca_atom = residue['CA']
                    ca_coords.append(ca_atom.get_coord())
                    residue_names.append(residue.get_resname())
    
    return np.array(ca_coords), residue_names

def create_protein_graph(ca_coords: np.ndarray, 
                        residue_names: List[str], 
                        distance_threshold: float = 6.0,
                        pdb_file: str = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Create protein graph based on CA atom coordinates, including triangle features and codebar features
    
    Args:
        ca_coords: CA atom coordinates [num_residues, 3]
        residue_names: Residue name list
        distance_threshold: Connection threshold (Angstrom)
        pdb_file: PDB file path (for computing codebar features)
        
    Returns:
        node_attr: Node features [num_residues, feature_dim]
        edge_index: Edge index [2, num_edges]
        edge_attr: Edge features [num_edges, feature_dim]
        triangle_index: Triangle index [2, num_triangles]
        triangle_dis: Triangle distance features [num_triangles, 9]
        codebar_features: Codebar features [35]
    """
    num_residues = len(ca_coords)
    
    # Create edges (based on distance threshold)
    edge_index = [[], []]
    edge_attr = []
    
    for i in range(num_residues):
        for j in range(i + 1, num_residues):  # Avoid self-loops and duplicate edges
            distance = np.linalg.norm(ca_coords[i] - ca_coords[j])
            
            if distance < distance_threshold:
                # Add undirected edge (only add one edge, enforce i < j)
                edge_index[0].append(i)
                edge_index[1].append(j)
                
                # Edge features: distance and relative position
                edge_feat = edge_features(distance, ca_coords[i], ca_coords[j], residue_names[i], residue_names[j])
                edge_attr.append(edge_feat)
    
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float32)
    
    # Calculate degree for each residue
    degrees = calculate_residue_degrees(edge_index, num_residues)
    
    # Create node features (residue type encoding + degree features)
    node_attr = []
    for i, residue_name in enumerate(residue_names):
        features = residue_features(residue_name, degrees[i])
        node_attr.append(features)
    
    node_attr = torch.tensor(node_attr, dtype=torch.float32)
    
    # Generate triangle features
    triangle_index, triangle_dis = create_triangle_features(ca_coords, edge_index, residue_names, distance_threshold)
    
    # Compute codebar features
    codebar_features = None
    if pdb_file and os.path.exists(pdb_file):
        try:
            # Use codebar to compute topological features
            topo_calculator = topo_fea(pdb_file, neighbor_dis=8, Cut=8)
            codebar_df = topo_calculator.cal_fea()
            
            # Extract feature columns (excluding ID column)
            feature_cols = [col for col in codebar_df.columns if col != 'ID']
            codebar_features = codebar_df[feature_cols].values
            
            # Keep residue dimension, do not average [num_residues, 35]
            if codebar_features.size > 0:
                # Ensure correct dimension: 35-dimensional features per residue
                if codebar_features.shape[1] == 35:
                    codebar_features = torch.tensor(codebar_features, dtype=torch.float32)  # [num_residues, 35]
                else:
                    print(f"Warning: Codebar feature dimension mismatch, expected 35, got {codebar_features.shape[1]}")
                    codebar_features = torch.zeros(len(ca_coords), 35, dtype=torch.float32)
            else:
                # If no features, create zero vector
                codebar_features = torch.zeros(len(ca_coords), 35, dtype=torch.float32)
        except Exception as e:
            print(f"Warning: Failed to compute codebar features for {pdb_file}: {e}")
            # If computation fails, create zero vector
            codebar_features = torch.zeros(len(ca_coords), 35, dtype=torch.float32)
    else:
        # If no PDB file, create zero vector
        codebar_features = torch.zeros(len(ca_coords), 35, dtype=torch.float32)
    
    return node_attr, edge_index, edge_attr, triangle_index, triangle_dis, codebar_features

def calculate_residue_degrees(edge_index: torch.Tensor, num_residues: int) -> List[int]:
    """
    Calculate degree (number of connected edges) for each residue
    
    Args:
        edge_index: Edge index [2, num_edges] (undirected edges)
        num_residues: Number of residues
        
    Returns:
        List of degrees for each residue
    """
    degrees = [0] * num_residues
    
    # Count number of edges each node participates in (undirected edges)
    for i in range(edge_index.shape[1]):
        src = edge_index[0, i].item()  # Source node
        dst = edge_index[1, i].item()  # Destination node
        degrees[src] += 1  # Source node degree +1
        degrees[dst] += 1  # Destination node degree +1
    
    return degrees

def residue_features(residue_name: str, degree: int) -> List[float]:
    """
    Encode residue name into feature vector, including degree features
    
    Args:
        residue_name: Residue name (3-letter code)
        degree: Residue degree (number of connected edges)
        
    Returns:
        Feature vector
    """
    # 20 standard amino acids
    amino_acids = [
        'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
        'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'
    ]
    
    # One-hot encoding
    features = [1.0 if aa == residue_name else 0.0 for aa in amino_acids]
    
    # Add additional physicochemical features
    phys_chem_features = get_physicochemical_features(residue_name)
    features.extend(phys_chem_features)
    
    # Add degree features (one-hot encoding, limited to 0-10 range)
    degree_features = degree_to_onehot(degree)
    features.extend(degree_features)
    
    return features

def get_physicochemical_features(residue_name: str) -> List[float]:
    """
    Get physicochemical features of residue
    
    Args:
        residue_name: Residue name
        
    Returns:
        List of physicochemical features
    """
    # Simplified physicochemical features (can be extended as needed)
    features = {
        'ALA': [0.0, 0.0, 0.0, 0.0],  # Hydrophobicity, Polarity, Charge, Size
        'ARG': [0.0, 1.0, 1.0, 1.0],
        'ASN': [0.0, 1.0, 0.0, 0.5],
        'ASP': [0.0, 1.0, -1.0, 0.5],
        'CYS': [0.5, 0.0, 0.0, 0.5],
        'GLN': [0.0, 1.0, 0.0, 0.5],
        'GLU': [0.0, 1.0, -1.0, 0.5],
        'GLY': [0.0, 0.0, 0.0, 0.0],
        'HIS': [0.5, 0.5, 0.5, 0.5],
        'ILE': [1.0, 0.0, 0.0, 1.0],
        'LEU': [1.0, 0.0, 0.0, 1.0],
        'LYS': [0.0, 1.0, 1.0, 1.0],
        'MET': [0.5, 0.0, 0.0, 1.0],
        'PHE': [1.0, 0.0, 0.0, 1.0],
        'PRO': [0.5, 0.0, 0.0, 0.5],
        'SER': [0.0, 1.0, 0.0, 0.5],
        'THR': [0.0, 1.0, 0.0, 0.5],
        'TRP': [1.0, 0.0, 0.0, 1.0],
        'TYR': [0.5, 0.5, 0.0, 1.0],
        'VAL': [1.0, 0.0, 0.0, 1.0]
    }
    
    return features.get(residue_name, [0.0, 0.0, 0.0, 0.0])

def degree_to_onehot(degree: int) -> List[float]:
    """
    Convert degree to one-hot encoding, limited to 0-10 range
    
    Args:
        degree: Degree
        
    Returns:
        One-hot encoding of degree (11 dimensions: 0-10)
    """
    # Limit degree to 0-10 range
    degree = max(0, min(10, degree))
    
    # 11-dimensional one-hot encoding (0-10)
    onehot = [0.0] * 11
    onehot[degree] = 1.0
    
    return onehot

def residue_to_onehot(residue_name: str) -> List[float]:
    """
    Convert residue name to one-hot encoding
    
    Args:
        residue_name: Residue name (3-letter code)
        
    Returns:
        One-hot encoding of residue (20 dimensions)
    """
    # 20 standard amino acids
    amino_acids = [
        'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
        'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'
    ]
    
    # 20-dimensional one-hot encoding
    onehot = [0.0] * 20
    try:
        idx = amino_acids.index(residue_name)
        onehot[idx] = 1.0
    except ValueError:
        # If residue name is not in standard list, return all-zero vector
        pass
    
    return onehot

def edge_features(distance: float, coord1: np.ndarray, coord2: np.ndarray, residue1: str, residue2: str) -> List[float]:
    """
    Compute edge features: distance, relative position, and one-hot encoding of two vertex residues
    
    Args:
        distance: Distance between two points
        coord1: Coordinates of first point
        coord2: Coordinates of second point
        residue1: Residue name of first point
        residue2: Residue name of second point
        
    Returns:
        Edge feature vector [distance, rel_x, rel_y, rel_z, residue1_onehot(20), residue2_onehot(20)]
        Total 44 dimensions: 4 geometric features + 40 residue features
    """
    relative_pos = coord2 - coord1
    
    # Geometric features: distance and relative position
    geometric_features = [distance, relative_pos[0], relative_pos[1], relative_pos[2]]
    
    # Residue features: one-hot encoding of two vertices
    residue1_features = residue_to_onehot(residue1)
    residue2_features = residue_to_onehot(residue2)
    
    # Merge all features
    all_features = geometric_features + residue1_features + residue2_features
    
    return all_features

def create_triangle_features(ca_coords: np.ndarray, edge_index: torch.Tensor, residue_names: List[str], distance_threshold: float) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create triangle features based on edges, including one-hot encoding of three vertex residues and RBF-expanded edge lengths
    
    Args:
        ca_coords: CA atom coordinates [num_residues, 3]
        edge_index: Edge index [2, num_edges]
        residue_names: Residue name list
        distance_threshold: Distance threshold
        
    Returns:
        triangle_index: Triangle index [2, num_triangles]
        triangle_dis: Triangle features [num_triangles, 78*3 + 6 + 60] (RBF edge lengths + 6 geometric features + 60 residue features)
    """
    num_residues = len(ca_coords)
    num_edges = edge_index.shape[1]
    
    # Create edge-to-node mapping
    edge_to_node = {}
    for i in range(num_edges):
        src, dst = edge_index[0, i].item(), edge_index[1, i].item()
        if src < dst:  # Only store one direction
            edge_to_node[(src, dst)] = i
    
    triangle_index = []
    triangle_dis = []
    
    # Find triangles: for each edge, find third node to form triangle
    for i in range(num_edges):
        src, dst = edge_index[0, i].item(), edge_index[1, i].item()
        
        # Find third node
        for third in range(num_residues):
            if third != src and third != dst:
                # Check if connected to both src and dst
                if ((src, third) in edge_to_node or (third, src) in edge_to_node) and \
                   ((dst, third) in edge_to_node or (third, dst) in edge_to_node):
                    
                    # Calculate three edge lengths of triangle
                    d1 = np.linalg.norm(ca_coords[src] - ca_coords[dst])
                    d2 = np.linalg.norm(ca_coords[src] - ca_coords[third])
                    d3 = np.linalg.norm(ca_coords[dst] - ca_coords[third])
                    
                    # Calculate three angles (cosine values)
                    v1 = ca_coords[dst] - ca_coords[src]
                    v2 = ca_coords[third] - ca_coords[src]
                    v3 = ca_coords[third] - ca_coords[dst]
                    
                    # Normalize vectors
                    v1_norm = v1 / (np.linalg.norm(v1) + 1e-8)
                    v2_norm = v2 / (np.linalg.norm(v2) + 1e-8)
                    v3_norm = v3 / (np.linalg.norm(v3) + 1e-8)
                    
                    # Calculate angle cosine values
                    cos1 = np.dot(v1_norm, v2_norm)
                    cos2 = np.dot(-v1_norm, v3_norm)
                    cos3 = np.dot(v2_norm, v3_norm)
                    
                    # Geometric features: [cos1, cos2, cos3, area, perimeter, compactness] (6 dimensions, excluding original edge lengths)
                    area = 0.5 * np.linalg.norm(np.cross(v1, v2))
                    perimeter = d1 + d2 + d3
                    compactness = (area ** 2) / (perimeter ** 3 + 1e-8)
                    
                    # Residue features: one-hot encoding of residue types of three vertices
                    residue_features_src = residue_to_onehot(residue_names[src])
                    residue_features_dst = residue_to_onehot(residue_names[dst])
                    residue_features_third = residue_to_onehot(residue_names[third])
                    
                    # Note: Edge length features will be processed by RBF expansion inside the model
                    # Here only provide original edge length values for RBF expansion
                    # Merge all features: [edge lengths(3) + geometric features(6) + residue features(60)]
                    triangle_feature = [d1, d2, d3, cos1, cos2, cos3, area, perimeter, compactness]
                    triangle_feature.extend(residue_features_src)
                    triangle_feature.extend(residue_features_dst)
                    triangle_feature.extend(residue_features_third)
                    
                    # Add triangle index (edge index)
                    triangle_index.append([i, i])  # Use edge index
                    triangle_dis.append(triangle_feature)
    
    if len(triangle_index) == 0:
        # If no triangles found, return empty tensors
        triangle_index = torch.zeros(2, 0, dtype=torch.long)
        triangle_dis = torch.zeros(0, 69, dtype=torch.float32)  # 3 edge lengths + 6 geometric features + 60 residue features
    else:
        triangle_index = torch.tensor(triangle_index, dtype=torch.long).t()
        triangle_dis = torch.tensor(triangle_dis, dtype=torch.float32)
    
    return triangle_index, triangle_dis

def process_protein_dataset_from_dir(pdb_dir: str, 
                                   distance_threshold: float = 6.0,
                                   output_file: str = None,
                                   esm2_dir: str = 'data/ESM2') -> List[Data]:
    """
    Process protein dataset in specified directory
    
    Args:
        pdb_dir: Directory containing PDB files
        distance_threshold: Connection threshold
        output_file: Output file path (optional)
        esm2_dir: ESM2 feature file directory
        
    Returns:
        List of protein graphs
    """
    protein_graphs = []
    
    if not os.path.exists(pdb_dir):
        print(f"Warning: Directory {pdb_dir} does not exist!")
        return protein_graphs
    
    # Get all PDB files in directory
    pdb_files = []
    for filename in os.listdir(pdb_dir):
        if filename.endswith('.pdb') or filename.endswith('.pdbqt'):
            pdb_files.append(filename)
    
    print(f"Found {len(pdb_files)} PDB files in {pdb_dir}")
    
    for filename in tqdm(pdb_files, desc=f"Processing {os.path.basename(pdb_dir)}"):
        pdb_path = os.path.join(pdb_dir, filename)
        
        try:
            # Extract CA atom coordinates
            ca_coords, residue_names = extract_ca_coordinates(pdb_path)
            
            # Create protein graph
            node_attr, edge_index, edge_attr, triangle_index, triangle_dis, codebar_features = create_protein_graph(
                ca_coords, residue_names, distance_threshold, pdb_path
            )
            
            # Load ESM2 features
            protein_id = os.path.splitext(os.path.basename(pdb_path))[0]
            esm2_features = load_esm2_features(protein_id, esm2_dir)
            
            # Create PyTorch Geometric Data object
            data = Data(
                x=node_attr,
                edge_index=edge_index,
                edge_attr=edge_attr,
                triangle_index=triangle_index,
                triangle_dis=triangle_dis,
                codebar_features=codebar_features,  # Add codebar features
                esm2_features=esm2_features,       # Add ESM2 features
                num_nodes=len(ca_coords)
            )
            
            protein_graphs.append(data)
            
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            continue
    
    # Save to file (optional)
    if output_file:
        with open(output_file, 'wb') as f:
            pickle.dump(protein_graphs, f)
        print(f"Saved {len(protein_graphs)} protein graphs to {output_file}")
    
    return protein_graphs


def process_protein_dataset(pdb_dir: str, 
                          distance_threshold: float = 6.0,
                          output_file: str = None) -> List[Data]:
    """
    Process protein dataset (compatible with old version)
    
    Args:
        pdb_dir: Directory containing PDB files
        distance_threshold: Connection threshold
        output_file: Output file path (optional)
        
    Returns:
        List of protein graphs
    """
    protein_graphs = []
    
    for filename in os.listdir(pdb_dir):
        if filename.endswith('.pdb') or filename.endswith('.pdbqt'):
            pdb_path = os.path.join(pdb_dir, filename)
            
            try:
                # Extract CA atom coordinates
                ca_coords, residue_names = extract_ca_coordinates(pdb_path)
                
                # Create protein graph
                node_attr, edge_index, edge_attr, triangle_index, triangle_dis, codebar_features = create_protein_graph(
                    ca_coords, residue_names, distance_threshold, pdb_path
                )
                
                # Load ESM2 features
                protein_id = os.path.splitext(os.path.basename(pdb_path))[0]
                esm2_features = load_esm2_features(protein_id, esm2_dir)
                
                # Create PyTorch Geometric Data object
                data = Data(
                    x=node_attr,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    triangle_index=triangle_index,
                    triangle_dis=triangle_dis,
                    codebar_features=codebar_features,  # Add codebar features
                    esm2_features=esm2_features,       # Add ESM2 features
                    num_nodes=len(ca_coords)
                )
                
                protein_graphs.append(data)
                
            except Exception as e:
                print(f"Error processing {filename}: {e}")
                continue
    
    # Save to file (optional)
    if output_file:
        with open(output_file, 'wb') as f:
            pickle.dump(protein_graphs, f)
    
    return protein_graphs

def load_protein_graphs(file_path: str) -> List[Data]:
    """
    Load protein graphs from file
    
    Args:
        file_path: File path
        
    Returns:
        List of protein graphs
    """
    with open(file_path, 'rb') as f:
        protein_graphs = pickle.load(f)
    return protein_graphs

def load_esm2_features(protein_id: str, esm2_dir: str = 'data/ESM2') -> torch.Tensor:
    """
    Load ESM2 features and perform mean pooling over residue dimension
    
    Args:
        protein_id: Protein ID
        esm2_dir: ESM2 feature file directory
        
    Returns:
        ESM2 feature tensor [1280], returns zero vector if loading fails
    """
    try:
        # Try different file extensions
        esm2_file = None
        for ext in ['.csv']:
            potential_file = os.path.join(esm2_dir, protein_id + ext)
            if os.path.exists(potential_file):
                esm2_file = potential_file
                break
        
        if esm2_file is None:
            print(f"Warning: ESM2 feature file does not exist: {protein_id}")
            return torch.zeros(1280, dtype=torch.float32)
        
        # Read CSV file - format is single row with 1280 comma-separated values
        # Due to special file format, try direct reading first
        try:
            esm2_data = pd.read_csv(esm2_file, header=None)
            
            # Check dimension - input should be [1, 1280] or [1280]
            if esm2_data.shape[1] == 1280:
                # Single row, 1280 columns - this is the correct format
                if esm2_data.shape[0] == 1:
                    # Directly take first row
                    esm2_features = esm2_data.iloc[0].values  # [1280]
                else:
                    # Multiple rows, take average
                    esm2_features = esm2_data.mean(axis=0).values  # [1280]
            elif esm2_data.shape[0] == 1280:
                # 1280 rows, single column - need transpose
                esm2_features = esm2_data.iloc[:, 0].values  # [1280]
            else:
                # Other format, try manual parsing
                print(f"Warning: ESM2 feature dimension incorrect: {protein_id}, shape {esm2_data.shape}, trying manual parsing")
                raise ValueError("Unexpected shape")
                
        except Exception as e:
            print(f"Pandas reading failed for {protein_id}, trying manual parsing: {e}")
            
            # Manually parse CSV file
            with open(esm2_file, 'r') as f:
                content = f.read().strip()
                values = content.split(',')
                
                if len(values) != 1280:
                    print(f"Warning: ESM2 feature value count incorrect: {protein_id}, expected 1280, got{len(values)}")
                    return torch.zeros(1280, dtype=torch.float32)
                
                # Convert to float
                try:
                    esm2_features = [float(val) for val in values]
                except ValueError as e:
                    print(f"Error converting ESM2 values to float for {protein_id}: {e}")
                    return torch.zeros(1280, dtype=torch.float32)
        
        # Ensure it's a numpy array
        if not hasattr(esm2_features, 'shape'):
            esm2_features = np.array(esm2_features)
        
        # Final check
        if esm2_features.shape != (1280,):
            print(f"Warning: Final ESM2 features shape incorrect: {protein_id}, shape {esm2_features.shape}")
            return torch.zeros(1280, dtype=torch.float32)
        
        # Check for abnormal values
        if np.isnan(esm2_features).any():
            print(f"Warning: NaN values detected in ESM2 features for {protein_id}")
            return torch.zeros(1280, dtype=torch.float32)
        
        if np.isinf(esm2_features).any():
            print(f"Warning: Inf values detected in ESM2 features for {protein_id}")
            return torch.zeros(1280, dtype=torch.float32)
        
        # Check if value range is reasonable
        min_val = esm2_features.min()
        max_val = esm2_features.max()
        if abs(min_val) > 1000 or abs(max_val) > 1000:
            print(f"Warning: ESM2 features have unusually large values for {protein_id}: range [{min_val:.6f}, {max_val:.6f}]")
        
        # Create tensor and mark as graph-level feature
        esm2_tensor = torch.tensor(esm2_features, dtype=torch.float32)
        esm2_tensor._is_graph_level = True  # Mark as graph-level feature
        
        return esm2_tensor
        
    except Exception as e:
        print(f"Error loading ESM2 features for {protein_id}: {e}")
        return torch.zeros(1280, dtype=torch.float32)


# Used to track warned missing files, avoid duplicate warnings
_missing_ph_files_warned = set()

def load_path_homology_features(protein_id: str, path_homology_dir: str = '/backup/data3/Homology_test/PH_360_5/symm_360_5/symm1mer_012', verbose: bool = False) -> torch.Tensor:
    """
    Load path_homology features (360 dimensions)
    
    Args:
        protein_id: Protein ID (needs to be converted to uppercase, e.g., 1A0P)
        path_homology_dir: path_homology feature file directory
        verbose: Whether to print detailed warning information (default False, only warn once)
    
    Returns:
        path_homology feature tensor [360], returns zero vector if loading fails
    """
    global _missing_ph_files_warned
    try:
        # Convert protein_id to uppercase (because file names are uppercase, e.g., 1A0P.csv)
        protein_id_upper = str(protein_id).upper()
        
        # Try different file extensions
        ph_file = None
        for ext in ['.csv']:
            potential_file = os.path.join(path_homology_dir, protein_id_upper + ext)
            if os.path.exists(potential_file):
                ph_file = potential_file
                break
        
        # If exact match fails, try case-insensitive matching
        # Note: Only use case-insensitive matching when exact match fails
        # And need to verify file format (240 dimensions) to avoid matching wrong format files
        if ph_file is None and os.path.exists(path_homology_dir):
            files_in_dir = [f for f in os.listdir(path_homology_dir) if f.endswith('.csv')]
            # Perform case-insensitive matching
            for f in files_in_dir:
                # Compare after removing extension
                file_id = os.path.splitext(f)[0]
                if file_id.upper() == protein_id_upper:
                    candidate_file = os.path.join(path_homology_dir, f)
                    # Verify file format: check if file size is reasonable (360 values approximately need 1500-4500 bytes)
                    file_size_check = os.path.getsize(candidate_file)
                    if 500 < file_size_check < 6000:  # Reasonable file size range
                        ph_file = candidate_file
                        break
                    else:
                        # File size unreasonable, may be wrong format file, skip
                        print(f"Warning: Found case-insensitive match {f} but file size {file_size_check} bytes is suspicious (expected ~1500-4500 bytes for 360 values). Skipping.")
                        continue
        
        if ph_file is None:
            # Only warn on first encounter of missing file, or every time if verbose=True
            if verbose or protein_id_upper not in _missing_ph_files_warned:
                print(f"Warning: path_homology feature file does not exist: {protein_id_upper} in {path_homology_dir}")
                # Only list similar files in verbose mode
                if verbose and os.path.exists(path_homology_dir):
                    files_in_dir = [f for f in os.listdir(path_homology_dir) if f.endswith('.csv')]
                    similar_files = [f for f in files_in_dir if protein_id_upper.upper() in f.upper()][:5]
                    if similar_files:
                        print(f"  Similar files in directory: {similar_files}")
                _missing_ph_files_warned.add(protein_id_upper)
            # Silently return zero vector
            return torch.zeros(360, dtype=torch.float32)
        
        # First verify file size and path
        file_size = os.path.getsize(ph_file)
        
        # Add debug information
        if file_size < 500:  # File too small, may not be correct file
            print(f"Warning: path_homology feature file unusually small: {protein_id_upper}")
            print(f"  File path: {ph_file}")
            print(f"  file size: {file_size} bytes (expected at least hundreds of bytes, 360 values usually require 1500+ bytes)")
            print(f"  This may be an incorrect file or the file may be truncated")
            
            # Check if there are other files with same name
            if os.path.exists(path_homology_dir):
                all_files = [f for f in os.listdir(path_homology_dir) if f.startswith(protein_id_upper.upper())]
                if len(all_files) > 1:
                    print(f"  Found multiple matching files: {all_files}")
                    # Try to use largest file
                    file_sizes = [(f, os.path.getsize(os.path.join(path_homology_dir, f))) for f in all_files]
                    file_sizes.sort(key=lambda x: x[1], reverse=True)
                    if file_sizes[0][1] > file_size:
                        print(f"  Trying to use larger file: {file_sizes[0][0]} ({file_sizes[0][1]} bytes)")
                        ph_file = os.path.join(path_homology_dir, file_sizes[0][0])
                        file_size = file_sizes[0][1]
        if file_size < 100:  # File too small, may not be correct file
            print(f"Warning: path_homology feature file too small: {protein_id_upper}, file size {file_size} bytes, file path: {ph_file}")
            print(f"  This may be an incorrect file, expected at least hundreds of bytes (360 values)")
            return torch.zeros(360, dtype=torch.float32)
        
        # Read CSV file - format is single row with 360 comma-separated values (no header)
        # Prefer manual parsing to ensure reading complete 360-dimensional features
        ph_features = None
        
        # Manual parsing (most reliable method, ensures reading complete file)
        try:
            # Read in binary mode, then decode, to avoid truncation due to encoding issues
            with open(ph_file, 'rb') as f:
                # Read entire file content (binary mode, ensures complete reading)
                raw_content = f.read()
                
                # Verify read binary content length matches file size
                if len(raw_content) != file_size:
                    print(f"Error: Binary read incomplete: {protein_id_upper}, file size {file_size}, read length{len(raw_content)}")
                    print(f"  File path: {ph_file}")
                    # Try re-reading
                    with open(ph_file, 'rb') as f2:
                        raw_content = f2.read()
                        if len(raw_content) != file_size:
                            print(f"  Re-read still incomplete, file may be locked or corrupted")
                            ph_features = None
                            raise ValueError(f"File read incomplete: expected {file_size} bytes, got {len(raw_content)} bytes")
                
                # Try decoding as UTF-8, if fails try other encodings
                try:
                    content = raw_content.decode('utf-8')
                except UnicodeDecodeError:
                    try:
                        content = raw_content.decode('latin-1')
                    except UnicodeDecodeError:
                        content = raw_content.decode('utf-8', errors='ignore')
                
                # Record original content information for debugging
                original_length = len(content)
                original_comma_count = content.count(',')
                
                # Check if file is truncated (if read length is significantly less than file size)
                if original_length < file_size * 0.8:  # If read length is less than 80% of file size, may have problem
                    print(f"Warning: Decoded content length mismatch: {protein_id_upper}, file size {file_size}, decoded length{original_length}")
                    print(f"  File path: {ph_file}")
                    # Try re-decoding with errors='replace'
                    try:
                        content = raw_content.decode('utf-8', errors='replace')
                        original_length = len(content)
                        original_comma_count = content.count(',')
                        print(f"  Re-decoding with errors=\'replace\': length{original_length}, comma count{original_comma_count}")
                    except Exception as e:
                        print(f"  Re-decoding failed: {e}")
                
                # Remove all newlines, carriage returns, and tabs
                content = content.replace('\n', '').replace('\r', '').replace('\t', '')
                
                # Remove possible quotes
                content = content.strip().strip('"').strip("'")
                
                # Split by comma, keep all values (including possible empty values)
                values = content.split(',')
                
                # Clean each value: remove leading/trailing whitespace, but keep empty strings (for counting)
                values = [v.strip() for v in values]
                
                # Remove trailing empty values (may be caused by trailing comma or newline)
                while values and not values[-1]:
                    values.pop()
                
                # Verify value count
                num_values = len(values)
                if num_values == 360:
                    # Correct! Convert to float array
                    try:
                        ph_features = np.array([float(val) if val else 0.0 for val in values], dtype=np.float32)
                    except ValueError as e:
                        print(f"Error converting path_homology values to float for {protein_id_upper}: {e}")
                        print(f"  Problematic values: {[v for v in values if not v or not v.replace('.', '').replace('-', '').isdigit()][:10]}")
                        ph_features = None
                elif num_values > 360:
                    # Too many values, only take first 360
                    print(f"Warning: path_homology feature value count too many: {protein_id_upper}, expected 360, got {num_values}, taking first 360")
                    try:
                        ph_features = np.array([float(val) if val else 0.0 for val in values[:360]], dtype=np.float32)
                    except ValueError as e:
                        print(f"  Error truncating features: {e}")
                        ph_features = None
                elif num_values > 0 and num_values < 360:
                    # Too few values, file may be read incompletely, try other methods
                    print(f"Warning: path_homology feature value count insufficient (manual parsing): {protein_id_upper}, expected 360, got{num_values}")
                    print(f"  File path: {ph_file}")
                    print(f"  File size: {file_size}, Content length: {original_length}, Comma count: {original_comma_count}")
                    print(f"  First 100 chars: {content[:100]}")
                    print(f"  Last 100 chars: {content[-100:]}")
                    
                    # If read content is significantly less than file size, try multiple methods to re-read
                    if original_length < file_size * 0.5:
                        print(f"  File read may be incomplete (read length {original_length} < 50% of file size {file_size}), trying multiple methods to re-read...")
                        
                        # Method 1: Read line by line
                        try:
                            with open(ph_file, 'r', encoding='utf-8', errors='ignore') as f3:
                                lines = f3.readlines()
                                full_content = ''.join(lines)
                                full_content = full_content.replace('\n', '').replace('\r', '').replace('\t', '')
                                full_content = full_content.strip().strip('"').strip("'")
                                full_values = full_content.split(',')
                                full_values = [v.strip() for v in full_values]
                                while full_values and not full_values[-1]:
                                    full_values.pop()
                                
                                if len(full_values) >= 360:
                                    print(f"  Method 1 (line-by-line read) succeeded, found {len(full_values)} values")
                                    ph_features = np.array([float(val) if val else 0.0 for val in full_values[:360]], dtype=np.float32)
                                else:
                                    print(f"  Method 1 (line-by-line read) still insufficient: {len(full_values)} values")
                                    ph_features = None
                        except Exception as e3:
                            print(f"  Method 1 (line-by-line read) failed: {e3}")
                            ph_features = None
                        
                        # Method 2: If method 1 fails, try binary read and force UTF-8 decoding
                        if ph_features is None:
                            try:
                                with open(ph_file, 'rb') as f4:
                                    raw_content2 = f4.read()
                                    if len(raw_content2) == file_size:
                                        # Force UTF-8 decoding, ignore errors
                                        content2 = raw_content2.decode('utf-8', errors='replace')
                                        content2 = content2.replace('\n', '').replace('\r', '').replace('\t', '')
                                        content2 = content2.strip().strip('"').strip("'")
                                        values2 = content2.split(',')
                                        values2 = [v.strip() for v in values2]
                                        while values2 and not values2[-1]:
                                            values2.pop()
                                        
                                        if len(values2) >= 360:
                                            print(f"  Method 2 (binary + force UTF-8) succeeded, found {len(values2)} values")
                                            ph_features = np.array([float(val) if val else 0.0 for val in values2[:360]], dtype=np.float32)
                                        else:
                                            print(f"  Method 2 (binary + force UTF-8) still insufficient: {len(values2)} values")
                                            ph_features = None
                                    else:
                                        print(f"  Method 2 (binary read) failed: read length {len(raw_content2)} != file size {file_size}")
                                        ph_features = None
                            except Exception as e4:
                                print(f"  Method 2 (binary + force UTF-8) failed: {e4}")
                                ph_features = None
                        
                        # 如果所有方法都失败，尝试pandas
                        if ph_features is None:
                            print(f"  All manual read methods failed, trying pandas read...")
                    else:
                        print(f"  Trying pandas read...")
                        ph_features = None  # 让pandas尝试读取
                else:
                    # 完全没有值
                    print(f"Error: path_homology feature file is empty: {protein_id_upper}")
                    ph_features = None
                    
        except Exception as e:
            print(f"Manual parsing failed for {protein_id_upper}: {e}")
            import traceback
            traceback.print_exc()
            ph_features = None
        
        # If manual parsing fails or values insufficient, try using pandas (as fallback)
        if ph_features is None:
            try:
                # Try different pandas reading methods
                # Method 1: Use python engine, no column limit, explicitly specify column count
                # Note: pandas may limit columns by default, need explicit handling
                try:
                    # First try to read all columns (no limit)
                    ph_data = pd.read_csv(ph_file, header=None, engine='python', sep=',', 
                                         quotechar='"', on_bad_lines='skip', 
                                         dtype=str, keep_default_na=False,
                                         names=[f'col_{i}' for i in range(360)])  # Explicitly specify 360 columns
                    
                    # Check actual number of columns read
                    actual_cols = ph_data.shape[1]
                    if actual_cols == 360:
                        # Single row, 360 columns - this is the correct format
                        # Use safe conversion method: first convert to string, then to float, avoid NaN
                        str_values = ph_data.iloc[0].astype(str).values
                        try:
                            ph_features = np.array([float(v) if v and v.lower() not in ['nan', 'none', 'null', ''] else 0.0 
                                                   for v in str_values], dtype=np.float32)
                        except (ValueError, TypeError) as e:
                            print(f"Warning: Safe conversion failed, using standard conversion: {e}")
                            ph_features = ph_data.iloc[0].values.astype(np.float32)
                            # Check and replace NaN values
                            if np.isnan(ph_features).any():
                                nan_count = np.isnan(ph_features).sum()
                                print(f"Warning: Found {nan_count} NaN values after pandas read, will replace with 0.0")
                                ph_features = np.nan_to_num(ph_features, nan=0.0)
                    elif actual_cols < 360:
                        # 列数不足，尝试读取原始数据
                        print(f"Warning: pandas read insufficient columns: {protein_id_upper}, expected 360, got{actual_cols}")
                        # 尝试不使用列名限制
                        ph_data2 = pd.read_csv(ph_file, header=None, engine='python', sep=',', 
                                              quotechar='"', on_bad_lines='skip', 
                                              dtype=str, keep_default_na=False)
                        
                        # 检查维度
                        if ph_data2.shape[1] >= 360:
                            # 取前360列
                            # 使用安全的转换方法
                            str_values = ph_data2.iloc[0, :360].astype(str).values
                            try:
                                ph_features = np.array([float(v) if v and v.lower() not in ['nan', 'none', 'null', ''] else 0.0 
                                                       for v in str_values], dtype=np.float32)
                            except (ValueError, TypeError) as e:
                                print(f"Warning: Safe conversion failed, using standard conversion: {e}")
                                ph_features = ph_data2.iloc[0, :360].values.astype(np.float32)
                                # Check and replace NaN values
                                if np.isnan(ph_features).any():
                                    nan_count = np.isnan(ph_features).sum()
                                    print(f"Warning: Found {nan_count} NaN values after pandas read, will replace with 0.0")
                                    ph_features = np.nan_to_num(ph_features, nan=0.0)
                        elif ph_data2.shape[0] >= 360 and ph_data2.shape[1] == 1:
                            # 360 rows, single column - take first 360 rows
                            # Use safe conversion method
                            str_values = ph_data2.iloc[:360, 0].astype(str).values
                            try:
                                ph_features = np.array([float(v) if v and v.lower() not in ['nan', 'none', 'null', ''] else 0.0 
                                                       for v in str_values], dtype=np.float32)
                            except (ValueError, TypeError) as e:
                                print(f"Warning: Safe conversion failed, using standard conversion: {e}")
                                ph_features = ph_data2.iloc[:360, 0].values.astype(np.float32)
                                # Check and replace NaN values
                                if np.isnan(ph_features).any():
                                    nan_count = np.isnan(ph_features).sum()
                                    print(f"Warning: Found {nan_count} NaN values after pandas read, will replace with 0.0")
                                    ph_features = np.nan_to_num(ph_features, nan=0.0)
                        else:
                            # Dimension incorrect, try manual parsing of raw data
                            print(f"  pandas read failed, trying manual parsing again...")
                            ph_features = None
                    else:
                        # Too many columns, take first 360 columns
                        # Use safe conversion method
                        str_values = ph_data.iloc[0, :360].astype(str).values
                        try:
                            ph_features = np.array([float(v) if v and v.lower() not in ['nan', 'none', 'null', ''] else 0.0 
                                                   for v in str_values], dtype=np.float32)
                        except (ValueError, TypeError) as e:
                            print(f"Warning: Safe conversion failed, using standard conversion: {e}")
                            ph_features = ph_data.iloc[0, :360].values.astype(np.float32)
                            # Check and replace NaN values
                            if np.isnan(ph_features).any():
                                nan_count = np.isnan(ph_features).sum()
                                print(f"Warning: Found {nan_count} NaN values after pandas read, will replace with 0.0")
                                ph_features = np.nan_to_num(ph_features, nan=0.0)
                        
                except Exception as e1:
                    # If specifying column names fails, try without column names
                        print(f"  pandas read (specified column names) failed: {e1}")
                    try:
                        ph_data = pd.read_csv(ph_file, header=None, engine='python', sep=',', 
                                             quotechar='"', on_bad_lines='skip', 
                                             dtype=str, keep_default_na=False)
                        
                        # Check dimension
                        if ph_data.shape[1] >= 360:
                            # Single row, at least 360 columns - take first 360 columns
                            # Use safe conversion method
                            str_values = ph_data.iloc[0, :360].astype(str).values
                            try:
                                ph_features = np.array([float(v) if v and v.lower() not in ['nan', 'none', 'null', ''] else 0.0 
                                                       for v in str_values], dtype=np.float32)
                            except (ValueError, TypeError) as e:
                                print(f"Warning: Safe conversion failed, using standard conversion: {e}")
                                ph_features = ph_data.iloc[0, :360].values.astype(np.float32)
                                # Check and replace NaN values
                                if np.isnan(ph_features).any():
                                    nan_count = np.isnan(ph_features).sum()
                                    print(f"Warning: Found {nan_count} NaN values after pandas read, will replace with 0.0")
                                    ph_features = np.nan_to_num(ph_features, nan=0.0)
                        elif ph_data.shape[0] >= 360 and ph_data.shape[1] == 1:
                            # At least 360 rows, single column - take first 360 rows
                            # Use safe conversion method
                            str_values = ph_data.iloc[:360, 0].astype(str).values
                            try:
                                ph_features = np.array([float(v) if v and v.lower() not in ['nan', 'none', 'null', ''] else 0.0 
                                                       for v in str_values], dtype=np.float32)
                            except (ValueError, TypeError) as e:
                                print(f"Warning: Safe conversion failed, using standard conversion: {e}")
                                ph_features = ph_data.iloc[:360, 0].values.astype(np.float32)
                                # Check and replace NaN values
                                if np.isnan(ph_features).any():
                                    nan_count = np.isnan(ph_features).sum()
                                    print(f"Warning: Found {nan_count} NaN values after pandas read, will replace with 0.0")
                                    ph_features = np.nan_to_num(ph_features, nan=0.0)
                        elif ph_data.shape[1] > 0 and ph_data.shape[1] < 360:
                            # Insufficient columns, may be pandas default limit
                            print(f"Error: path_homology feature dimension incorrect (pandas): {protein_id_upper}, shape {ph_data.shape}")
                            print(f"  Expected: (1, 360) or (360, 1)")
                            print(f"  This may be due to pandas default column limit, trying manual parsing of full file again...")
                            
                            # Last attempt: re-manual parsing, ensure complete reading
                            with open(ph_file, 'r', encoding='utf-8') as f2:
                                content2 = f2.read()
                                content2 = content2.replace('\n', '').replace('\r', '').replace('\t', '')
                                content2 = content2.strip().strip('"').strip("'")
                                values2 = content2.split(',')
                                values2 = [v.strip() for v in values2]
                                while values2 and not values2[-1]:
                                    values2.pop()
                                
                                if len(values2) >= 360:
                                    ph_features = np.array([float(val) if val else 0.0 for val in values2[:360]], dtype=np.float32)
                                else:
                                    ph_features = None
                        else:
                            ph_features = None
                            
                    except Exception as e2:
                        print(f"  pandas read (without column names) also failed: {e2}")
                        ph_features = None
                    
            except Exception as e:
                print(f"Pandas reading failed for {protein_id_upper}: {e}")
                ph_features = None
        
        # Ensure it's a numpy array and verify dimension
        if ph_features is None:
            print(f"Error: Failed to load path_homology features for {protein_id_upper}, using zero vector")
            return torch.zeros(360, dtype=torch.float32)
        
        if not hasattr(ph_features, 'shape'):
            ph_features = np.array(ph_features)
        
        # Final verification: must be 360 dimensions, do not accept other dimensions
        if ph_features.shape == (360,):
            # Correct! Already 360 dimensions
            pass
        elif len(ph_features.shape) == 1 and ph_features.shape[0] < 360:
            # Insufficient dimensions, pad with zeros to 360 dimensions
            print(f"Warning: path_homology feature dimension insufficient: {protein_id_upper}, shape {ph_features.shape}, padding with zeros to 360 dimensions")
            padded_features = np.zeros(360, dtype=np.float32)
            padded_features[:ph_features.shape[0]] = ph_features
            ph_features = padded_features
        elif len(ph_features.shape) == 1 and ph_features.shape[0] > 360:
            # Too many dimensions, truncate to 360 dimensions
            print(f"Warning: path_homology feature dimension too many: {protein_id_upper}, shape {ph_features.shape}, truncating to 360 dimensions")
            ph_features = ph_features[:360]
        else:
            # Dimension incorrect, this is an error condition
            print(f"Error: path_homology feature dimension incorrect: {protein_id_upper}, shape {ph_features.shape}, expected (360,)")
            print(f"  This indicates a problem with file reading logic, please check file format")
            return torch.zeros(360, dtype=torch.float32)
        
        # Check for abnormal values and try to fix
        if np.isnan(ph_features).any():
            nan_count = np.isnan(ph_features).sum()
            print(f"Warning: NaN values detected in path_homology features for {protein_id_upper}: {nan_count} NaN values")
            print(f"  Attempting to replace NaN values with 0.0...")
            # Try to replace NaN values instead of directly returning zero vector
            ph_features = np.nan_to_num(ph_features, nan=0.0, posinf=0.0, neginf=0.0)
            # Check again
            if np.isnan(ph_features).any():
                print(f"  Error: NaN values still present after replacement, using zero vector")
                return torch.zeros(360, dtype=torch.float32)
            else:
                print(f"  Success: NaN values successfully replaced with 0.0")
        
        if np.isinf(ph_features).any():
            inf_count = np.isinf(ph_features).sum()
            print(f"Warning: Inf values detected in path_homology features for {protein_id_upper}: {inf_count} Inf values")
            print(f"  Attempting to replace Inf values with 0.0...")
            # Replace Inf values
            ph_features = np.nan_to_num(ph_features, nan=0.0, posinf=0.0, neginf=0.0)
            # Check again
            if np.isinf(ph_features).any():
                print(f"  Error: Inf values still present after replacement, using zero vector")
                return torch.zeros(360, dtype=torch.float32)
            else:
                print(f"  Success: Inf values successfully replaced with 0.0")
        
        # Create tensor and mark as graph-level feature
        ph_tensor = torch.tensor(ph_features, dtype=torch.float32)
        ph_tensor._is_graph_level = True  # Mark as graph-level feature
        
        return ph_tensor
        
    except Exception as e:
        print(f"Error loading path_homology features for {protein_id}: {e}")
        import traceback
        traceback.print_exc()
        return torch.zeros(360, dtype=torch.float32)
