import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pylab
import cPickle
from fileNames import YR_01_path
from tissueviewer.dtissue import DTissue 
from allTagsDModule import allTagsD_L1_new, allTagsD_L2_new
from tissueviewer.image import makepatternCellsValues


from createTagsD import patterns10hOnlyL1, patterns10hOnlyL2, patterns10hAllTissue
from createTagsD import patterns40hOnlyL1, patterns40hOnlyL2, patterns40hAllTissue
from createTagsD import patterns96hOnlyL1, patterns96hOnlyL2, patterns96hAllTissue
from createTagsD import patterns120hOnlyL1, patterns120hOnlyL2, patterns120hAllTissue
from createTagsD import patterns132hOnlyL1, patterns132hOnlyL2, patterns132hAllTissue

from computeBackwardTransitionGraph import dist_10h_40h_L1, dist_40h_96h_L1, dist_96h_120h_L1, dist_120h_132h_L1
from computeBackwardTransitionGraph import dist_10h_40h_L2, dist_40h_96h_L2, dist_96h_120h_L2, dist_120h_132h_L2

from computeHammingDistances_L1 import dist as hammingD

threshold = 0.2



fobj = file("FM1_dtissue.tis")
dtis = cPickle.load(fobj)
fobj.close()



def calculateMotherRatio(D, candidateAncesters, daughterTP, ancestersTP):
    """
    Compute the number of the cell ids in "D" whose ancesters are in "candidateAncesters" over the number of the cell ids in D
    """
    
    counter = 0.0
    for cid in D:
        for an in candidateAncesters:
            R = dtis.extractDescendants(ancestersTP, an, daughterTP)
            if cid in R:
                counter += 1.0
                break
    return counter / float(len(D))


pat10hKeys = set(patterns10hAllTissue.keys()) - set(["L1", "L2"])
pat40hKeys = set(patterns40hAllTissue.keys()) - set(["L1", "L2"])
pat96hKeys = set(patterns96hAllTissue.keys()) - set(["L1", "L2"])
pat120hKeys = set(patterns120hAllTissue.keys()) - set(["L1", "L2"])
pat132hKeys = set(patterns132hAllTissue.keys()) - set(["L1", "L2"])


pat10hKeysL1 = set(patterns10hOnlyL1.keys()) - set(["L1", "L2"])
pat40hKeysL1 = set(patterns40hOnlyL1.keys()) - set(["L1", "L2"])
pat96hKeysL1 = set(patterns96hOnlyL1.keys()) - set(["L1", "L2"])
pat120hKeysL1 = set(patterns120hOnlyL1.keys()) - set(["L1", "L2"])
pat132hKeysL1 = set(patterns132hOnlyL1.keys()) - set(["L1", "L2"])


pat10hKeysL2 = set(patterns10hOnlyL2.keys()) - set(["L1", "L2"])
pat40hKeysL2 = set(patterns40hOnlyL2.keys()) - set(["L1", "L2"])
pat96hKeysL2 = set(patterns96hOnlyL2.keys()) - set(["L1", "L2"])
pat120hKeysL2 = set(patterns120hOnlyL2.keys()) - set(["L1", "L2"])
pat132hKeysL2 = set(patterns132hOnlyL2.keys()) - set(["L1", "L2"])

cellsValues10h_L1, tagsColomapId10h_L1, tagsCIds10h_L1 = makepatternCellsValues(patterns10hOnlyL1, pat10hKeysL1)
cellsValues40h_L1, tagsColomapId40h_L1, tagsCIds40h_L1 = makepatternCellsValues(patterns40hOnlyL1, pat40hKeysL1)
cellsValues96h_L1, tagsColomapId96h_L1, tagsCIds96h_L1 = makepatternCellsValues(patterns96hOnlyL1, pat96hKeysL1)
cellsValues120h_L1, tagsColomapId120h_L1, tagsCIds120h_L1 = makepatternCellsValues(patterns120hOnlyL1, pat120hKeysL1)
cellsValues132h_L1, tagsColomapId132h_L1, tagsCIds132h_L1 = makepatternCellsValues(patterns132hOnlyL1, pat132hKeysL1)


cellsValues10h_L2, tagsColomapId10h_L2, tagsCIds10h_L2 = makepatternCellsValues(patterns10hOnlyL2, pat10hKeysL2)
cellsValues40h_L2, tagsColomapId40h_L2, tagsCIds40h_L2 = makepatternCellsValues(patterns40hOnlyL2, pat40hKeysL2)
cellsValues96h_L2, tagsColomapId96h_L2, tagsCIds96h_L2 = makepatternCellsValues(patterns96hOnlyL2, pat96hKeysL2)
cellsValues120h_L2, tagsColomapId120h_L2, tagsCIds120h_L2 = makepatternCellsValues(patterns120hOnlyL2, pat120hKeysL2)
cellsValues132h_L2, tagsColomapId132h_L2, tagsCIds132h_L2 = makepatternCellsValues(patterns132hOnlyL2, pat132hKeysL2)

