import os import math from itertools import permutations, product, repeat, takewhile, groupby from functools import partial, reduce from operator import add, itemgetter from collections import defaultdict from subprocess import call from typing import Dict, List, Tuple import ast import astor import numpy as np import seaborn as sns import matplotlib.pyplot as plt import matplotlib.ticker as ticker import matplotlib as mpl from scipy.signal import argrelmax from scipy.stats import gaussian_kde from scipy.stats import ttest_ind from more_itertools import interleave_longest import common.statesN as f import common.seg as seg import common.edict as edict import common.lin as lin from common.seg import STissue import grates as grs import patternTransitions as p geneNms: List[str] = ['"AG"', '"AHP6"', '"ANT"', '"AP1"', '"AP2"', '"AP3"', '"AS1"', '"ATML1"', '"CLV3"', '"CUC1_2_3"', '"ETTIN"', '"FIL"', '"LFY"', '"MP"', '"PHB_PHV"', '"PI"', '"PUCHI"', '"REV"', '"SEP1"', '"SEP2"', '"SEP3"', '"STM"', '"SUP"', '"SVP"', '"WUS"'] mpl.rcParams.update(mpl.rcParamsDefault) sts_ts = f.states_per_t() fst = itemgetter(0) snd = itemgetter(1) def set_plot_params(fontsize=11): sns.set_style("whitegrid", {'grid.linestyle':'--'}) plt.rc('xtick', labelsize=fontsize) plt.rc('ytick', labelsize=fontsize) plt.rc('axes', labelsize=fontsize) params = {'legend.fontsize': fontsize} plt.rcParams.update(params) class BExpr2(): def __init__(self, e_str): self.e = ge(e_str) def __hash__(self): gns = sorted([hash(g) for g in getGenes(getTopExpr(self.e))]) op = hash(ast.dump(getTopExpr(self.e).op)) return sum(gns) + op def __eq__(self, other): return hash(self) == hash(other) def __repr__(self): import astor return astor.to_source(self.e).replace('"', '').strip() class BExpr2Region(): def __init__(self, t, ts, e_str): self.t = t self.e = ge(e_str) self.ts = ts def __hash__(self): return sum(get_sts_expr(self.ts, self.e, sts_ts[self.t])) def __eq__(self, other): return hash(self) == hash(other) class GPairRegions(): def __init__(self, t, ts, g1, g2): self.t = t self.g1 = g1 self.g2 = g2 self.e1 = ge("'{g1}' and '{g2}'".format(g1=g1, g2=g2)) self.e2 = ge("'{g1}' and not '{g2}'".format(g1=g1, g2=g2)) self.ts = ts def __hash__(self): sts = sts_ts[self.t] sts_e1 = get_sts_expr(self.ts, self.e1, sts) sts_e2 = get_sts_expr(self.ts, self.e2, sts) return (hash(frozenset(sts_e1)) + hash(frozenset(sts_e2))) def __eq__(self, other): return hash(self) == hash(other) def plot(self, txt=""): d = dict() grs_state = get_grs_state() sts = sts_ts[self.t] sts_e1 = get_sts_expr(self.ts, self.e1, sts) sts_e2 = get_sts_expr(self.ts, self.e2, sts) gr_e1 = meanor0(reduce(add, [grs_state[st] for st in sts_e1], [])) gr_e2 = meanor0(reduce(add, [grs_state[st] for st in sts_e2], [])) he, le = [e for e, gr in sorted([(self.e1, gr_e1), (self.e2, gr_e2)], key=snd, reverse=True)] for c in self.ts: if seg.evalB(getTopExpr(he), c.exprs): d[c.cid] = 100.0 else: d[c.cid] = 0.5 for c in self.ts: if seg.evalB(getTopExpr(le), c.exprs): d[c.cid] = -10.0 f = plot_ts_q(self.ts, d, lb="{g1}-{g2}".format(g1=self.g1, g2=self.g2), bounds=(0, 1), txt=txt) return f def plot_distr(self): set_plot_params(fontsize=12) sns.set_style("white") cls = sns.color_palette("coolwarm") grs_state = get_grs_state() sts = sts_ts[self.t] sts_e1 = get_sts_expr(self.ts, self.e1, sts) sts_e2 = get_sts_expr(self.ts, self.e2, sts) grs_e1 = reduce(add, [grs_state[st] for st in sts_e1]) grs_e2 = reduce(add, [grs_state[st] for st in sts_e2]) gr_e1 = np.mean(reduce(add, [grs_state[st] for st in sts_e1])) gr_e2 = np.mean(reduce(add, [grs_state[st] for st in sts_e2])) h_grs, l_grs = [e for e, gr in sorted([(grs_e1, gr_e1), (grs_e2, gr_e2)], key=snd, reverse=True)] he, le = [e for e, gr in sorted([(self.e1, gr_e1), (self.e2, gr_e2)], key=snd, reverse=True)] fig = plt.figure(figsize=(3.5, 3.7)) ax = fig.add_subplot('111') ax.set_xlim(-0.05, 0.18) sns.distplot(h_grs, bins=10, color=cls[-1], kde=True, ax=ax, label=astor.to_source(he).replace("'", "")) sns.distplot(l_grs, bins=10, color=cls[0], kde=True, ax=ax, label=astor.to_source(le).replace("'", "")) ax.set_xlabel(r'$\mu$m/h') plt.legend(fontsize=10, loc='upper right', frameon=False) plt.savefig("{g1}-{g2}-regions.png".format(g1=self.g1, g2=self.g2), dpi=300) plt.show() return def toPair(self): return (self.g1, self.g2) def groupByRegion(t, ts, es: List[str]) -> Dict[int, List[str]]: res = defaultdict(list) regs = set([hash(BExpr2Region(t, ts, e)) for e in es]) reg_ids = {reg:i for i, reg in enumerate(regs)} for e in es: e_rid = reg_ids[hash(BExpr2Region(t, ts, e))] res[e_rid].append(e) return res def mean_gratio(ps): return meanor0([abs(v) for p, v in ps]) def mean_gratio_(ps): return meanor0([v for p, v in ps]) def groupByRegionGPairs(t: int, ts: STissue, ress: Dict[Tuple[str, str], float], n=6, prc=0.5): d = defaultdict(list) ps_nnan = [(GPairRegions(t, ts, p[0], p[1]), v) for p, v in ress if not math.isnan(v)] nps = int(np.ceil(len(ps_nnan)*prc)) ps = sorted(ps_nnan, key=lambda x: abs(x[1]), reverse=True)[:nps] regs = set([hash(p) for p, v in ps]) for p, v in ps: d[hash(p)].append((p.toPair(), v)) d_ = sorted([(k, v) for k, v in d.items()], key=lambda ps: mean_gratio(ps[1]), reverse=True)[:n] return [(i+1, v) for i, (k, v) in zip(range(len(d_)), d_)] def plotRegions(t, ts, d): fns = list() for rid, ps in d: r = GPairRegions(t, ts, ps[0][0][0], ps[0][0][1]) f = r.plot(txt="{rid} ({gr:.3f})".format(rid=str(rid), gr=mean_gratio(ps))) fns.append(f) montage_fimgs(fns, im_label="t{t}_regions".format(t=t)) for fn in fns: os.remove(fn) return "t{t}_regions.png".format(t=t) def cluster_states(grs_state: Dict[int, List[float]]): mgrs = {st:np.mean(grs) for st, grs in grs_state.items()} g = gaussian_kde(np.array(list(mgrs.values()))) xs = np.arange(-0.05, 0.2, 0.001) ys = g.evaluate(xs) cms = argrelmax(ys)[0] xsm = np.array(xs[cms]) ysm = np.array(ys[cms]) cls = {st:np.argmin(np.abs(xsm - grm)) for st, grm in mgrs.items()} return g, xsm, ysm, cls def plot_clustering(g, xsm, ysm): xs = np.arange(-0.05, 0.2, 0.001) ys = g.evaluate(xs) cls = sns.color_palette() plt.plot(xs, ys, color=cls[0]) plt.plot(xsm, ysm, 'o', color=cls[1]) plt.xlabel("growth rate") plt.ylabel("density estimate") plt.savefig("clusters_grates_states.png", dpi=300) plt.show() def get_grs_state() -> Dict[int, List[float]]: tss, linss = lin.mkSeries1(d="../data/FM1/tv/", dExprs="../data/geneExpression/", linDataLoc="../data/FM1/tracking_data/", ft=lambda t: t in {10, 40, 96, 120, 132}) lin.filterL1_st(tss) G = p.mkTGraphN() grates = grs.grates_avg_cons(tss, linss) grates_pats = p.addCombPatterns(grates, tss) grates_state = p.getGAnisosPerPattern(G, grates_pats) return grates_state def ge(s: str): return ast.parse(s) def construct_expr(nms, fns): return ge(" ".join(interleave_longest(nms, fns))) def mk_bexprs(k, geneNms): not_genes = ["not {gn}".format(gn=gn) for gn in geneNms] not_gene_nms = geneNms + not_genes fnss = [['and', 'or']] * (k-1) for gns in permutations(not_gene_nms, k): for fns in product(*fnss): bexpr = construct_expr(gns, fns) yield bexpr return def search_bexprs(ts, k, obj): ress = dict() for i, bexpr in enumerate(mk_bexprs(k, geneNms)): #cids = ts.filterGBExpr(bexpr) ress[bexpr] = obj(cids2=bexpr) return ress def hash_expr(ts, e): return np.sum(ts.filterGBExpr(e)) def get_cids_highg(ts): sts = f.statesI() grs_state = get_grs_state() g, xsm, ysm, cls = cluster_states(grs_state) stsHigh = [st for st, cl in cls.items() if cl == 1] cids = reduce(add, [ts.filterGs(sts[st]) for st in stsHigh]) return cids def get_cids_sts(ts, sts): st_genes = f.statesI() return reduce(add, [ts.filterGs(st_genes[st]) for st in sts]) def bacc(ts, cids1, cids2): scids1 = set(cids1) scids2 = set(cids2) pos = sum([1 for c in ts if c.cid in scids1]) neg = len(list(ts)) - pos tp = sum([1 for c in ts if c.cid in scids1 and c.cid in scids2]) tn = sum([1 for c in ts if c.cid not in scids1 and c.cid not in scids2]) if not pos == 0: tpr = tp / pos else: tpr = 1.0 if not neg == 0: tnr = tn / neg else: tnr = 1.0 return 0.5*(tpr + tnr) def meanor0(xs): if xs: return np.mean(xs) else: return 0.0 def avg_grate(grates, cids2): grs_cids = [grates.get(cid, None) for cid in cids2] grs_cids_ = [gr for gr in grs_cids if gr] return meanor0(grs_cids_) def gratio(xs, xs1): if not xs or not xs1: return 0.0 m1 = np.mean(xs) m2 = np.mean(xs1) return (m1 - m2) / (m1 + m2) def states_d(ts, cids2, sts, grs_state_m): sts_envs = f.statesI_env(seg.geneNms) e = getTopExpr(cids2) mgrs_sts_t = [grs_state_m[st] for st in sts if seg.evalB(e, sts_envs[st])] mgrs_sts_f = [grs_state_m[st] for st in sts if not seg.evalB(e, sts_envs[st])] return gratio(mgrs_sts_t, mgrs_sts_f) def go_search(t, k): tss, linss = lin.mkSeries() lin.filterL1_st(tss) ts = tss[t] cids_pat = get_cids_highg(ts) bacc_obj = partial(bacc, ts=ts, cids1=cids_pat) ress = search_bexprs(ts, k, bacc_obj) return ress def go_search1(t, k): tss, linss = lin.mkSeries() lin.filterL1_st(tss) ts = tss[t] grates = grs.grates_avg_cons(tss, linss) gr_obj = partial(avg_grate, grates=grates[t]) ress = search_bexprs(ts, k, gr_obj) return ress def get_scores(t, es_reg, reg_sts): tss, linss = lin.mkSeries() lin.filterL1_st(tss) ts = tss[t] sts = f.states_per_t()[t] scores = defaultdict(list) for ri, ess in es_reg.items(): for es1 in ess: r_sts = reg_sts[ri] cids1 = get_cids_sts(ts, r_sts) cids2 = ts.filterGBExpr(ge(es1)) s = bacc(ts, cids1, cids2) scores[ri].append((es1, s)) return scores def filterNonEmpty(ts, es): #an expression makes sense #if all the genes are expressed in the tissue es_ = list() for e in es: gs = ["'{g}'".format(g=g.replace("not", "").strip()) for g in getGenes(getTopExpr(ge(e)))] gs_on = [g for g in gs if ts.filterGBExpr(ge(g))] if len(gs) == len(gs_on): es_.append(e) return es_ def const(cids2, y): return y def go_search2(t, k): tss, linss = lin.mkSeries() lin.filterL1_st(tss) ts = tss[t] sts = f.states_per_t()[t] grs_state = get_grs_state() mgrs_state = {st:np.mean(grs) for st, grs in grs_state.items()} st_gr_obj = partial(states_d, ts=ts, sts=sts, grs_state_m=mgrs_state) ress = search_bexprs(ts, k, st_gr_obj) ress_ = pprint_list(ress) max_v = ress_[0][1] min_v = ress_[-1][1] ress_max_elems = list(takewhile(lambda p: p[1] == max_v, ress_)) ress_min_elems = list(takewhile(lambda p: p[1] == min_v, reversed(ress_))) es = [e for e, v in ress_max_elems + ress_min_elems] es_reg = groupByRegion(t, ts, es) for rid, es in es_reg.items(): es_ = [repr(e) for e in set([BExpr2(e) for e in es])] with open("reg{i}_t{t}_exprs.txt".format(i=rid, t=t), "w+") as fout: fout.write("\n".join(es_)) plot_state_mgrs(t) reg_sts = {rid:get_sts_expr(ts, ge(es[0]), sts) for rid, es in es_reg.items()} for rid, es in es_reg.items(): plot_ts_binary(ts, ts.filterGBExpr(ge(es[0])), "reg{i}_t{t}".format(i=rid, t=t), txt="states: " + ", ".join(list(map(str, reg_sts[rid])))) es_reg_ = dict() for rid, es in es_reg.items(): es_reg_[rid] = filterNonEmpty(ts, list(repr(e) for e in set([BExpr2(e) for e in es]))) return ress, es_reg_, reg_sts def go_search3(t, k): tss, linss = lin.mkSeries() lin.filterL1(tss) ts = tss[t] sts = f.states_per_t()[t] def_obj = partial(const, y=0.0) ress = search_bexprs(ts, k, def_obj) ress_ = pprint_list(ress) es = [e for e, v in ress_] es_reg = groupByRegion(t, ts, es) reg_sts = {rid:get_sts_expr(ts, ge(es[0]), sts) for rid, es in es_reg.items()} return es_reg, reg_sts def invert_dlist(d): d_ = dict() for k, xs in d.items(): for x in xs: d_[x] = k return d_ def vis_stripplot(d): fig = plt.figure() ax = fig.add_subplot('111') sns.set_style("whitegrid", {'grid.linestyle':'--'}) xs_labs = sorted(d.keys()) ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05)) xs = reduce(add, [list(repeat(x, len(d[x]))) for i, x in enumerate(xs_labs)]) ys = reduce(add, [d[x] for x in xs_labs]) cls = sns.color_palette() sns.stripplot(xs, ys, palette=[cls[0]], size=5, alpha=0.75, jitter=False) ax.set_xticks(list(range(len(xs_labs)))) ax.set_ylabel("BAcc") ax.set_xlabel("expr length") return ax def vis_stripplot_hue2(d, i): fig = plt.figure(figsize=(12, 6)) ax = fig.add_subplot('111') sns.set_style("whitegrid", {'grid.linestyle':'--'}) xs_labs = sorted(list(set([gn.replace("not", "").strip() for gn in d.keys()]))) xs = reduce(add, [list(repeat(x, len(d[x])*2)) for i, x in enumerate(xs_labs)]) hs = reduce(add, [(list(repeat("g", len(d[x]))) + list(repeat("not g", len(d[x])))) for i, x in enumerate(xs_labs)]) ys = reduce(add, [d[x] + d["not {x}".format(x=x)] for i, x in enumerate(xs_labs)]) cls = sns.color_palette('Blues') cls_not = sns.color_palette('Reds') sns.stripplot(x=xs, y=ys, hue=hs, palette=[cls[i], cls_not[i]], size=4, alpha=0.6, dodge=True) ax.set_xticks(list(range(len(xs_labs)))) ax.set_ylim((0.01, 0.11)) plt.xticks(rotation=90) plt.savefig("grates_gene.png", dpi=300) plt.show() def cellsToDat(fn, cells): #so we can use newman nvars = len(cells[0].exprs.keys()) + 5 ncells = len(cells) ntpoints = 1 header = "\n".join([str(ntpoints), " ".join([str(ncells), str(nvars), "0"])]) with open(fn, 'w+') as fout: fout.write("\n".join([header] + [cell.toOrganism() for cell in cells])) def plot_ts_binary(ts_, cids, lb="region", txt=""): from copy import deepcopy ts = deepcopy(ts_) def toNewmanFn(fn): (fn, ext) = os.path.splitext(fn) return fn + "000" + ".tif" def toNewmanFnPng(fn): (fn, ext) = os.path.splitext(fn) return fn + "000" + ".png" def ind(ts, g): geneNms = list(ts)[0].geneNms return geneNms.index(g) convertCmd = "convert" confFile = "/Users/s1437009/Organism/tools/plot/bin/newmanInit.conf" visCmd = "/Users/s1437009/Organism/tools/plot/bin/newman" lbGene = "ANT" scids = set(cids) for c in ts: if c.cid in scids: c.exprs[lbGene] = True else: c.exprs[lbGene] = False cellsToDat(lb + ".data", list(ts)) call([visCmd, "-shape", "sphereVolume", "-d", "3", lb + ".data", "-column", str(ind(ts, lbGene) + 1), "-schema", str(0), "-output", "tiff", lb, "-camera", confFile, "-size", str(720), "-min", str(-0.03), "-max", str(1.03)]) call([convertCmd, toNewmanFn(lb), "-fill", "white", "-font", "Times-New-Roman", "-pointsize", str(30), "-undercolor", "'#85929E'", "-gravity", "North", "-annotate", "+0+5", "' {txt} '".format(txt=txt.replace("'", "")), "{lb}.png".format(lb=lb)]) os.remove(lb+".data") os.remove(toNewmanFn(lb)) call(["open", "{lb}.png".format(lb=lb)]) def plot_ts_q(ts_, d, lb="vals", bounds=(-0.03, 0.05), txt="", interactive=False, size=100, cm=3): from copy import deepcopy ts = deepcopy(ts_) def toNewmanFn(fn): (fn, ext) = os.path.splitext(fn) return fn + "000" + ".tif" def ind(ts, g): geneNms = list(ts)[0].geneNms return geneNms.index(g) confFile = "/Users/s1437009/Organism/tools/plot/bin/newmanInit.conf" visCmd = "/Users/s1437009/Organism/tools/plot/bin/newman" convertCmd = "convert" lbGene = "ANT" for c in ts: if c.cid in d: c.exprs[lbGene] = d[c.cid] cellsToDat(lb + ".data", list(ts)) call([visCmd, "-shape", "sphereVolume", "-d", "3", lb + ".data", "-column", str(ind(ts, lbGene) + 1), "-schema", str(cm), "-output", "tiff", lb, "-min", str(bounds[0]), "-max", str(bounds[1]), "-size", str(size), "-camera", confFile]) if lb != "": call([convertCmd, toNewmanFn(lb), "-fill", "white", "-font", "Times-New-Roman", "-pointsize", str(15), "-undercolor", "#696969", "-gravity", "North", "-annotate", "+0+0", " {txt} ".format(txt=txt), "{lb}.png".format(lb=lb)]) else: call([convertCmd, toNewmanFn(lb), "{lb}.png".format(lb=lb)]) if interactive: call(["open", "{lb}.png".format(lb=lb)]) os.remove(lb+".data") os.remove(toNewmanFn(lb)) return "{lb}.png".format(lb=lb) def montage_fimgs(fns, im_label="montage"): montageCmd = "montage" n = len(fns) m = np.ceil(n / 3) k = min(n, 3) #collage all the images call([montageCmd] + fns + ["-tile", str(k) + "x" + str(m), "-geometry", "+20+20", im_label + ".png"]) def mark_state_gr(ts, grs_state): st_genes = f.states mgrs = dict([(st, np.mean(grs)) for st, grs in grs_state.items()]) d = dict() for c in ts: d[c.cid] = mgrs.get(st_genes.get(c.getOnGenes(), None), 0.0) return d def plot_state_mgrs(t): tss, linss = lin.mkSeries() lin.filterL1_st(tss) ts = tss[t] grs_state = get_grs_state() d = mark_state_gr(ts, grs_state) plot_ts_q(ts, d, "grates_sts_{t}h".format(t=t)) def get_st_cids_expr(ts, e, sts: List[int]): sts_envs = f.statesI_env(seg.geneNms) sts_genes = f.statesI() e = getTopExpr(e) sts_t = [st for st in sts if seg.evalB(e, sts_envs[st])] cids_t = set(reduce(add, [ts.filterGs(sts_genes[st]) for st in sts_t])) return cids_t def get_sts_expr(ts, e, sts: List[int]) -> List[int]: sts_envs = f.statesI_env(seg.geneNms) sts_genes = f.statesI() e = getTopExpr(e) sts_t = [st for st in sts if seg.evalB(e, sts_envs[st])] return sts_t def showR(r): e, gr = r e = e.replace("'", "") return "{e}, {gr:.2f}".format(e=e, gr=gr) def getGenesOp(e): if type(e.op) is ast.And: return reduce(add, [getGenes(e1) for e1 in e.values]) elif type(e.op) is ast.Or: return reduce(add, [getGenes(e1) for e1 in e.values]) elif type(e.op) is ast.Not: return ["not {gn}".format(gn=reduce(add, getGenes(e.operand)))] def getGenes(e): if type(e) is ast.UnaryOp or type(e) is ast.BoolOp: return getGenesOp(e) elif type(e) is ast.Str: return [e.s] elif type(e) is ast.Name: return [e.id] def getTopExpr(e): return e.body[0].value def expr(r): return r[0] def val(r): return r[1] def getResPerGene(ress): ress_gene = defaultdict(list) for expr, res in ress.items(): for gn in getGenes(getTopExpr(expr)): ress_gene[gn].append(res) return ress_gene def pprint_list(ress): import astor ress_ = sorted([(astor.to_source(k).replace('"', '').strip(), v) for k, v in ress.items()], key=lambda x: x[1], reverse=True) return ress_ def getLinGBExpr(tss, e): return {t: ts.filterGBExpr(e) for t, ts in tss.items()} def getValsT(cidsT, valsT, tpoints=set([10, 40, 96, 120, 132])): return {t:edict.gets(valsT[t], cids) for t, cids in cidsT.items() if t in tpoints} def gratio(xs, xs1): if not xs or not xs1: return np.nan m1 = np.mean(xs) m2 = np.mean(xs1) return (m1 - m2) / (m1 + m2) def fgr(tss, grates, g1, g2): bexpr1 = "'{g1}' and not '{g2}'".format(g1=g1, g2=g2) bexpr2 = "'{g1}' and '{g2}'".format(g1=g1, g2=g2) cidsT = getLinGBExpr(tss, seg.ge(bexpr2)) cidsT1 = getLinGBExpr(tss, seg.ge(bexpr1)) grs = getValsT(cidsT, grates) grs1 = getValsT(cidsT1, grates) return {t:gratio(grs[t], grs1[t]) for t in grs.keys()} def get_vals_exprs(tss, e1, e2, vals): cidsT = getLinGBExpr(tss, seg.ge(e1)) cidsT1 = getLinGBExpr(tss, seg.ge(e2)) grs = getValsT(cidsT, vals) grs1 = getValsT(cidsT1, vals) return grs, grs1 def rgd_gene(tss, g, grates): bexpr1 = "'{g1}'".format(g1=g) bexpr2 = "not '{g1}'".format(g1=g) grs, grs1 = get_vals_exprs(tss, bexpr1, bexpr2, grates) return {t:gratio(grs[t], grs1[t]) for t in grs.keys()} def rgd_gene_(tss, g, grates): bexpr1 = "'{g1}'".format(g1=g) bexpr2 = "not '{g1}'".format(g1=g) grs_t, grs1_t = get_vals_exprs(tss, bexpr1, bexpr2, grates) return gratio(reduce(add, grs_t.values()), reduce(add, grs1_t.values())) def rgd_state(st, grs_state): grs_ = reduce(add, [grs for st_id, grs in grs_state.items() if st_id != st]) gr = grs_state[st] return gratio(gr, grs_) def pval_state(st, grs_state): grs_ = reduce(add, [grs for st_id, grs in grs_state.items() if st_id != st]) gr = grs_state[st] _, pval = ttest_ind(grs_, gr) return pval def pvals_all_states(grs_state): return {st_id:pval_state(st_id, grs_state) for st_id in grs_state.keys()} def rgd_all_states(grs_state): return {st_id:rgd_state(st_id, grs_state) for st_id in grs_state.keys()} def go_single(tss, geneNms): d = {10: dict(), 40: dict(), 96: dict(), 120: dict(), 132: dict()} tss, linss = lin.mkSeries() lin.filterL1_st(tss) grates = grs.grates_avg_cons(tss, linss) for g in geneNms: v = rgd_gene(tss, g, grates) for t, gr in v.items(): d[t][g] = [gr] return d def go_single_(tss, geneNms): d = dict() tss, linss = lin.mkSeries() lin.filterL1_st(tss) grates = grs.grates_avg_cons(tss, linss) for g in geneNms: v = rgd_gene_(tss, g, grates) d[g] = [v] return d def calcTs(tss, grates, gs, f): d = {10: dict(), 40: dict(), 96: dict(), 120: dict(), 132: dict()} for g1, g2 in product(gs, gs): print(g1, g2) v = f(tss, grates, g1, g2) for t, gratio in v.items(): d[t][(g1, g2)] = gratio return d def matrixify(k2_dict): from itertools import product kss = k2_dict.keys() kss1 = sorted(list(set([ks[0] for ks in kss]))) kss2 = sorted(list(set([ks[1] for ks in kss]))) feat_mat = np.zeros((len(kss1), len(kss2))) for (i, k1), (j, k2) in product(list(enumerate(kss1)), list(enumerate(kss2))): feat_mat[i, j] = k2_dict[(k1, k2)] return kss1, kss2, feat_mat def matrixify_df(k2_dict, kss, df=""): from itertools import product kss1 = sorted(kss) kss2 = sorted(kss) feat_mat = np.empty((len(kss1), len(kss2)), dtype=str) for (i, k1), (j, k2) in product(list(enumerate(kss1)), list(enumerate(kss2))): print(i, j, k1, k2) feat_mat[i, j] = k2_dict.get((k1, k2), df) return kss1, kss2, feat_mat def arrange_rows_cols(vals, gs): vals_ = dict() for i, g in enumerate(gs): vals_[g] = list(vals[:, i]) return vals_ def mkStripPlotCombs(vals_, ax, cl, lb="", annot=True): set_plot_params() plt.rcParams['svg.fonttype'] = 'none' d = vals_ from itertools import repeat from functools import reduce from operator import add xs_labs = sorted(d.keys()) ax.yaxis.set_major_locator(ticker.MultipleLocator(0.4)) xs = reduce(add, [list(repeat(x, len(d[x]))) for i, x in enumerate(xs_labs)]) hs = reduce(add, [list(repeat("pos", len(d[x]))) for i, x in enumerate(xs_labs)]) ys = reduce(add, [d[x] for x in xs_labs]) cls = sns.color_palette("Blues") sns.stripplot(xs, ys, palette=[cls[cl]], size=5, label=lb) ax.set_xticks(list(range(len(xs_labs)))) if annot: ax.text(21.3, -0.383, "ETTIN, STM", alpha=0.5, weight='ultralight') ax.text(12.2, 0.36, "SEP1, LFY", alpha=0.5, weight='ultralight') plt.xticks(rotation=90) def go_combs(tss, linss, geneNms, f): grates = grs.grates_avg_cons(tss, linss) d = calcTs(tss, grates, geneNms, f) res = {} for t in d.keys(): rr = sorted([(pair, gr) for pair, gr in d[t].items() if not math.isnan(gr)], key=lambda x: abs(x[1]), reverse=True) rnan = [(pair, gr) for pair, gr in d[t].items() if math.isnan(gr)] res[t] = rr + rnan return res def plotHeatMap(vals, ykeys, xkeys, annot, bounds=(0, 1)): set_plot_params(fontsize=7) fig = plt.figure(figsize=(5, 4)) vmin, vmax = bounds ax = sns.heatmap(vals, yticklabels=ykeys, xticklabels=xkeys, linewidths=.25, cmap='coolwarm', vmin=vmin, vmax=vmax, annot_kws={"size": 7}, annot=annot, fmt='', cbar_kws={"shrink": .5, "label": "RGD"}) ax.set_ylabel("gene A") ax.set_xlabel("gene B") def strip_vals(d): d_ = list() for rid, pvs in d: d_.append((rid, [p for p, v in pvs])) return d_ def plot_heatmap(ress, tss, t): set_plot_params(fontsize=12) gs, gs, vals = matrixify(dict(ress[t])) d_ = groupByRegionGPairs(t, tss[t], ress[t]) d = strip_vals(d_) gs, gs, vals_annot = matrixify_df(invert_dlist(dict(d)), seg.geneNms) plotHeatMap(vals, gs, gs, vals_annot, (-0.4, 0.4)) plt.savefig("grates_gene_pairs_t{t}.svg".format(t=t), dpi=300, bbox_inches='tight') fn = plotRegions(t, tss[t], d_) # call(["open", # fn]) # call(["open", # "grates_gene_pairs_t{t}.png".format(t=t)]) def plot_res_single_all_tpoints(res): set_plot_params(fontsize=11) fig = plt.figure(figsize=(9, 5)) ax = fig.add_subplot('111') ax.set_title("A vs not A") ax.set_xlabel("Gene A") ax.set_ylabel("RGD") ax.set_ylim((-0.6, 0.6)) mkStripPlotCombs(res[132], ax, 4, annot=False) mkStripPlotCombs(res[120], ax, 3, annot=False) mkStripPlotCombs(res[96], ax, 2, annot=False) mkStripPlotCombs(res[40], ax, 1, annot=False) mkStripPlotCombs(res[10], ax, 0, annot=False) plt.savefig("gene_rgd_all_tpoints.png", dpi=300, bbox_inches='tight') def plot_res_single_comb_tpoints(res): set_plot_params(fontsize=11) fig = plt.figure(figsize=(9, 5)) ax = fig.add_subplot('111') ax.set_title("A vs not A (combined timepoints)") ax.set_xlabel("Gene A") ax.set_ylabel("RGD") ax.set_ylim((-0.6, 0.6)) mkStripPlotCombs(res, ax, 4, annot=False) plt.savefig("gene_rgd_all_tpoints_combined.png", dpi=300, bbox_inches='tight') def plot_res_combs_all_tpoints(res): import matplotlib.lines as mlines set_plot_params(fontsize=11) t = 132 gs, gs, vals = matrixify(dict(res[t])) vals_132 = arrange_rows_cols(vals, gs) t = 120 gs, gs, vals = matrixify(dict((res[t]))) vals_120 = arrange_rows_cols(vals, gs) t = 96 gs, gs, vals = matrixify(dict((res[t]))) vals_96 = arrange_rows_cols(vals, gs) t = 40 gs, gs, vals = matrixify((dict(res[t]))) vals_40 = arrange_rows_cols(vals, gs) t = 10 gs, gs, vals = matrixify((dict(res[t]))) vals_10 = arrange_rows_cols(vals, gs) fig = plt.figure(figsize=(9, 5)) ax = fig.add_subplot('111') ax.set_title("(A and B) vs (A and not B)") ax.set_xlabel("Gene B") ax.set_ylabel("RGD") ax.set_ylim((-1.2, 1.2)) mkStripPlotCombs(vals_132, ax, 5, lb="4") mkStripPlotCombs(vals_120, ax, 4, lb="3") mkStripPlotCombs(vals_96, ax, 3, lb="2") mkStripPlotCombs(vals_40, ax, 2, lb="1") mkStripPlotCombs(vals_10, ax, 1, lb="0") cls = sns.color_palette("Blues") tpoint_handles = [mlines.Line2D([], [], color=cls[1], marker="o", markersize=5, label='0', linestyle=""), mlines.Line2D([], [], color=cls[2], marker="o", markersize=5, label='1', linestyle=""), mlines.Line2D([], [], color=cls[3], marker="o", markersize=5, label='2', linestyle=""), mlines.Line2D([], [], color=cls[4], marker="o", markersize=5, label='3', linestyle=""), mlines.Line2D([], [], color=cls[5], marker="o", markersize=5, label='4', linestyle="")] plt.legend(handles=tpoint_handles, title="stage", ncol=3, loc="lower left", columnspacing=0.5, handletextpad=0.02, borderpad=0.15) plt.savefig("negative_infl_grates_all_tpoints.svg", bbox_inches='tight') def go(): tss, linss = lin.mkSeries1(d="../data/FM1/tv/", dExprs="../data/geneExpression/", linDataLoc="../data/FM1/tracking_data/", ft=lambda t: t in {10, 40, 96, 120, 132}) lin.filterL1_st(tss) ress = go_combs(tss, linss, seg.geneNms, fgr) for t in [10, 40, 96, 120, 132]: plot_heatmap(ress, tss, t) plot_res_combs_all_tpoints(ress) t=132 ts = tss[t] g = GPairRegions(t, ts, 'ETTIN', 'STM') g1 = GPairRegions(t, ts, 'SEP1', 'LFY') g.plot() g.plot_distr() g1.plot() g1.plot_distr()