import numpy as np
import copy

class PathHomology:
    def __init__(self):
        return

    @staticmethod
    def remove_loops(edges):
        loop_idx = []
        loop_nodes = []
        for i, e in enumerate(edges):
            if e[0] == e[1]:
                loop_idx.append(i)
                loop_nodes.append(e[0])
        if len(loop_nodes) > 0:
            print(f'Warning, loops on node {loop_nodes} were removed.')
        edges = np.delete(edges, loop_idx, axis=0)
        return edges

    @staticmethod
    def split_independent_compondent(edges, nodes):
        node_map_idx = {node: idx for idx, node in enumerate(nodes)}
        graph = [[] for _ in range(len(nodes))]
        for u, v in edges:
            graph[node_map_idx[u]].append(v)
            graph[node_map_idx[v]].append(u)
        all_components = []
        visited = [False] * len(nodes)

        def dfs(node, component):
            visited[node_map_idx[node]] = True
            component.append(node)
            for neighbor in graph[node_map_idx[node]]:
                if not visited[node_map_idx[neighbor]]:
                    dfs(neighbor, component)

        for i, node in enumerate(nodes):
            if not visited[i]:
                component = []
                dfs(node, component)
                all_components.append(component)
        return all_components

    @staticmethod
    def split_independent_digraph(all_components, edges):
        all_digraphs = [[] for _ in all_components]
        edges_visited = [False for _ in edges]
        for i_c, component in enumerate(all_components):
            for i_e, edge in enumerate(edges):
                if not edges_visited[i_e] and (edge[0] in component or edge[1] in component):
                    all_digraphs[i_c].append(edge)
                    edges_visited[i_e] = True
            if len(component) == 1 and len(all_digraphs[i_c]) < 1:
                all_digraphs[i_c].append(component)
        return all_digraphs

    def utils_generate_allowed_paths(self, edges, max_path):
        nodes = np.unique(edges)
        nodes_num = len(nodes)
        nodes_idx_map = {node: idx for idx, node in enumerate(nodes)}
        edge_matrix = np.zeros((nodes_num, nodes_num))
        for u, v in edges:
            edge_matrix[nodes_idx_map[u], nodes_idx_map[v]] = 1
        allowed_path = {0: [np.array([n]) for n in nodes]}
        allowed_path_str = {0: [str(n) for n in nodes]}
        for i in range(max_path + 1):
            allowed_path[i + 1] = []
            allowed_path_str[i + 1] = []
            for path_prev in allowed_path[i]:
                for node in nodes:
                    if edge_matrix[nodes_idx_map[path_prev[-1]], nodes_idx_map[node]]:
                        new_path = np.append(path_prev, node)
                        allowed_path[i + 1].append(new_path)
                        allowed_path_str[i + 1].append('->'.join(map(str, new_path)))
        return allowed_path_str

    def utils_unlimited_boundary_operator(self, allowed_path, max_path):
        """
        Generate boundary matrices and ranks
        """
        boundary_map_matrix = {0: np.zeros([len(allowed_path[0]), ])}
        boundary_mat_matrix_rank = {0: 0}
        allowed_path_idx_argument = {0: [1] * len(allowed_path[0])}

        for n in range(1, max_path + 2):
            boundary_map_dict = {}
            boundary_operated_path_name_collect = []

            allowed_path_n_types = len(allowed_path[n])
            if allowed_path_n_types == 0:
                boundary_map_matrix[n] = np.zeros([1, len(allowed_path[n - 1])])
                boundary_mat_matrix_rank[n] = 0
                allowed_path_idx_argument[n] = [1] * len(allowed_path[n - 1])
                continue

            for i_path, path in enumerate(allowed_path[n]):
                path_node_idx = path.split('->')

                boundary_operated_path_info = {}
                for i_kill in range(n + 1):
                    temp_path = np.delete(path_node_idx, i_kill)
                    temp_path_str = '->'.join([str(pp) for pp in temp_path])
                    boundary_operated_path_info[temp_path_str] = (-1) ** (i_kill)
                    boundary_operated_path_name_collect.append(temp_path_str)

                boundary_map_dict[path] = copy.deepcopy(boundary_operated_path_info)

            considered_operated_path_name = np.unique(
                boundary_operated_path_name_collect + allowed_path[n - 1])
            unlimited_boundary_mat = np.zeros([allowed_path_n_types, len(considered_operated_path_name)])

            for i_path, (n_1_path_str, operated_n_path_dict) in enumerate(boundary_map_dict.items()):
                for j, n_path in enumerate(considered_operated_path_name):
                    if n_path in operated_n_path_dict:
                        unlimited_boundary_mat[i_path, j] = operated_n_path_dict[n_path]

            boundary_map_matrix[n] = unlimited_boundary_mat
            boundary_mat_matrix_rank[n] = np.linalg.matrix_rank(unlimited_boundary_mat)
            allowed_path_idx_argument[n] = [1 if tpn in allowed_path[n - 1] else 0 for tpn in
                                            considered_operated_path_name]

            assert len(allowed_path_idx_argument[n]) == len(considered_operated_path_name)
        return boundary_map_matrix, boundary_mat_matrix_rank, allowed_path_idx_argument


    def utils_regular_boundary_operator(self, allowed_path, max_path):
        boundary_map_matrix = {0: np.zeros([len(allowed_path[0]), ])}
        boundary_mat_matrix_rank = {0: 0}
        allowed_path_idx_argument = {0: [1] * len(allowed_path[0])}
        for n in range(1, max_path + 2):
            boundary_map_dict = {}
            boundary_operated_path_name_collect = []
            allowed_path_n_types = len(allowed_path[n])
            if allowed_path_n_types == 0:
                boundary_map_matrix[n] = np.zeros([1, len(allowed_path[n - 1])])
                boundary_mat_matrix_rank[n] = 0
                allowed_path_idx_argument[n] = [1] * len(allowed_path[n - 1])
                continue
            for path in allowed_path[n]:
                path_node_idx = path.split('->')
                boundary_operated_path_info = {}
                for i_kill in range(n + 1):
                    temp_path = np.delete(path_node_idx, i_kill)
                    is_regular = all(temp_path[j] != temp_path[j + 1] for j in range(len(temp_path) - 1))
                    if is_regular:
                        temp_path_str = '->'.join(temp_path)
                        boundary_operated_path_info[temp_path_str] = (-1) ** i_kill
                        boundary_operated_path_name_collect.append(temp_path_str)
                boundary_map_dict[path] = copy.deepcopy(boundary_operated_path_info)
            considered_paths = np.unique(boundary_operated_path_name_collect + allowed_path[n - 1])
            boundary_mat = np.zeros([allowed_path_n_types, len(considered_paths)])
            for i_path, (p_str, boundary_info) in enumerate(boundary_map_dict.items()):
                for j, cp in enumerate(considered_paths):
                    if cp in boundary_info:
                        boundary_mat[i_path, j] = boundary_info[cp]
            boundary_map_matrix[n] = boundary_mat
            boundary_mat_matrix_rank[n] = np.linalg.matrix_rank(boundary_mat)
            allowed_path_idx_argument[n] = [1 if tpn in allowed_path[n - 1] else 0 for tpn in considered_paths]
        return boundary_map_matrix, boundary_mat_matrix_rank, allowed_path_idx_argument

    def path_homology_for_connected_digraph(self, allowed_path, max_path):
        betti_numbers = np.array([0] * (max_path + 1))
        boundary_map_matrix, boundary_mat_matrix_rank, allowed_path_idx_argument = \
            self.utils_unlimited_boundary_operator(allowed_path, max_path)

        betti_numbers[0] = len(allowed_path[0]) - boundary_mat_matrix_rank[1]
        for n in range(1, max_path + 1):
            if len(allowed_path[n]) == 0:
                break
            dim_0 = len(allowed_path[n]) - boundary_mat_matrix_rank[n]
            if n == max_path or len(allowed_path[n + 1]) == 0:
                dim_1 = 0
            else:
                I_aug = np.eye(len(allowed_path_idx_argument[n + 1])) * allowed_path_idx_argument[n + 1]
                big_mat = np.vstack([I_aug, boundary_map_matrix[n + 1]])
                dim_An_Bn = np.linalg.matrix_rank(big_mat)
                dim_1 = len(allowed_path[n]) + boundary_mat_matrix_rank[n + 1] - dim_An_Bn
            betti_numbers[n] = dim_0 - dim_1
        return betti_numbers

    def path_homology(self, edges, nodes, max_path):
        if edges.dtype != nodes.dtype:
            edges = edges.astype(str)
            nodes = np.array(nodes).astype(str)
        all_components = self.split_independent_compondent(edges, nodes)
        all_digraphs = self.split_independent_digraph(all_components, edges)
        betti_numbers = []
        for edges in all_digraphs:
            if np.shape(edges)[1] <= 1:
                betti_numbers.append(np.array([1] + [0] * max_path))
            else:
                edges = self.remove_loops(edges)
                if len(edges) == 0:
                    betti_numbers.append(np.array([1] + [0] * max_path))
                    continue
                allowed_path = self.utils_generate_allowed_paths(edges, max_path)
                betti_numbers.append(self.path_homology_for_connected_digraph(allowed_path, max_path))

        return np.sum(betti_numbers, axis=0)

    def persistent_path_homology_from_digraph_given_edges(self, nodes, edge_with_distance, max_path, filtration_dis, max_dis):
        points_num = len(nodes)
        points_idx = np.array(nodes)
        fully_connected_map = np.zeros([points_num, points_num], dtype=int)
        distance_matrix = np.full([points_num, points_num], np.inf)
        for u, v, d in edge_with_distance:
            fully_connected_map[u, v] = 1
            distance_matrix[u, v] = d
        self.total_edges_num = np.sum(fully_connected_map > 0)
        filtration = np.arange(0, max_dis + filtration_dis, filtration_dis)
        all_betti_num = []
        save_time_flag = 0
        snapshot_map_temp = -np.ones_like(fully_connected_map)  # force first calculation
        for snapshot_dis in filtration:
            snapshot_map = ((distance_matrix <= snapshot_dis) & (fully_connected_map == 1)).astype(int)
            if len(all_betti_num) > 0 and (snapshot_map == snapshot_map_temp).all():
                betti_numbers = all_betti_num[-1]
                all_betti_num.append(betti_numbers)
                continue
            else:
                snapshot_map_temp = snapshot_map.copy()
            start_ids, end_ids = [], []
            for i in range(points_num):
                for j in range(points_num):
                    if snapshot_map[i, j] == 1:
                        start_ids.append(i)
                        end_ids.append(j)
            if len(start_ids) == 0:
                betti_numbers = np.array([len(nodes)] + [0] * max_path)
                all_betti_num.append(betti_numbers)
                continue
            edges = np.vstack([start_ids, end_ids]).T
            if save_time_flag == 1:
                betti_numbers = all_betti_num[-1]
                all_betti_num.append(betti_numbers)
                continue
            if len(edges) == self.total_edges_num:
                save_time_flag = 1
            betti_numbers = self.path_homology(edges, points_idx, max_path)
            all_betti_num.append(betti_numbers)
        return all_betti_num

