#!/usr/bin/env python
# Jacob Joseph
# 2/23/2007
# Utilities to plot graph statistics

from rpy import *
from matplotlib import pyplot
import bisect
#import networkx as NX
#from JJutil import pickler, rate
#from IPython.Shell import IPShellEmbed


####################################
# Draw Plots
####################################
def plot_graph_k_histograms( net_stats, statname, bitthresh=300, ncthresh=0.999,
                            fprefix='unk', ftype='x11'):
    """For particular bitscore and nc thresholds, plot k vs val from a
dictionary {k:val}"""
    x = {}
    y = {}
    
    for (stype,thresh) in (('bit_score',bitthresh),('nc_score',ncthresh)):
        stats = net_stats[stype][thresh][statname]
        keys = sorted( stats.keys())
        x[stype] = keys
        y[stype] = [ stats[k] for k in keys]

    if ftype=='ps': r.postscript(fprefix+'_'+statname+'_'+str(bitthresh)+'_'+str(ncthresh)+'_k.ps',
                                 paper='letter', width=10)
    elif ftype=='png':
        r.png(fprefix+'_'+statname+'_'+str(bitthresh)+'_'+str(ncthresh)+'_k.png',
              width=640,height=480)
    else: r.x11()

    if statname=='cc_hist':
        xlim=(1,1000)
        ylim=(0.1,1)
    else:
        xlim=(1,1000)
        ylim=(1,10000)
    # 10 tickmarks on each axis
    r.par(lab=(10,10,0))
    r.plot(x['bit_score'], y['bit_score'], col='blue', type='p',
           pch=20, xlab='k', ylab=statname, main=statname,
           #ylim=(min(y['bit']+y['nc']), max(y['bit']+y['nc'])),
           #ylim=(0,500),
           #xlim=xlimit
           log='xy',xlim=xlim,ylim=ylim)
    r.grid(equilogs=False)
    r.lines(x['nc_score'], y['nc_score'], col='green', type='p', pch=20)
    r.legend('bottomleft', legend=['bit '+str(bitthresh),'nc '+str(ncthresh)],
             col=['blue','green'],lty=1)
    if ftype!='x11': r.dev_off()
    return



