# -*- coding: utf-8 -*-
"""
Created on Fri Dec 19 18:30:58 2025

@author: LJY
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
------------------------------------
1. Only compute 1-mer (20 amino acids)
2. Extract 18-dimensional topological features (barcode statistics) for each 1-mer
   = 6 features each for H0/H1/H2 → 18 dimensions
3. Use Digraph construction method (V, WG, max_distance=8) unchanged
4. Output dimension = 20 × 18 = 360 columns
5. Each FASTA → separate CSV (file name/column name/directory unchanged)
------------------------------------
"""

import os
import csv

# ===== Use provided GLMYnonregular.py =====
from GLMYnonregular import input_digraph_simple

AA = "ACDEFGHIKLMNPQRSTVWY"
FIXED_MAX_DISTANCE = 8


# --------------------------
# Utility: Clean sequence
# --------------------------
def clean_sequence(seq):
    return "".join(c for c in seq.upper() if c in AA)


# --------------------------
# FASTA reading
# --------------------------
def read_fasta(path):
    name = None
    seq = []
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                if name:
                    yield clean_sequence("".join(seq))
                name = line[1:]
                seq = []
            else:
                seq.append(line)
    if name:
        yield clean_sequence("".join(seq))


# --------------------------
# Build Digraph
# --------------------------
def get_positions(seq, kmer):
    L = len(kmer)
    return [i+1 for i in range(len(seq)-L+1) if seq[i:i+L] == kmer]


def build_graph(positions):
    raw = sorted(positions)
    n = len(raw)
    V = list(range(n))
    WG = []

    idx = {p: i for i, p in enumerate(raw)}

    for i, u in enumerate(raw):
        for j, v in enumerate(raw):
            if i == j:
                continue
            d = abs(v - u)
            if d <= FIXED_MAX_DISTANCE:
                WG.append([idx[u], idx[v], d])

    return V, WG


# ============================================================
# GLMYnonregular returns Betti curve → rebuild barcode diagram
# ============================================================

def betti_curve_to_diagram(betti_num_all, dims=("0", "1", "2")):
    """
    betti_num_all: list, each element is an array, e.g., [beta0,beta1,beta2,...], corresponding to filtration t=0..max
    Returns diagram dict: diagram["0"]=[(b,d),...], diagram["1"]=..., diagram["2"]=...
    
    Rebuild rules (common consistent rules for discrete filtration):
      - At t=0, open betti(0) bars, birth=0
      - If betti increases by k at t: open k bars, birth=t
      - If betti decreases by k at t: close k bars, death=t
        Closing strategy: FIFO (first born first dead), ensures reproducibility
      - Bars still active at end: death=-1
    """
    diagram = {d: [] for d in dims}
    if not betti_num_all:
        return diagram

    T = len(betti_num_all) - 1  # last index

    # Maintain an "active birth queue" for each dimension
    active = {d: [] for d in dims}

    # Initialize t=0
    b0 = betti_num_all[0]
    for dim_str in dims:
        dim = int(dim_str)
        cnt0 = int(b0[dim]) if len(b0) > dim else 0
        active[dim_str] = [0] * cnt0  # cnt0 bars with birth=0

    # Step by step update
    for t in range(1, T + 1):
        prev = betti_num_all[t - 1]
        cur = betti_num_all[t]
        for dim_str in dims:
            dim = int(dim_str)
            p = int(prev[dim]) if len(prev) > dim else 0
            c = int(cur[dim]) if len(cur) > dim else 0
            diff = c - p
            if diff > 0:
                # Open diff bars
                active[dim_str].extend([t] * diff)
            elif diff < 0:
                # Close -diff bars (FIFO)
                k = -diff
                for _ in range(k):
                    if active[dim_str]:
                        b = active[dim_str].pop(0)
                        diagram[dim_str].append((b, t))
                    else:
                        # Theoretically shouldn't happen; not fatal if it does
                        pass

    # Bars still active at end set death=-1
    for dim_str in dims:
        for b in active[dim_str]:
            diagram[dim_str].append((b, -1))

    return diagram