# New interface
def input_digraph_simple(nodes, edge_with_distance, max_path, filtration, max_dis):
    PH = PathHomology()
    betti_num_all = PH.persistent_path_homology_from_digraph_given_edges(
        nodes, edge_with_distance, max_path, filtration, max_dis
    )
    return betti_num_all

# Test main function
def do_main():
    nodes = [0, 1, 2, 3]
    edge_with_distance = [
        [0, 1, 0.5],
        [1, 0, 0.5],
        [1, 2, 1.2],
        [2, 3, 0.7],
        [3, 0, 2.0]
    ]
    max_path = 3
    filtration = 1
    max_dis = 3.0
    betti_num_all = input_digraph_simple(nodes, edge_with_distance, max_path, filtration, max_dis)
    print(betti_num_all)
    for i, b in enumerate(betti_num_all):
        print(f"Filtration Step {i}: Betti Numbers = {b}")

if __name__ == "__main__":
    do_main()


def do_main2(node, weighted_edge):
    nodes = node
    edge_with_distance = weighted_edge
    max_path = 5
    filtration = 1
    max_dis = 3
    betti_num_all = input_digraph_simple(nodes, edge_with_distance, max_path, filtration, max_dis)

    result = {}
    beta_00 = len(node)
    beta_01 = betti_num_all[1][0] if len(betti_num_all) > 1 else None
    beta_11 = betti_num_all[1][1] if len(betti_num_all) > 1 else None
    beta_12 = betti_num_all[2][1] if len(betti_num_all) > 2 else None

    result["β₀ (f=0)"] = beta_00
    result["β₀ (f=1)"] = beta_01
    result["β₁ (f=1)"] = beta_11
    result["β₁ (f=2)"] = beta_12

    print("k-mer Betti summary:", result)
    return result


        
