import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def calculate_effective_number_weights(class_counts, beta=0.999):
    """
    Calculate Effective Number weights
    
    Args:
        class_counts: Number of samples per class
        beta: Effective Number parameter, default 0.999
    
    Returns:
        weights: Weights for each class
        alpha_c: Normalized alpha weights
    
    Paper: "Class-Balanced Loss Based on Effective Number of Samples"
    Formula: w_c = (1 - beta^n_c) / (1 - beta), alpha_c = w_c / sum(w_c)
    """
    # Convert to numpy array
    if isinstance(class_counts, torch.Tensor):
        class_counts = class_counts.cpu().numpy()
    
    # Calculate Effective Number weights
    weights = (1 - beta ** class_counts) / (1 - beta)
    
    # Calculate normalized alpha weights
    alpha_c = weights / np.sum(weights)
    
    return torch.tensor(weights, dtype=torch.float32), torch.tensor(alpha_c, dtype=torch.float32)


def calculate_inverse_frequency_weights(class_counts, power=1.0, min_ratio=20.0):
    """
    Calculate weights based on inverse frequency, with weight enhancement support
    
    Args:
        class_counts: Number of samples per class
        power: Weight power (1.0=inverse_freq, 0.5=sqrt_inverse)
        min_ratio: Minimum weight ratio
    
    Returns:
        weights: Enhanced weights
    """
    if isinstance(class_counts, torch.Tensor):
        class_counts = class_counts.cpu().numpy()
    
    # Calculate inverse frequency weights
    inverse_weights = 1.0 / (class_counts ** power)
    
    # Weight enhancement: ensure minimum weight ratio
    max_weight = np.max(inverse_weights)
    min_weight = max_weight / min_ratio
    
    # Scale if minimum weight doesn't meet requirements
    if np.min(inverse_weights) < min_weight:
        scale_factor = min_weight / np.min(inverse_weights)
        inverse_weights *= scale_factor
    
    return torch.tensor(inverse_weights, dtype=torch.float32)


