#!/usr/bin/env python
# coding: utf-8

# In[27]:


#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os

from Bio.PDB.PDBParser import PDBParser
from Bio.PDB import NeighborSearch
from Bio.PDB import Selection
from Bio.PDB import StructureAlignment
from Bio.Cluster import distancematrix
from Bio.Align import MultipleSeqAlignment

import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import time
import numpy as np
from Bio.Cluster import kmedoids
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all" 

import torch as t


# In[2]:


def dist_map(pdb_path):
    parser = PDBParser()
    structure = parser.get_structure('X', pdb_path)
    atoms = structure.get_atoms()
    ATOM = []
    for atom in atoms:
        ATOM.append(atom.get_coord())
    dist_map = distancematrix(ATOM)
    dist_map = [np.sqrt(3 * i) for i in list(dist_map)]
    m = np.empty(shape = (len(dist_map), len(dist_map)))
    for i in range(len(dist_map)):
        for j in range(len(dist_map)):
            if i == j:
                m[i][j] = 0
            elif i<j:
                m[i][j] = dist_map[j][i]
            else:
                m[i][j] = dist_map[i][j]
    m = t.from_numpy(m)
    return m


# In[3]:

def chain_length(pdb_path):
    parser = PDBParser()
    structure = parser.get_structure('X', pdb_path)
    model = structure[0]
    chain_len = []
    for chain in model:
        atoms = chain.get_atoms()
        i = 0
        for atom in atoms:
            i = i + 1
        chain_len.append(i)
    all_length = sum(chain_len)
    return chain_len, all_length


# In[4]:


def inter_chains_mask(length):
    dim = sum(length)
    refer = [0]
    for l in range(len(length)):
        re = sum(length[:(l + 1)])
        refer.append(re)
    chain_mask = np.zeros(shape = (dim, dim))
    for i in range(dim):
        for j in range(dim):
            for r in range(len(refer) - 1):
                sub_re = [refer[r], refer[r + 1]]
                if (i >= sub_re[0] and j >= sub_re[0] and i < sub_re[1] and j < sub_re[1]):
                    chain_mask[i, j] = 1
    chain_mask = t.from_numpy(chain_mask)
    return chain_mask


# In[5]:


def interface(d, cutoff):
    if d <= cutoff:
        return 1
    else:
        return 0


# In[6] 0，1距离矩阵

def judge_contact(pdb_path, intra_cutoff = 6, inter_cutoff = 8):
    d_map = dist_map(pdb_path)
    c_length = chain_length(pdb_path)
    c_mask = inter_chains_mask(c_length[0])
    
    dim = d_map.size()[0]
    c = np.empty(shape = (dim, dim))
    for i in range(dim):
        for j in range(dim):
            if c_mask[i, j] == 1:
                c[i, j] = interface(d_map[i, j], cutoff = intra_cutoff)
            else:
                c[i, j] = interface(d_map[i, j], cutoff = inter_cutoff)
    c = t.from_numpy(c)
    return c

# In[7]  

def get_gtcontact_maps(pdb_path, intra_cutoff, inter_cutoff):
    distance_map = dist_map(pdb_path)
    contact_map = judge_contact(pdb_path, intra_cutoff, inter_cutoff)
    return contact_map
def get_difference_map(gt_result, pred_result):
    dist_diff_map = abs(gt_result[0] - pred_result[0])
    con_diff_map = abs(gt_result[1] - pred_result[1])
    return dist_diff_map, con_diff_map


# In[8] 计算新的加权矩阵

def maskmap(contact_map):
    number = contact_map.size()[0]
    maskmaprt = np.empty((number,number))
    #contact_map = t.from_numpy(contact_map)
    #对矩阵进行扩充
    aa = np.insert(contact_map,0,0,axis=1)
    bb = np.insert(aa,number+1,0,axis=1)
    cc = np.insert(bb,0,0,axis=0)
    dd = np.insert(cc,number+1,0,axis=0)
    for i in range(1,number):
        for j in range(1,number):
            maskmaprt[i-1, j-1] = (dd[i-1, j-1]+dd[i-1, j]+dd[i-1, j+1]+dd[i, j-1]+dd[i, j]+dd[i, j+1]+dd[i+1, j-1]+dd[i+1, j]+dd[i+1, j+1])/9
    maskmaprt = t.from_numpy(maskmaprt)       
    return maskmaprt

#In[9] 计算距离矩阵差

def get_difference_map(gt_map, pred_map):
    dist_diff_map = abs(gt_map - pred_map)
#    con_diff_map = abs(gt_result[1] - pred_result[1])
    return dist_diff_map


#In[10] 对应于RMSD计算形式

def maskrt(maskmaprt,distdiff_map):
    md = np.multiply(maskmaprt,distdiff_map)
#    mmd = md*md
    number = maskmaprt.sum().item()
#    number = np.square(distdiff_map.size()[0])
    summd = md.sum().item()
    result = np.sqrt(summd/number)

    return result

#执行命令
def result(pdb_pathgt,pdb_pathpred):
    contect_map = get_gtcontact_maps(pdb_pathgt, intra_cutoff=6, inter_cutoff=8)
 #  mm = maskmap(contect_map)
    gt_map = dist_map(pdb_pathgt)
    pred_map = dist_map(pdb_pathpred)
    gdm = get_difference_map(gt_map, pred_map)
    rt = maskrt(contect_map,gdm)
    return rt


import sys
pdb_pathgt = sys.argv[1]
pdb_pathpred = sys.argv[2]
rtt=result(pdb_pathgt,pdb_pathpred)
print(rtt)




# In[ ]:




