FAQ | This is a LIVE service | Changelog

Skip to content
Snippets Groups Projects
growth_control.py 35.5 KiB
Newer Older

    return d


def calcTs(tss, grates, gs, f):
    d = {10:  dict(),
         40:  dict(),
         96:  dict(),
         120: dict(),
         132: dict()}

    for g1, g2 in product(gs, gs):
        print(g1, g2)
        v = f(tss, grates, g1, g2)
        for t, gratio in v.items():
            d[t][(g1, g2)] = gratio

    return d

def matrixify(k2_dict):
    from itertools import product

    kss = k2_dict.keys()
    kss1 = sorted(list(set([ks[0] for ks in kss])))
    kss2 = sorted(list(set([ks[1] for ks in kss])))

    feat_mat = np.zeros((len(kss1), len(kss2)))

    for (i, k1), (j, k2) in product(list(enumerate(kss1)),
                                    list(enumerate(kss2))):
        feat_mat[i, j] = k2_dict[(k1, k2)]

    return kss1, kss2, feat_mat

def matrixify_df(k2_dict, kss, df=""):
    from itertools import product

    kss1 = sorted(kss)
    kss2 = sorted(kss)

    feat_mat = np.empty((len(kss1), len(kss2)), dtype=str)

    for (i, k1), (j, k2) in product(list(enumerate(kss1)),
                                    list(enumerate(kss2))):
        print(i, j, k1, k2)
        feat_mat[i, j] = k2_dict.get((k1, k2), df)

    return kss1, kss2, feat_mat

def arrange_rows_cols(vals, gs):
    vals_ = dict()
    for i, g in enumerate(gs):
        vals_[g] = list(vals[:, i])

    return vals_

def mkStripPlotCombs(vals_, ax, cl, lb="", annot=True):
    set_plot_params()
    plt.rcParams['svg.fonttype'] = 'none'
    d = vals_
    from itertools import repeat
    from functools import reduce
    from operator import add

    xs_labs = sorted(d.keys())
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.4))

    xs = reduce(add, [list(repeat(x, len(d[x])))
                      for i, x in enumerate(xs_labs)])
    hs = reduce(add, [list(repeat("pos", len(d[x])))
                      for i, x in enumerate(xs_labs)])
    ys = reduce(add, [d[x] for x in xs_labs])
    cls = sns.color_palette("Blues")
    sns.stripplot(xs, ys, palette=[cls[cl]], size=5, label=lb)
    ax.set_xticks(list(range(len(xs_labs))))

    if annot:
        ax.text(21.3, -0.383, "ETTIN, STM", alpha=0.5, weight='ultralight')
        ax.text(12.2, 0.36, "SEP1, LFY", alpha=0.5, weight='ultralight')
    
    plt.xticks(rotation=90)

def go_combs(tss, linss, geneNms, f):
    grates = grs.grates_avg_cons(tss, linss)

    d = calcTs(tss, grates, geneNms, f)
    res = {}
    for t in d.keys():
        rr = sorted([(pair, gr) for pair, gr in d[t].items()
                     if not math.isnan(gr)],
                    key=lambda x: abs(x[1]),
                    reverse=True)
        rnan = [(pair, gr) for pair, gr in d[t].items() if math.isnan(gr)]
        res[t] = rr + rnan

    return res

def plotHeatMap(vals, ykeys, xkeys, annot, bounds=(0, 1)):
    set_plot_params(fontsize=7)
    fig = plt.figure(figsize=(5, 4))

    vmin, vmax = bounds
    ax = sns.heatmap(vals, yticklabels=ykeys,
                     xticklabels=xkeys, linewidths=.25,
                     cmap='coolwarm', vmin=vmin, vmax=vmax, annot_kws={"size": 7},
                     annot=annot, fmt='', cbar_kws={"shrink": .5,
                                                    "label": "RGD"})
    ax.set_ylabel("gene A")
    ax.set_xlabel("gene B")

