import torch
import torch.nn as nn
import torch.nn.functional as F


# Ensure computations run on GPU when available
if torch.cuda.is_available():
    cuda_device_count = torch.cuda.device_count()
    if cuda_device_count > 0:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


class QCformer(nn.Module):
    """QCformer that uses ESM2 and path homology mixed features with MLP prediction head.
    
    Architecture:
    1. ESM2 features (1280 dim) -> reduced to 256 dim via network
    2. Path_homology features (360 dim) -> reduced to 256 dim via MLP network
    3. Concatenate ESM2 reduced (256 dim) with path_homology reduced (256 dim)
    4. Mixed features (256 + 256 = 512 dim) -> MLP -> prediction
    """

    def __init__(self, config):
        super().__init__()

        # Determine number of model outputs
        # Priority: 1) training.num_classes, 2) model.num_outputs, 3) infer from data
        if 'num_classes' in config.get('training', {}):
            self.num_outputs = config['training']['num_classes']
        elif 'num_outputs' in config.get('model', {}):
            self.num_outputs = config['model']['num_outputs']
        else:
            # Backwards compatibility with previous default
            self.num_outputs = 51 if config['training'].get('target') is None else 1

        self.output_activation = config['model'].get('output_activation', 'none')
        self.dropout = config['model'].get('dropout', 0.0)
        self.hidden_dim = config['model'].get('hidden_dim', 256)

        # MLP configuration
        self.mlp_hidden_dims = config['model'].get('mlp_hidden_dims', [512, 256])  # MLP hidden layer dimension list
        self.mlp_num_layers = config['model'].get('mlp_num_layers', None)  # MLP number of layers (if specified, will override mlp_hidden_dims)

        # Feature sizes
        self.esm2_feature_size = 1280  # ESM2 feature dimension
        self.esm2_reduced_dim = 256   # ESM2 dimension after reduction
        self.path_homology_feature_size = 360  # path_homology feature dimension (original)
        self.path_homology_reduced_dim = 256  # path_homology dimension after reduction
        
        # Feature normalization configuration
        self.normalize_features = config['model'].get('normalize_features', True)  # Enable normalization by default
        self.esm2_mean = config['model'].get('esm2_mean', None)  # ESM2 feature mean
        self.esm2_std = config['model'].get('esm2_std', None)  # ESM2 feature standard deviation
        self.path_homology_mean = config['model'].get('path_homology_mean', None)  # Path_homology feature mean
        self.path_homology_std = config['model'].get('path_homology_std', None)  # Path_homology feature standard deviation
        
        # If mean and standard deviation are provided, convert to tensor and register as buffer
        if self.normalize_features:
            if self.esm2_mean is not None and self.esm2_std is not None:
                self.register_buffer('esm2_mean_tensor', torch.tensor(self.esm2_mean, dtype=torch.float32, device=device))
                # Prevent division by zero, add small epsilon
                self.register_buffer('esm2_std_tensor_safe', torch.clamp(torch.tensor(self.esm2_std, dtype=torch.float32, device=device), min=1e-8))
            else:
                # If statistics are not provided, use zero mean and unit standard deviation (no normalization, but keep interface)
                self.register_buffer('esm2_mean_tensor', torch.zeros(self.esm2_feature_size, dtype=torch.float32, device=device))
                self.register_buffer('esm2_std_tensor_safe', torch.ones(self.esm2_feature_size, dtype=torch.float32, device=device))
                print("Warning: ESM2 feature normalization is enabled but statistics are not provided. Normalization will be skipped.")
            
            if self.path_homology_mean is not None and self.path_homology_std is not None:
                self.register_buffer('path_homology_mean_tensor', torch.tensor(self.path_homology_mean, dtype=torch.float32, device=device))
                # Prevent division by zero, add small epsilon
                self.register_buffer('path_homology_std_tensor_safe', torch.clamp(torch.tensor(self.path_homology_std, dtype=torch.float32, device=device), min=1e-8))
            else:
                # If statistics are not provided, use zero mean and unit standard deviation (no normalization, but keep interface)
                self.register_buffer('path_homology_mean_tensor', torch.zeros(self.path_homology_feature_size, dtype=torch.float32, device=device))
                self.register_buffer('path_homology_std_tensor_safe', torch.ones(self.path_homology_feature_size, dtype=torch.float32, device=device))
                print("Warning: Path_homology feature normalization is enabled but statistics are not provided. Normalization will be skipped.")
        
        # ESM2 feature reduction network: 1280 -> 256
        # Use larger eps value to improve numerical stability for float16
        self.esm2_reduction = nn.Sequential(
            nn.Linear(self.esm2_feature_size, 512, device=device),
            nn.LayerNorm(512, eps=1e-5, device=device),  # Increase eps to improve numerical stability
            nn.SiLU(),
            nn.Dropout(self.dropout),
            nn.Linear(512, self.esm2_reduced_dim, device=device),
            nn.LayerNorm(self.esm2_reduced_dim, eps=1e-5, device=device),  # Increase eps to improve numerical stability
            nn.SiLU(),
            nn.Dropout(self.dropout),
        )
        
        # Path_homology feature reduction network: 360 -> 256
        # Use larger eps value to improve numerical stability for float16
        self.path_homology_reduction = nn.Sequential(
            nn.Linear(self.path_homology_feature_size, 512, device=device),
            nn.LayerNorm(512, eps=1e-5, device=device),  # Increase eps to improve numerical stability
            nn.SiLU(),
            nn.Dropout(self.dropout),
            nn.Linear(512, self.path_homology_reduced_dim, device=device),
            nn.LayerNorm(self.path_homology_reduced_dim, eps=1e-5, device=device),  # Increase eps to improve numerical stability
            nn.SiLU(),
            nn.Dropout(self.dropout),
        )
        
        # Mixed feature dimension: ESM2 after reduction (256) + path_homology after reduction (256) = 512
        mixed_feature_dim = self.esm2_reduced_dim + self.path_homology_reduced_dim  # 512
        self.mixed_feature_dim = mixed_feature_dim
        
        # Build MLP layers
        # If mlp_num_layers is specified, use uniform hidden layer dimensions
        if self.mlp_num_layers is not None and self.mlp_num_layers > 0:
            # Calculate dimension for each layer (gradually decrease from mixed_feature_dim to hidden_dim)
            dims = [mixed_feature_dim]
            for i in range(self.mlp_num_layers):
                # Linear interpolation to calculate intermediate layer dimensions
                ratio = (i + 1) / (self.mlp_num_layers + 1)
                dim = int(mixed_feature_dim * (1 - ratio) + self.hidden_dim * ratio)
                dims.append(dim)
            dims.append(self.hidden_dim)
            self.mlp_hidden_dims = dims[1:-1]  # Exclude input and output layers
        else:
            # Use configured mlp_hidden_dims
            if not isinstance(self.mlp_hidden_dims, list):
                self.mlp_hidden_dims = [self.mlp_hidden_dims] if self.mlp_hidden_dims else []
        
        # If no hidden layers are configured, use default hidden layer
        if len(self.mlp_hidden_dims) == 0:
            self.mlp_hidden_dims = [512]  # Default to one 512-dimensional hidden layer
        
        # Build MLP network
        mlp_layers = []
        input_dim = mixed_feature_dim
        
        # Add hidden layers
        for hidden_dim in self.mlp_hidden_dims:
            mlp_layers.append(nn.Linear(input_dim, hidden_dim, device=device))
            mlp_layers.append(nn.LayerNorm(hidden_dim, eps=1e-5, device=device))
            mlp_layers.append(nn.SiLU())
            mlp_layers.append(nn.Dropout(self.dropout))
            input_dim = hidden_dim
        
        # Output layer: from MLP output to final prediction
        self.mlp = nn.Sequential(*mlp_layers)
        self.output_projection = nn.Sequential(
            nn.Linear(input_dim, self.hidden_dim, device=device),
            nn.LayerNorm(self.hidden_dim, eps=1e-5, device=device),  # Increase eps to improve numerical stability
            nn.SiLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_dim, self.num_outputs, device=device)
        )

    def _determine_batch_size(self, data, features):
        """Determine batch size from data or features"""
        if hasattr(data, 'num_graphs') and data.num_graphs is not None:
            return data.num_graphs
        if hasattr(data, 'batch') and data.batch is not None:
            return int(data.batch.max().item()) + 1
        if features is not None:
            return features.shape[0] if features.dim() > 1 else 1
        return 1

    def _reshape_features(self, features, batch_size, expected_dim):
        """Reshape features to [batch_size, expected_dim]"""
        if isinstance(features, (list, tuple)):
            features = torch.stack(list(features), dim=0)

        if features.dim() == 1:
            features = features.unsqueeze(0)
        elif features.dim() > 2:
            features = features.view(features.size(0), -1)

        if batch_size > 1 and features.size(0) == 1 and features.size(1) == batch_size * expected_dim:
            features = features.view(batch_size, expected_dim)
        elif features.numel() == batch_size * expected_dim and features.size(0) != batch_size:
            features = features.view(batch_size, expected_dim)

        if features.size(-1) != expected_dim:
            raise ValueError(
                f"Expected feature dimension {expected_dim}, "
                f"got {features.size(-1)}"
            )

        return features

    def forward(self, data):
        # Check required features
        if not hasattr(data, 'path_homology_features'):
            available_attrs = [attr for attr in dir(data) if not attr.startswith('_')]
            raise ValueError(
                f"Input data must contain path_homology_features. "
                f"Available attributes: {available_attrs}"
            )
        
        if data.path_homology_features is None:
            raise ValueError(
                f"Input data.path_homology_features is None. "
                f"Data type: {type(data)}, has path_homology_features attr: {hasattr(data, 'path_homology_features')}"
            )
        
        # Check ESM2 features
        if not hasattr(data, 'esm2_features'):
            available_attrs = [attr for attr in dir(data) if not attr.startswith('_')]
            raise ValueError(
                f"Input data must contain esm2_features. "
                f"Available attributes: {available_attrs}"
            )
        
        if data.esm2_features is None:
            raise ValueError(
                f"Input data.esm2_features is None. "
                f"Data type: {type(data)}, has esm2_features attr: {hasattr(data, 'esm2_features')}"
            )

        path_homology_features = data.path_homology_features
        esm2_features = data.esm2_features
        
        # Determine batch size
        batch_size = self._determine_batch_size(data, path_homology_features)
        
        # Reshape features to correct shape
        path_homology_features = self._reshape_features(path_homology_features, batch_size, self.path_homology_feature_size)
        esm2_features = self._reshape_features(esm2_features, batch_size, self.esm2_feature_size)
        
        # Get the device where the model is located
        model_device = next(self.parameters()).device
        
        # Move to correct device
        path_homology_features = path_homology_features.to(model_device)
        esm2_features = esm2_features.to(model_device)

        # Input data validation: check and clean before entering the network
        if torch.isnan(esm2_features).any() or torch.isinf(esm2_features).any():
            esm2_features = torch.nan_to_num(esm2_features, nan=0.0, posinf=1e6, neginf=-1e6)
            torch.clamp_(esm2_features, min=-1e6, max=1e6)
        
        if torch.isnan(path_homology_features).any() or torch.isinf(path_homology_features).any():
            path_homology_features = torch.nan_to_num(path_homology_features, nan=0.0, posinf=1e6, neginf=-1e6)
            torch.clamp_(path_homology_features, min=-1e6, max=1e6)
        
        # Feature normalization: normalize before entering the network to mitigate extreme value issues
        if self.normalize_features:
            # ESM2 feature normalization
            if hasattr(self, 'esm2_mean_tensor') and hasattr(self, 'esm2_std_tensor_safe'):
                # Ensure normalization tensors are on the same device as input features
                esm2_mean = self.esm2_mean_tensor.to(esm2_features.device)
                esm2_std_safe = self.esm2_std_tensor_safe.to(esm2_features.device)
                # Ensure mean and std tensor shapes match
                if esm2_mean.shape[0] == self.esm2_feature_size:
                    esm2_features = (esm2_features - esm2_mean) / esm2_std_safe
                else:
                    # If shapes don't match, try broadcasting (for batch dimension)
                    if esm2_features.dim() == 2:
                        # [batch_size, feature_dim]
                        esm2_features = (esm2_features - esm2_mean.unsqueeze(0)) / esm2_std_safe.unsqueeze(0)
            
            # Path_homology feature normalization
            if hasattr(self, 'path_homology_mean_tensor') and hasattr(self, 'path_homology_std_tensor_safe'):
                # Ensure normalization tensors are on the same device as input features
                ph_mean = self.path_homology_mean_tensor.to(path_homology_features.device)
                ph_std_safe = self.path_homology_std_tensor_safe.to(path_homology_features.device)
                # Ensure mean and std tensor shapes match
                if ph_mean.shape[0] == self.path_homology_feature_size:
                    path_homology_features = (path_homology_features - ph_mean) / ph_std_safe
                else:
                    # If shapes don't match, try broadcasting (for batch dimension)
                    if path_homology_features.dim() == 2:
                        # [batch_size, feature_dim]
                        path_homology_features = (path_homology_features - ph_mean.unsqueeze(0)) / ph_std_safe.unsqueeze(0)
            
            # Limit range again after normalization to prevent extreme values (after normalization should typically be in [-3, 3] range)
            torch.clamp_(esm2_features, min=-10.0, max=10.0)
            torch.clamp_(path_homology_features, min=-10.0, max=10.0)
        
        # Step 1: ESM2 feature reduction: 1280 -> 256
        esm2_reduced = self.esm2_reduction(esm2_features)  # [batch_size, 256]
        
        # Check if intermediate results have NaN/Inf and limit numerical range
        if torch.isnan(esm2_reduced).any() or torch.isinf(esm2_reduced).any():
            # If NaN/Inf occurs, use zero padding and limit range
            esm2_reduced = torch.nan_to_num(esm2_reduced, nan=0.0, posinf=1e6, neginf=-1e6)
        # Limit numerical range to prevent extreme values
        torch.clamp_(esm2_reduced, min=-1e6, max=1e6)
        
        # Step 2: Path_homology feature reduction: 360 -> 256
        path_homology_reduced = self.path_homology_reduction(path_homology_features)  # [batch_size, 256]
        
        # Check if intermediate results have NaN/Inf and limit numerical range
        if torch.isnan(path_homology_reduced).any() or torch.isinf(path_homology_reduced).any():
            # If NaN/Inf occurs, use zero padding and limit range
            path_homology_reduced = torch.nan_to_num(path_homology_reduced, nan=0.0, posinf=1e6, neginf=-1e6)
        # Limit numerical range to prevent extreme values
        torch.clamp_(path_homology_reduced, min=-1e6, max=1e6)
        
        # Step 3: Mixed features: concatenate ESM2 reduced features and path_homology reduced features
        mixed_features = torch.cat([esm2_reduced, path_homology_reduced], dim=1)  # [batch_size, 256 + 256 = 512]
        
        # Check if mixed features have NaN/Inf and limit numerical range
        if torch.isnan(mixed_features).any() or torch.isinf(mixed_features).any():
            mixed_features = torch.nan_to_num(mixed_features, nan=0.0, posinf=1e6, neginf=-1e6)
        torch.clamp_(mixed_features, min=-1e6, max=1e6)
        
        # Step 4: Process mixed features through MLP
        mlp_output = self.mlp(mixed_features)  # [batch_size, hidden_dim]
        
        # Check if MLP output has NaN/Inf and limit numerical range
        if torch.isnan(mlp_output).any() or torch.isinf(mlp_output).any():
            mlp_output = torch.nan_to_num(mlp_output, nan=0.0, posinf=1e6, neginf=-1e6)
        torch.clamp_(mlp_output, min=-1e6, max=1e6)
        
        # Step 5: Get final prediction through output layer
        out = self.output_projection(mlp_output)  # [batch_size, num_outputs]
        
        # Final check if output has NaN/Inf and limit numerical range
        if torch.isnan(out).any() or torch.isinf(out).any():
            # If NaN/Inf occurs, use zero padding and limit range
            out = torch.nan_to_num(out, nan=0.0, posinf=1e6, neginf=-1e6)
        # Limit range before applying activation function to prevent extreme values causing NaN in activation function
        out = torch.clamp(out, min=-50.0, max=50.0)

        # Safely apply activation function
        if self.output_activation == 'relu':
            out = torch.relu(out)
        elif self.output_activation == 'softplus':
            # softplus is sensitive to input range, limit range first
            out = torch.clamp(out, min=-20.0, max=20.0)
            out = F.softplus(out)
        elif self.output_activation == 'elu':
            out = F.elu(out) + 1
        elif self.output_activation == 'tanh':
            out = torch.tanh(out)
        elif self.output_activation == 'sigmoid':
            out = torch.sigmoid(out)
        elif self.output_activation == 'softmax':
            # softmax is sensitive to input range, limit range first to prevent numerical overflow
            out = torch.clamp(out, min=-50.0, max=50.0)
            out = torch.softmax(out, dim=1)
        elif self.output_activation == 'log_softmax':
            # log_softmax is sensitive to input range, limit range first to prevent numerical overflow
            out = torch.clamp(out, min=-50.0, max=50.0)
            out = torch.log_softmax(out, dim=1)
        
        # Final safety check: ensure output has no NaN/Inf
        if torch.isnan(out).any() or torch.isinf(out).any():
            # If NaN/Inf still exists after activation function, use safe default values
            if self.output_activation == 'softmax' or self.output_activation == 'log_softmax':
                # For softmax, use uniform distribution as fallback
                out = torch.ones_like(out) / out.size(1)
            else:
                out = torch.nan_to_num(out, nan=0.0, posinf=1.0, neginf=-1.0)
            out = torch.clamp(out, min=-10.0, max=10.0)

        return out


__all__ = ["QCformer"]