def compute_diagram_by_glmy_nonregular(V, WG, max_distance=8):
    """
    Call GLMYnonregular, get betti_num_all, then rebuild diagram.
    """
    if len(V) == 0:
        return {"0": [], "1": [], "2": []}

    # GLMYnonregular needs: nodes, edge_with_distance, max_path, filtration, max_dis
    nodes = V
    edge_with_distance = WG

    # To get H0/H1/H2, max_path needs to be at least 3
    max_path = 3
    filtration = 1
    max_dis = max_distance

    betti_num_all = input_digraph_simple(nodes, edge_with_distance, max_path, filtration, max_dis)
    diagram = betti_curve_to_diagram(betti_num_all, dims=("0", "1", "2"))
    return diagram


# --------------------------
# Barcode statistics (keep as is)
# --------------------------
def summarize_barcode(diagram, max_distance=8):
    """
    Returns (18-dimensional features)
    For each dimension H0, H1, H2 extract:
        - Bar count
        - Max lifetime max_life
        - Mean lifetime mean_life
        - Total lifetime total_life
        - Long-lived bar count long_count (>max_distance/2)
        - Min lifetime min_life
    Total 6 features × 3 dimensions = 18 dimensions
    """

    def lifetime(b, d):
        if d == -1:
            return max_distance - b
        return max(0, d - b)

    features = []

    for dim in ["0", "1", "2"]:
        bars = diagram.get(dim, [])
        lifetimes = [lifetime(b, d) for b, d in bars]

        if not lifetimes:
            features += [0, 0, 0, 0, 0, 0]
            continue

        count = len(lifetimes)
        max_life = max(lifetimes)
        mean_life = sum(lifetimes) / count
        total_life = sum(lifetimes)
        long_count = sum(1 for L in lifetimes if L > max_distance/2)
        min_life = min(lifetimes)

        features += [count, max_life, mean_life, total_life, long_count, min_life]

    return tuple(features)


# --------------------------
# Main computation: Compute topological features for all 1-mers of a sequence
# --------------------------
def compute_kmer_features(seq):
    seq = clean_sequence(seq)
    feats = {}

    # ---- 1-mer ----
    for aa in AA:
        pos = get_positions(seq, aa)
        if not pos:
            feats[aa] = (0,) * 18
            continue

        V, WG = build_graph(pos)
        diagram = compute_diagram_by_glmy_nonregular(V, WG, FIXED_MAX_DISTANCE)
        feats[aa] = summarize_barcode(diagram, FIXED_MAX_DISTANCE)

    return feats


# --------------------------
# Process single FASTA
# --------------------------
def process_single_fasta(fpath, out_dir):
    base = os.path.basename(fpath)
    prefix = os.path.splitext(base)[0]
    out_csv = os.path.join(out_dir, prefix + ".csv")
    print(f"[RUN] Processing {base}")

    seq_list = [(i, clean_sequence(s)) for i, s in enumerate(read_fasta(fpath))]

    # Header 360 dimensions (20 1-mers × 18 features)
    header = ["seq_id"]

    def add_header(prefix_):
        for k in range(18):
            header.append(f"{prefix_}_f{k}")

    for aa in AA:
        add_header(aa)

    # Write CSV
    with open(out_csv, "w", newline="") as out:
        writer = csv.writer(out)
        writer.writerow(header)

        for seq_id, seq in seq_list:
            feats = compute_kmer_features(seq)
            row = [seq_id]

            for aa in AA:
                row.extend(feats[aa])

            writer.writerow(row)

    print(f"[OK] Output → {out_csv}")


# --------------------------
# Main program entry
# --------------------------
def process_all_fasta(fasta_dir, out_dir):
    os.makedirs(out_dir, exist_ok=True)

    fasta_files = [
        os.path.join(fasta_dir, f)
        for f in os.listdir(fasta_dir)
        if f.lower().endswith((".fa", ".fasta", ".faa"))
    ]

    for fp in fasta_files:
        process_single_fasta(fp, out_dir)


if __name__ == "__main__":
    process_all_fasta(
        r"./symm_fasta",
        r"./symm1merandtop20_8dis"
    )