def strip_vals(d):
    d_ = list()
    for rid, pvs in d:
        d_.append((rid, [p for p, v in pvs]))

    return d_

def plot_heatmap(ress, tss, t):
    set_plot_params(fontsize=12)
    gs, gs, vals = matrixify(dict(ress[t]))
    d_ = groupByRegionGPairs(t, tss[t], ress[t])
    d = strip_vals(d_)
    gs, gs, vals_annot = matrixify_df(invert_dlist(dict(d)), seg.geneNms)

    plotHeatMap(vals, gs, gs, vals_annot, (-0.4, 0.4))
    plt.savefig("grates_gene_pairs_t{t}.svg".format(t=t), dpi=300,
                bbox_inches='tight')

    fn = plotRegions(t, tss[t], d_)

    # call(["open",
    #       fn])

    # call(["open",
    #       "grates_gene_pairs_t{t}.png".format(t=t)])

def plot_res_single_all_tpoints(res):
    set_plot_params(fontsize=11)

    fig = plt.figure(figsize=(9, 5))
    ax = fig.add_subplot('111')
    ax.set_title("A vs not A")
    ax.set_xlabel("Gene A")
    ax.set_ylabel("RGD")
    ax.set_ylim((-0.6, 0.6))

    mkStripPlotCombs(res[132], ax, 4, annot=False)
    mkStripPlotCombs(res[120], ax, 3, annot=False)
    mkStripPlotCombs(res[96], ax, 2, annot=False)
    mkStripPlotCombs(res[40], ax, 1, annot=False)
    mkStripPlotCombs(res[10], ax, 0, annot=False)

    plt.savefig("gene_rgd_all_tpoints.png", dpi=300,
                bbox_inches='tight')

def plot_res_single_comb_tpoints(res):
    set_plot_params(fontsize=11)

    fig = plt.figure(figsize=(9, 5))
    ax = fig.add_subplot('111')
    ax.set_title("A vs not A (combined timepoints)")
    ax.set_xlabel("Gene A")
    ax.set_ylabel("RGD")
    ax.set_ylim((-0.6, 0.6))

    mkStripPlotCombs(res, ax, 4, annot=False)

    plt.savefig("gene_rgd_all_tpoints_combined.png", dpi=300,
                bbox_inches='tight')

def plot_res_combs_all_tpoints(res):
    import matplotlib.lines as mlines
    set_plot_params(fontsize=11)
    t = 132
    gs, gs, vals = matrixify(dict(res[t]))
    vals_132 = arrange_rows_cols(vals, gs)

    t = 120
    gs, gs, vals = matrixify(dict((res[t])))
    vals_120 = arrange_rows_cols(vals, gs)

    t = 96
    gs, gs, vals = matrixify(dict((res[t])))
    vals_96 = arrange_rows_cols(vals, gs)

    t = 40
    gs, gs, vals = matrixify((dict(res[t])))
    vals_40 = arrange_rows_cols(vals, gs)

    t = 10
    gs, gs, vals = matrixify((dict(res[t])))
    vals_10 = arrange_rows_cols(vals, gs)

    fig = plt.figure(figsize=(9, 5))
    ax = fig.add_subplot('111')
    ax.set_title("(A and B) vs (A and not B)")
    ax.set_xlabel("Gene B")
    ax.set_ylabel("RGD")
    ax.set_ylim((-1.2, 1.2))

    mkStripPlotCombs(vals_132, ax, 5, lb="4")
    mkStripPlotCombs(vals_120, ax, 4, lb="3")
    mkStripPlotCombs(vals_96, ax, 3, lb="2")
    mkStripPlotCombs(vals_40, ax, 2, lb="1")
    mkStripPlotCombs(vals_10, ax, 1, lb="0")

    cls = sns.color_palette("Blues")

    tpoint_handles = [mlines.Line2D([], [], color=cls[1], marker="o",
                             markersize=5, label='0', linestyle=""),
                      mlines.Line2D([], [], color=cls[2], marker="o",
                                    markersize=5, label='1', linestyle=""),
                      mlines.Line2D([], [], color=cls[3], marker="o",
                                    markersize=5, label='2', linestyle=""),
                      mlines.Line2D([], [], color=cls[4], marker="o",
                                    markersize=5, label='3', linestyle=""),
                      mlines.Line2D([], [], color=cls[5], marker="o",
                                    markersize=5, label='4', linestyle="")]
    plt.legend(handles=tpoint_handles,
               title="stage",
               ncol=3,
               loc="lower left",
               columnspacing=0.5,
               handletextpad=0.02,
               borderpad=0.15)
    
    plt.savefig("negative_infl_grates_all_tpoints.svg",
                bbox_inches='tight')