allTagsDreverse_L1 = dict((v, k) for k, v in allTagsD_L1_new.iteritems())
allTagsDreverse_L2 = dict((v, k) for k, v in allTagsD_L2_new.iteritems())
 

#==========================================================================================

wwDes = []
first = tagsCIds10h_L1
second = tagsCIds40h_L1
fName = "transitionGraph10h_40h_L1"


A = np.matrix(dist_10h_40h_L1)

 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
counter = 0
for e in GMain.edges():
    counter += 1

G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L1[item[1]]], first[allTagsDreverse_L1[item[0]]], "40h", "10h") ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_10h_40h_L1[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
    

p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))



#==========================================================================================

wwDes = []
first = tagsCIds40h_L1
second = tagsCIds96h_L1
start = "40"
stop = "96"
fName = "transitionGraph%sh_%sh_L1"%(start, stop)

A = np.matrix(dist_40h_96h_L1)
 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
    
counter = 0
for e in GMain.edges():
    counter += 1


G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L1[item[1]]], first[allTagsDreverse_L1[item[0]]], "%sh"%stop, "%sh"%start) 
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_40h_96h_L1[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
    
p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))

#==========================================================================================

first = tagsCIds96h_L1
second = tagsCIds120h_L1
start = "96"
stop = "120"
fName = "transitionGraph%sh_%sh_L1"%(start, stop)

wwDes = []

A = np.matrix(dist_96h_120h_L1)
 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
    
counter = 0
for e in GMain.edges():
    counter += 1

G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L1[item[1]]], first[allTagsDreverse_L1[item[0]]], "%sh"%stop, "%sh"%start) ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_96h_120h_L1[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))


#==========================================================================================

first = tagsCIds120h_L1
second = tagsCIds132h_L1
start = "120"
stop = "132"
fName = "transitionGraph%sh_%sh_L1"%(start, stop)

wwDes = []

A = np.matrix(dist_120h_132h_L1)

 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())

counter = 0
for e in GMain.edges():
    counter += 1

G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L1[item[1]]], first[allTagsDreverse_L1[item[0]]], "%sh"%stop, "%sh"%start) ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_120h_132h_L1[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
    

p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))



#==========================================================================================

first = tagsCIds10h_L2
second = tagsCIds40h_L2
start = "10"
stop = "40"
fName = "transitionGraph%sh_%sh_L2"%(start, stop)

wwDes = []

A = np.matrix(dist_10h_40h_L2)
 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
    
counter = 0
for e in GMain.edges():
    counter += 1

G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L2[item[1]]], first[allTagsDreverse_L2[item[0]]], "%sh"%stop, "%sh"%start) ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_10h_40h_L2[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])

p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))


#==========================================================================================

first = tagsCIds40h_L2
second = tagsCIds96h_L2
start = "40"
stop = "96"
fName = "transitionGraph%sh_%sh_L2"%(start, stop)

wwDes = []

A = np.matrix(dist_40h_96h_L2)

 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
    
counter = 0
for e in GMain.edges():
    counter += 1


G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L2[item[1]]], first[allTagsDreverse_L2[item[0]]], "%sh"%stop, "%sh"%start) ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_40h_96h_L2[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
    

p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))

#==========================================================================================

first = tagsCIds96h_L2
second = tagsCIds120h_L2
start = "96"
stop = "120"
fName = "transitionGraph%sh_%sh_L2"%(start, stop)

wwDes = []

A = np.matrix(dist_96h_120h_L2)

 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
    
counter = 0
for e in GMain.edges():
    counter += 1

G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L2[item[1]]], first[allTagsDreverse_L2[item[0]]], "%sh"%stop, "%sh"%start) ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_96h_120h_L2[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
    

p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))





#==========================================================================================

first = tagsCIds120h_L2
second = tagsCIds132h_L2
start = "120"
stop = "132"
fName = "transitionGraph%sh_%sh_L2"%(start, stop)

wwDes = []

A = np.matrix(dist_120h_132h_L2)
 
GMain = nx.from_numpy_matrix(A, create_using=nx.MultiDiGraph())
    
counter = 0
for e in GMain.edges():
    counter += 1

G = nx.DiGraph()
D = nx.get_edge_attributes(GMain,'weight')
for item, w in D.iteritems():
    val = calculateMotherRatio(second[allTagsDreverse_L2[item[1]]], first[allTagsDreverse_L2[item[0]]], "%sh"%stop, "%sh"%start) ###VVIPPPPPPPPPPPPPPPPPPPPP
    val = np.round(val, 2)
    w = hammingD[item[0], item[1]]
    w *= 27
    wDes = dist_120h_132h_L2[item[0], item[1]]
    wDes = np.round(wDes, 2)
    if val > threshold:
        G.add_edges_from([(item[0], item[1])], label = str(val))
        wwDes.append([w, wDes])
    

p = nx.drawing.nx_pydot.to_pydot(G)
p.write_png('%s_backward_%0.2f.png'%(fName, threshold))