class ClassBalancedFocalLoss(nn.Module):
    """
    Class-Balanced Focal Loss with Effective Number weights
    
    FL(pt) = -αt * (1-pt)^γ * log(pt)
    
    where:
    - pt is the predicted probability for the true class
    - αt is the Effective Number weight for class t
    - γ is the focusing parameter (default=2)
    """
    
    def __init__(self, class_counts=None, beta=0.999, gamma=2.0, reduction='mean'):
        super(ClassBalancedFocalLoss, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.reduction = reduction
        
        # Calculate Effective Number weights
        if class_counts is not None:
            _, self.alpha = calculate_effective_number_weights(class_counts, beta)
        else:
            self.alpha = None
        
    def forward(self, inputs, targets):
        """
        Args:
            inputs: (N, C) where C is the number of classes
            targets: (N,) where each value is 0 <= targets[i] <= C-1
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        
        if self.alpha is not None:
            if self.alpha.type() != inputs.data.type():
                self.alpha = self.alpha.type_as(inputs.data)
            at = self.alpha.gather(0, targets.data.view(-1))
            focal_loss = at * (1 - pt) ** self.gamma * ce_loss
        else:
            focal_loss = (1 - pt) ** self.gamma * ce_loss
            
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


class AggressiveFocalLoss(nn.Module):
    """
    Aggressive Focal Loss that amplifies gradients for hard samples, using more aggressive α parameters
    
    Features:
    1. Uses Effective Number weights to calculate α, more aggressive
    2. Higher gamma values, more focus on hard samples
    3. Supports dynamic weight adjustment
    4. Stronger gradient amplification effect for hard samples
    """
    
    def __init__(self, class_counts=None, beta=0.9999, gamma=2.5, 
                 dynamic_gamma=True, min_gamma=2.0, max_gamma=4.0,
                 reduction='mean'):
        super(AggressiveFocalLoss, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.dynamic_gamma = dynamic_gamma
        self.min_gamma = min_gamma
        self.max_gamma = max_gamma
        self.reduction = reduction
        
        # Calculate more aggressive Effective Number weights
        if class_counts is not None:
            _, self.alpha = calculate_effective_number_weights(class_counts, beta)
            # Enhance weights to give higher weights to minority classes
            self.alpha = self.alpha ** 1.5  # Enhance weight effect
            self.alpha = self.alpha / self.alpha.sum()  # Renormalize
        else:
            self.alpha = None
        
        # Record training statistics
        self.register_buffer('epoch', torch.tensor(0))
        self.register_buffer('difficult_samples', torch.tensor(0))
        
    def update_epoch(self, epoch):
        """Update epoch information for dynamic parameter adjustment"""
        self.epoch = torch.tensor(epoch)
        
    def forward(self, inputs, targets):
        """
        Args:
            inputs: (N, C) where C is the number of classes
            targets: (N,) where each value is 0 <= targets[i] <= C-1
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        
        # Dynamically adjust gamma value
        if self.dynamic_gamma:
            # Dynamically adjust gamma based on training progress
            progress = min(self.epoch.item() / 50.0, 1.0)  # Assuming total training of 50 epochs
            current_gamma = self.min_gamma + (self.max_gamma - self.min_gamma) * progress
        else:
            current_gamma = self.gamma
        
        # Calculate focal loss
        if self.alpha is not None:
            if self.alpha.type() != inputs.data.type():
                self.alpha = self.alpha.type_as(inputs.data)
            at = self.alpha.gather(0, targets.data.view(-1))
            
            # Aggressive focal loss calculation
            focal_loss = at * (1 - pt) ** current_gamma * ce_loss
            
            # Further amplify gradients for hard samples (samples with pt < 0.3)
            hard_sample_mask = pt < 0.3
            if hard_sample_mask.any():
                # Additional weight for hard samples
                hard_weight = 2.0
                focal_loss[hard_sample_mask] *= hard_weight
                
                # Update hard sample statistics
                self.difficult_samples = hard_sample_mask.sum().float()
        else:
            focal_loss = (1 - pt) ** current_gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
    
    def get_difficult_sample_ratio(self):
        """Get the ratio of difficult samples"""
        return self.difficult_samples.item()


class FocalLoss(nn.Module):
    """
    Focal Loss for addressing class imbalance.
    
    FL(pt) = -αt * (1-pt)^γ * log(pt)
    
    where:
    - pt is the predicted probability for the true class
    - αt is the weight for class t
    - γ is the focusing parameter (default=2)
    """
    
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        """
        Args:
            inputs: (N, C) where C is the number of classes
            targets: (N,) where each value is 0 <= targets[i] <= C-1
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        
        if self.alpha is not None:
            if self.alpha.type() != inputs.data.type():
                self.alpha = self.alpha.type_as(inputs.data)
            at = self.alpha.gather(0, targets.data.view(-1))
            focal_loss = at * (1 - pt) ** self.gamma * ce_loss
        else:
            focal_loss = (1 - pt) ** self.gamma * ce_loss
            
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


class FocalLossWithLogits(nn.Module):
    """
    Focal Loss with built-in sigmoid activation for binary classification.
    """
    
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLossWithLogits, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        """
        Args:
            inputs: (N,) raw logits
            targets: (N,) binary targets (0 or 1)
        """
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        
        # For targets == 1, use alpha; for targets == 0, use 1-alpha
        alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)
        focal_loss = alpha_t * (1 - pt) ** self.gamma * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


class AsymmetricLoss(nn.Module):
    """
    Asymmetric Loss for addressing class imbalance.
    More gentle than Focal Loss, better for noisy data.
    
    Paper: "Asymmetric Loss Functions for Learning with Noisy Labels"
    """
    
    def __init__(self, gamma_neg=2.0, gamma_pos=1.0, margin=0.05, eps=1e-8, reduction='mean'):
        super(AsymmetricLoss, self).__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.margin = margin
        self.eps = eps
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        """
        Args:
            inputs: (N, C) where C is the number of classes
            targets: (N,) where each value is 0 <= targets[i] <= C-1
        """
        # Convert to binary classification format
        if inputs.size(1) == 2:
            # Binary classification: take positive class logits
            inputs = inputs[:, 1]  # (N,)
            targets = (targets == 1).float()  # (N,)
        else:
            # Multi-class: convert to binary classification
            inputs = inputs[:, 1] if inputs.size(1) > 1 else inputs.squeeze()
            targets = (targets == 1).float()
        
        # Calculate probabilities
        probs = torch.sigmoid(inputs)
        
        # Positive sample loss
        pos_mask = (targets == 1)
        pos_loss = torch.pow(1 - probs[pos_mask], self.gamma_pos) * torch.log(probs[pos_mask] + self.eps)
        
        # Negative sample loss (with margin)
        neg_mask = (targets == 0)
        neg_loss = torch.pow(probs[neg_mask], self.gamma_neg) * torch.log(1 - probs[neg_mask] + self.eps)
        
        # Apply margin
        neg_loss = torch.clamp(neg_loss + self.margin, min=0)
        
        # Combine losses
        loss = torch.cat([pos_loss, neg_loss])
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


def get_balanced_class_weights(labels, method='capped', max_ratio=2.0):
    """
    Calculate balanced class weights, avoiding excessive weights.
    
    Args:
        labels: List of labels
        method: Weight calculation method ('sqrt', 'log', 'linear', 'capped', 'gentle')
        max_ratio: Maximum weight ratio (minority/majority)
    
    Returns:
        class_weights: Class weight tensor
    """
    
    # Calculate number of samples per class
    unique_labels, counts = np.unique(labels, return_counts=True)
    total_samples = len(labels)
    
    print(f"Class distribution: {dict(zip(unique_labels, counts))}")
    
    if method == 'sqrt':
        # Use square root method: w = sqrt(N/Ni)
        weights = np.sqrt(total_samples / counts)
    elif method == 'log':
        # Use logarithmic method: w = log(N/Ni) + 1
        weights = np.log(total_samples / counts) + 1
    elif method == 'linear':
        # Use linear method: w = N/(2*Ni)
        weights = total_samples / (2 * counts)
    elif method == 'capped':
        # Use linear method but limit maximum weight ratio
        weights = total_samples / (2 * counts)
        # Limit weight ratio
        max_weight = weights.max()
        min_weight = weights.min()
        if max_weight / min_weight > max_ratio:
            # Adjust weights so ratio doesn't exceed max_ratio
            scale_factor = max_ratio / (max_weight / min_weight)
            weights = weights * scale_factor
    elif method == 'gentle':
        # New: Gentler weight strategy to avoid over-compensation
        # Use square root method but further limit weight ratio
        weights = np.sqrt(total_samples / counts)
        # Limit maximum weight ratio to 1.5
        max_weight = weights.max()
        min_weight = weights.min()
        if max_weight / min_weight > 1.5:
            scale_factor = 1.5 / (max_weight / min_weight)
            weights = weights * scale_factor
        # Further reduce weight strength
        weights = 1.0 + (weights - 1.0) * 0.6  # Keep only 60% of compensation strength
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Normalize weights
    weights = weights / weights.sum() * len(unique_labels)
    
    print(f"Class weights: {weights}")
    print(f"Weight ratio (minority/majority): {weights.max()/weights.min():.3f}")
    
    return torch.tensor(weights, dtype=torch.float32)


def create_enhanced_weighted_sampler(labels, weights, per_class_limit=True, repeat_factor=8, base_limit=None):
    """
    Create enhanced WeightedRandomSampler with per-class sampling limit support
    
    Args:
        labels: List of labels
        weights: Class weights
        per_class_limit: Whether to enable per-class sampling limit
        repeat_factor: Maximum repetition factor per class per epoch
        base_limit: Base sampling limit
    
    Returns:
        sampler: Enhanced sampler
    """
    import torch
    from torch.utils.data import WeightedRandomSampler
    
    if not per_class_limit:
        # If limit is not enabled, use standard sampler
        return WeightedRandomSampler(weights, len(labels), replacement=True)
    
    # Calculate sample weights for each class
    sample_weights = []
    unique_labels = torch.unique(torch.tensor(labels))
    
    for label in labels:
        label_idx = (unique_labels == label).nonzero(as_tuple=True)[0].item()
        sample_weights.append(weights[label_idx].item())
    
    sample_weights = torch.tensor(sample_weights, dtype=torch.float32)
    
    # Calculate sampling limit for each class
    if base_limit is None:
        total_samples = len(labels)
        num_classes = len(unique_labels)
        base_limit = int((total_samples / num_classes) * 0.5)
    
    # Create sampler with limit
    # Note: Here we use standard WeightedRandomSampler, but indirectly implement limit through weight control
    # Actual per-class limit control needs to be implemented in the training loop
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
    
    return sampler


def create_weighted_sampler(labels, weights):
    """
    Create standard WeightedRandomSampler
    
    Args:
        labels: List of labels
        weights: Class weights
    
    Returns:
        sampler: Standard sampler
    """
    import torch
    from torch.utils.data import WeightedRandomSampler
    
    # Calculate sample weights for each class
    sample_weights = []
    unique_labels = torch.unique(torch.tensor(labels))
    
    for label in labels:
        label_idx = (unique_labels == label).nonzero(as_tuple=True)[0].item()
        sample_weights.append(weights[label_idx].item())
    
    sample_weights = torch.tensor(sample_weights, dtype=torch.float32)
    
    # Create sampler
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
    
    return sampler
