#!/usr/bin/python3

# FusionAI: AI detectoin of gene fusion based on DNA sequence
# usage:  FusionAI_pred.py [-h] -f FILENAME [-m MODEL] [-o OUTPUT] [-A COLA] [-B COLB] [-I ROWI]
# usage example: python FusionAI_pred.py -f fusionai_test_bed.txt
# by Hua Tan, copyright (c) warm.tan@gmail.com, 8/11/2020

import numpy as np
import pandas as pd
import argparse

def one_hot_encode(seq):
    map = np.asarray([[0, 0, 0, 0],
                      [1, 0, 0, 0],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0],
                      [0, 0, 0, 1]])
    seq = seq.upper().replace('A', '\x01').replace('C', '\x02')
    seq = seq.replace('G', '\x03').replace('T', '\x04').replace('N', '\x00')
    return map[np.fromstring(seq, np.int8) % 5]

def raw2code(file,id_seq1,id_seq2):
    fg = pd.read_csv(file,sep="[\s\t]",header=None)
    fg_seq = fg.iloc[:,id_seq1] + fg.iloc[:,id_seq2]
    fg_seq_coded = np.asarray([one_hot_encode(x) for x in fg_seq])
    print("dimension in raw2code: "+str(fg_seq_coded.shape))
    if fg_seq_coded.shape[-2:] != (20000,4):
        print("the default model requires an input of 2kb-long DNA sequence,\
              \nif you input different lengths, training your own FusionAI model\
              \nfirst by calling FusionAI_train.py")
    return fg, fg_seq_coded


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='FusionAI: Artifical Intelligence detection\
                                     of gene-gene fusion based soley on DNA sequence of\
                                     BP+/-5kb.\n\
                                     Author: Hua Tan,\n\
                                     Copyright (c), warm.tan@gmail.com, 8/11/2020 @7900')
    parser.add_argument('-f','--filename', help="File name of input bed file (N*10 table):\
                        [TEKT4P2 chr21 9966322 - RBPJ chr4 26364067 + GGACTACA... CCTTACTT...]", 
                        type = str, required=True)
    parser.add_argument('-m','--model',help="File name of trained FusionAI model in h5 format",
                    type = str, default='newdat_newmod_jj.h5')
    parser.add_argument('-o','--output',help="File name of output for FusionAI predictions (N*2 table):\
                    [prob_nofusion prob_fusion]",default="FusionAI_predictions.txt")
    parser.add_argument('-A','--colA',help="column index of 5'-gene sequence in input file",
                        type = int, default=8)
    parser.add_argument('-B','--colB',help="column index of 3'-gene sequence in input file",
                        type = int, default=9)
    parser.add_argument('-I','--rowI',help="row index of interested line in input file",
                        type = int, default=None)
    
    args = vars(parser.parse_args())
    
    print("\n###########parameters###########")
    for key in args.keys():
        print(key,args[key])
    print("###########parameters###########\n")
    
    fg, data = raw2code(file=args['filename'],id_seq1=args['colA'],id_seq2=args['colB'])
    
    if args['rowI'] is not None:
        fg = fg.iloc[[args['rowI']]]
        data = data[[args['rowI']]]
    
    Nsamp, LEN_SEQ, WID_SEQ = data.shape
    
    import keras
    
    if keras.backend.image_data_format() == "channels_first":
        x_test = data.reshape(Nsamp,1,LEN_SEQ,WID_SEQ)
    else:
        x_test = data.reshape(Nsamp,LEN_SEQ,WID_SEQ,1)
    print("Dimension of data:"+str(x_test.shape))
        
    model = keras.models.load_model(args['model'])
    
    predictions = model.predict_proba(x_test,verbose=1)
    
    predictions_ = pd.concat([fg,pd.DataFrame(predictions)],axis=1)
    
    predictions_.to_csv(args['output'],sep='\t',header=False,index=False)