def go():
    tss, linss = lin.mkSeries1(d="../data/FM1/tv/",
      	                       dExprs="../data/geneExpression/",
			       linDataLoc="../data/FM1/tracking_data/",
			       ft=lambda t: t in {10, 40, 96, 120, 132})
    lin.filterL1_st(tss)
    ress = go_combs(tss, linss, seg.geneNms, fgr)

    for t in [10, 40, 96, 120, 132]:
        plot_heatmap(ress, tss, t)

Argyris Z's avatar
Argyris Z committed
    plot_res_combs_all_tpoints(ress)
Argyris Z's avatar
Argyris Z committed
    t=132
    ts = tss[t]
    g = GPairRegions(t, ts, 'ETTIN', 'STM')
    g1 = GPairRegions(t, ts, 'SEP1', 'LFY')
Argyris Z's avatar
Argyris Z committed
    g.plot()
    g.plot_distr()
Argyris Z's avatar
Argyris Z committed
    g1.plot()
    g1.plot_distr()
Argyris Z's avatar
Argyris Z committed


def go_():
    tss, linss = lin.mkSeriesIm0(dataDir="../data/",
                                 ft=lambda t: t in {10, 40, 96, 120, 132})
    lin.filterL1_st(tss)

    t = 132
    ts = tss[t]
    g = GPairRegions(t, ts, 'ETTIN', 'LFY')  # zone 1
    g1 = GPairRegions(t, ts, 'SEP1', 'LFY')  # zone 2
    g2 = GPairRegions(t, ts, 'AP1', 'LFY')   # zone 3

    g.plot()
    g.plot_distr()

    g1.plot()
    g1.plot_distr()

    g2.plot()
    g2.plot_distr()

def go_t120():
    tss, linss = lin.mkSeriesIm0(dataDir="../data/",
                                 ft=lambda t: t in {10, 40, 96, 120, 132})
    lin.filterL1_st(tss)

    t = 120
    ts = tss[t]
    g = GPairRegions(t, ts, 'ETTIN', 'LFY')  # zone 1
    g1 = GPairRegions(t, ts, 'AP1', 'LFY')   # zone 2

    g.plot()
    g.plot_distr()

    g1.plot()
    g1.plot_distr()


def go__():
    tss, linss = lin.mkSeriesIm0(dataDir="../data/",
                                 ft=lambda t: t in {10, 40, 96, 120, 132})
    lin.filterL1_st(tss)

    t = 132
    ts = tss[t]

    r1 = [c.cid for c in ts if c.exprs['LFY']]
    r2 = [c.cid for c in ts if c.exprs['CUC1_2_3']]

    d = dict()
    for c in ts:
        if c.cid in r1:
            d[c.cid] = 1.0
        elif c.cid in r2:
            d[c.cid] = 0.0
        else:
            d[c.cid] = 0.5

    grates = grs.grates_avg_cons(tss, linss)

    plot_ts_q_(ts, d)
    plot_distrs(grates[t], r1, r2, lb="boundary")