class nxplot:
    net_stats = None

    def __init__(self, net_stats):
        self.net_stats = net_stats

    def plot_init(self, ftype, fbasename):
        if ftype=='ps': r.postscript(fbasename+'.ps',
                                     paper='special', width=7.5, height=5.5, horizontal=False)
        elif ftype=='pdf':

            #r("""pdfFonts( 'CM' = Type1Font("CM",
            #                 c("cm-lgc/fonts/afm/public/cm-lgc/fcmr8a.afm",
            #                  "cm-lgc/fonts/afm/public/cm-lgc/fcmb8a.afm",
            #                  "cm-lgc/fonts/afm/public/cm-lgc/fcmri8a.afm",
            #                  "cm-lgc/fonts/afm/public/cm-lgc/fcmbi8a.afm")))""")

            r.pdf(fbasename+'.pdf',
                  paper='special',
                  pointsize=8,
                  #family='CM',
                  width=3.25, height=2.5)
            #print r.par(omd=(0,0,1,1))
            #r.par(mai=(.4,.2,.4,.05)) # bottom, left, top, right
            #r.par(mgp=(0,0,0)) 
            #r.par(tck=.018) 

            
        elif ftype=='png': r.png(fbasename+'.png',
                                 width=700, height=500)
        elif ftype=='x11': r.x11()
        else: assert False, "unknown ftype: %s" % ftype
        return

    def plot_close(self, ftype):
        if ftype!='x11': r.dev_off()        

    def build_scatter_data(self, runs, stat_name, base_stype=None,
                           top_transform=None, absolute=False):
        stypes = list( self.net_stats.get_stypes(runs) )
        assert len(stypes)<=2, "Unimplemented when len(stypes) > 2. stypes: %s" % stypes

        # transformation of the upper axis relative to the
        # bottom. Defaults to identity, but a log could be useful.  It
        # should return a float, or None if no transformation exists
        if top_transform is None:
            top_transform = lambda a: float(a)

        # assert type( top_transform(1))==float, "top_transform() should return a float"
        
        # Set the stype axis order, if specified
        if base_stype is not None:
            assert base_stype in stypes, "base_stype: %s not in stypes: %s" % (
                base_stype, stypes)
            stypes.remove(base_stype)
            stypes = [base_stype] + stypes

        # construct top and bottom axes, retaining only those in the
        # top axis that return a float from top_transform
        x_base = []
        x_top = []
        x_top_trans = []
        for run in runs:
            if run.stype==stypes[0]:
                x_base.append( run.min_score)
            else:
                trans = top_transform( run.min_score)
                if trans is not None:
                    x_top.append( run.min_score)
                    x_top_trans.append( trans)

        # Sort the axes
        x_base.sort()
        # sort by transformed value
        x_top_tup = zip(x_top, x_top_trans)
        x_top_tup.sort(key=lambda a: a[1])
        x_top, x_top_trans = zip( *x_top_tup)  # unzip

        xlim_base = (x_base[0], x_base[-1])
        xlim_top = ( x_top[0], x_top[-1])

        # Top axis scaling, and tick positions
        ax_top_scale = None
        ax_top_labelat = None
        if len(stypes) == 2:
            ax_top_scale = float(xlim_base[1] - xlim_base[0]) / (
                top_transform(xlim_top[1]) - top_transform(xlim_top[0]))

            ax_top_offset = xlim_base[0] - top_transform( xlim_top[0]) * ax_top_scale

            # Upper x-axis tick position labels
            ax_top_labelat = []
            for run in runs:
                if run.stype==stypes[0]: continue   # bottom axis

                pos = top_transform( run.min_score)
                if pos is None: continue
                
                if not absolute: pos = pos * ax_top_scale + ax_top_offset

                ax_top_labelat.append( (pos, run.min_score))

        # Values of this statistic for each threshold of this run
        # { (stype, run_id): [(min_score,y), ...] }
        points = {}
        ylim = [None, None]
        for run in runs:
            if not (run.stype, run.run_id, run.descr) in points:
                points[ (run.stype, run.run_id, run.descr)] = []

            # possibly scale the threshold for plotting
            if run.stype == stypes[0]: xpos = run.min_score
            elif absolute: xpos = top_transform( run.min_score)
            else: xpos = top_transform( run.min_score) * ax_top_scale + ax_top_offset

            if xpos is None: continue

            stat = self.net_stats.get_stat(stat_name, run=run)

            # y-axis limits
            if ylim[0] is None or ylim[0] > stat: ylim[0] = stat
            if ylim[1] is None or ylim[1] < stat: ylim[1] = stat

            points[ (run.stype, run.run_id, run.descr) ].append( (xpos, stat))

        # sort the x-axis
        for key in points:
            points[key].sort( key=lambda a:a[0])

        return (xlim_base, xlim_top, ax_top_labelat, ylim, points)


    def plot_statistic( self, stat_name, runs=None, fprefix='unk', ftype='x11',
                        base_stype=None, ylimit=None,
                        top_transform=None, absolute=False,
                        legend_position='bottom', draw_title=True,
                        lwd=None, ylab=True, draw_legend=True,
                        ylog=False, colarr=None):
        if runs is None: runs = self.net_stats.get_runs()
        if ylimit is not None: ylim = ylimit
        if base_stype is None: base_stype = runs[0].stype

        if colarr is None:
            colarr = ['blue', 'brown', 'green', 'orange', 'yellow',
                    'brown', 'black', 'violet']
        
        if lwd is None: lwd=1

        (xlim_base, xlim_top, ax_top_labelat, ylim, points) = self.build_scatter_data(
            runs, stat_name, base_stype=base_stype,
            top_transform=top_transform, absolute=absolute)

        self.plot_init(ftype, fprefix+'_'+stat_name)

        # vertical labels, 10 grid lines
        r.par(lab=(10,10,0))

        # tight margins
        r.par(mai=(.26,.18,.26,.05)) # bottom, left, top, right
        r.par(mgp=(0,0,0)) 
        r.par(tck=.018) 

        #print r.par('oma','mar','mai','mgp','pty')
        #print r.par('cex','cex.axis','cex.lab','font','font.axis')
        #print r.par('tck')

        if draw_title: title=stat_name
        else: title=''

        if ylab:
            if stat_name=='ccomp_avg_density_w':
                ylab = "Component density"
                legend_position='bottomright'
            elif stat_name=='ccomp_num':
                ylab = "Component count"
            elif stat_name=='graph_mean_cc':
                ylab = "Clustering coefficient"
            else:
                ylab = stat_name
        else:
            ylab = ''

        legend = []
        created_base_axis = False
        created_top_axis = False
        for (i, (stype, run_id, descr)) in enumerate(points):
            x = [ xy[0] for xy in points[(stype, run_id, descr)] ]
            y = [ xy[1] for xy in points[(stype, run_id, descr)] ]

            if stype=='bit_score':
                lstr = "Sequence similarity"
            elif stype=='nc_score' and run_id==731:
                lstr = "NC: S cerevisiae"
            elif stype=='nc_score' and run_id==734:
                lstr = "NC: Four genomes"
            elif stype=='nc_score' and run_id==735:
                lstr = "NC: Nine genomes"
            elif stype=='nc_score' and run_id==808:
                lstr = "NC"

            else:
                lstr = "%s: (%s %d)" % (descr, stype, run_id)
            legend.append(lstr)

            if i==0:
                # Plot twice, to put the grid in the background
                r.plot(x, y, col=colarr[i], type='o', pch=20,
                       #xlab=base_stype.replace('_',' '),
                       xlab='',
                       ylab=ylab, main=title,
                       xlim=xlim_base, ylim=ylim, lwd=lwd,
                       log='y' if ylog else ''
                       )
                if ylog:
                    r.grid(lty='dotted',lwd=1,equilogs=False)
                else:
                    r.grid(lty='dotted',lwd=1)
                    
                r.lines(x, y, col=colarr[i], type='o', pch=20,
                        lwd=lwd)
                       #xlab=base_stype.replace('_',' '),
                       #xlab='',
                       #ylab=ylab, main=title,
                       #xlim=xlim_base, ylim=ylim, lwd=lwd)

            else:
                r.lines(x, y, col=colarr[i], type='o', pch=20, lwd=lwd)

            # Two x-axes
            if not created_base_axis and stype == base_stype:
                r.axis(1, at=x, labels=False, col='black')

                if stype=='nc_score': xlab = "NC score"
                elif stype=='bit_score': xlab = "Bit score"
                else: xlab=stype.replace('_', ' ')
                r.mtext(xlab, side=1, padj=1.2, col='black')
                
                created_base_axis = True
            elif not created_top_axis and stype != base_stype:
                # FIXME: Add top x-axis labels
                xat, xlab = zip( *ax_top_labelat)
                r.axis(3, at=xat, labels=xlab, col='black')

                if stype=='nc_score': xlab = "NC score"
                elif stype=='bit_score': xlab = "Bit score"
                else: xlab=stype.replace('_', ' ')
                r.mtext(xlab, side=3, padj=-1.2, col='black')

                created_top_axis = True

        if draw_legend:
            r.legend(legend_position, legend=legend, col=colarr[:len(legend)], lty=1, pch=20,
                     bty='o', bg="#00000010", box_lty=0,
                     inset=(0.03, 0.03))
        
        self.plot_close( ftype)

        return
        
    
    def plot_cc_histogram(self, run, stat='size', fprefix='unk', ftype='x11',
                             xlimit=None, ylimit=None):
        """Draw a histogram of either component 'size' or 'density'"""

        self.plot_init(ftype, "%s_%s_%d_%s" % (fprefix, run.stype, run.run_id, stat))

        # connected component sizes and densities
        sizeden = self.net_stats.get_stat( 'ccomp_sizedensities', run=run)
        
        if stat == 'size':
            x = sorted(sizeden.keys())
            y = [ len(sizeden[a]) for a in x]
            
        elif stat == 'density':
            # connected component densities
            density_dict = {}
            for cc_densities in sizeden.itervalues():
                for cc_density in cc_densities:
                    if not density_dict.has_key(cc_density): density_dict[cc_density] = 0
                    density_dict[cc_density] += 1

            x = sorted( density_dict.keys())
            y = [ density_dict[a] for a in x]

        # 10 tickmarks on each axis
        r.par(lab=(10,10,0))

        r.plot(x, y, col='brown', type='o', pch=20,
               ylab='count', xlab=stat,
               main="%s: (%s, %d, %g) %s" % (run.descr, run.stype, run.run_id,
                                             run.min_score, stat),
               xlim=xlimit, ylim=ylimit)
        r.grid()

        self.plot_close( ftype)
        return


    def plot_cc_scatter(self, run, fprefix=None,
                        xlimit=None, ylimit=None):
        """Draw a scatterplot of density vs size for a particular run"""

        #self.plot_init(ftype, "%s_%s_%d_sizedens" % (fprefix, run.stype, run.run_id))
        self.mplot_init(fprefix)

        # connected component sizes and densities
        sizeden = self.net_stats.get_stat( 'ccomp_sizedensities', run=run)

        s_d_count = {}
        for size in sizeden.keys():
            for density in sizeden[size]:
                if (size,density) not in s_d_count:
                    s_d_count[ (size,density)] = 0
                s_d_count[ (size,density)] += 1

        cnt_min = min(s_d_count.values())
        cnt_max = max(s_d_count.values())

        x = []
        y = []
        cols = []
        for ((size, density), cnt) in s_d_count.items():
            x.append(size)
            y.append(density)
            cols.append( cnt)

        pyplot.scatter(x, y, c=cols)
        pyplot.axis('tight')
        pyplot.title("%s: (%s, %d, %g)" % (run.descr, run.stype, run.run_id,
                                           run.min_score))
        pyplot.xlabel('Component Size')
        pyplot.ylabel('Component Density')
        pyplot.grid(color='gray', alpha=0.5)
        if xlimit: pyplot.xlim(*xlimit)
        if ylimit: pyplot.ylim(*ylimit)

        pyplot.subplots_adjust(hspace=0.3, bottom=0.08, left=0.06, top=0.95, right=0.98)
        pyplot.colorbar(pad=0, aspect=40)
        
        self.mplot_close(fprefix, "_%s_%d_%s" % (run.stype, run.run_id, run.min_score))
        
        return

    def mplot_init(self, fprefix, figsize=(12,9)):
        if fprefix is not None:
            pyplot.ioff() # turn off figure updates
            
        pyplot.rc('figure', figsize=(12,9))
        return
        
    def mplot_close(self, fprefix, f_name, dpi=100):
        if fprefix is not None:
            pyplot.savefig( fprefix+f_name+'.png', format='png', dpi=dpi)
            pyplot.close()
        #else:
        #    #pyplot.show()
        return
    

    def plot_equivalent_x(self, stat_names, equiv_runs=None, fprefix=None,
                          dpi=100, symbols=None, draw_legend=True,
                          xlimit=None, ylimit=None):

        if symbols is None:
            symbols = ['b', 'g', 'r', 'c', 'y', 'm', 'k']

        runs = [ self.net_stats.get_runs( stype_run_id) for stype_run_id in equiv_runs]

        self.mplot_init(fprefix)

        for (i,stat_name) in enumerate(stat_names):
            xy0 = self.build_stat_data( runs[0], stat_name)
            xy1 = self.build_stat_data( runs[1], stat_name)
            
            (equiv_0, equiv_1) = self.build_equiv_x( xy0, xy1)

            pyplot.plot( equiv_0, equiv_1, symbols[i], label=stat_name)

        pyplot.xscale('log', basex=10)
        pyplot.xticks( xy0[0], xy0[0], rotation='vertical')
        pyplot.yticks( xy1[0], ['%0.2f' % f for f in xy1[0]])

        pyplot.axis('tight')
        pyplot.title('Equivalent Scores: %s' % equiv_runs)
        pyplot.xlabel( "%s (%d) threshold" % (equiv_runs[0][0], equiv_runs[0][1]))
        pyplot.ylabel( "%s (%d) threshold" % (equiv_runs[1][0], equiv_runs[1][1]))
        if xlimit is not None: pyplot.xlim( xlimit)
        if ylimit is not None: pyplot.ylim( ylimit)
        pyplot.grid(color='gray', alpha=0.5)

        pyplot.subplots_adjust(hspace=0.3, bottom=0.09, left=0.07, top=0.95, right=0.98)
        
        if draw_legend:
            l = pyplot.legend(pad=0, loc=0)   # no extra space, 'best' location
            l.legendPatch.set(fill=True, facecolor='gray', edgecolor='gray', alpha=0.7)

        self.mplot_close(fprefix, '_equivalence.png', dpi=dpi)

        return

    def build_equiv_transform(self, stat_name, equiv_runs=None):

        runs = [ self.net_stats.get_runs( stype_run_id) for stype_run_id in equiv_runs]
        xy0 = self.build_stat_data( runs[0], stat_name)
        xy1 = self.build_stat_data( runs[1], stat_name)
        (equiv_0, equiv_1) = self.build_equiv_x( xy0, xy1)

        # for a given point in equiv_0, find the corresponding point in equiv_1
        # linearly interpolate when necessary.

        def transform( x0):

            i = bisect.bisect_left( equiv_0, x0)

            if i < len(equiv_0) and equiv_0[i] == x0:
                return equiv_1[i]
            
            elif i==0 or i==len(equiv_0):
                # We are not within the range of equivalence
                return None
            
            else:
                # interpolation
                m = float(equiv_1[i+1] - equiv_1[i]) / (equiv_0[i+1] - equiv_0[i])
                return equiv_1[i] + m * (x0 - equiv_0[i])

        return transform
        

    def build_stat_data(self, runs, stat_name):
        """Be sure to call this function with on run classes of one stype, run_id"""

        xy = []
        for run in runs:
            xy.append( (run.min_score, self.net_stats.get_stat(stat_name, run=run)) )

        # sort by x-axis
        xy.sort(key=lambda a: a[0])
        # unzip
        x, y = zip(*xy)
        
        return x, y

    def build_equiv_x( self, xy0, xy1):
        """INPUT: xy0 = (xlist, y_list), xy1 = (xlist, ylist)

        Produce a set of corresponding x-axis values between two functions.

        For each point in xy0, trace along the curve xy1 for the same
        y value (linearly interpolating as necessary), and store the
        correponding x-values: x0 and x1.

        Note that this function will only behave properly when xy1 is
        monotonic in the range of xy0.
        """

        # make lists of ((x,y),(x,y),...)
        xy0 = zip(*xy0)
        xy1 = zip(*xy1)

        # Sort by x-value.  Step through both arrays by inreasing x
        xy0.sort(key=lambda a: a[0])
        xy1.sort(key=lambda a: a[0])

        # find bounds on xy1 curve
        min_xy1_y = min( [a[1] for a in xy1])
        max_xy1_y = max( [a[1] for a in xy1])

        equiv_0 = []
        equiv_1 = []
        for (x_0, y_0) in xy0:
            # out of y-axis range of xy1
            if y_0 < min_xy1_y or y_0 > max_xy1_y:
                continue

            # (slowly) search for the proper nc y-axis interval, and
            # linearly interpolate between two
            # remember the nc values are sorted in increasing x order
            x_prev = xy1[0][0]
            y_prev = xy1[0][1]
            for (x_1, y_1) in xy1[1:]:
                if (y_prev <= y_0 <= y_1) or (y_prev >= y_0 >= y_1):
                    equiv_0.append(x_0)
                    equiv_1.append(float(x_prev) + (float(x_1 - x_prev)
                                     * (y_0 - y_prev) / (y_1 - y_prev)))
                    break
                else:
                    x_prev=x_1
                    y_prev=y_1

        return (equiv_0, equiv_1)
