#!/usr/bin/env python

from rpy import *
#from IPython.Shell import IPShellEmbed

class pointset:
    "Class to hold a set of points, primarily for scatterplots in R"
    x = None
    y = None
    col = None      # point color
    pch = None      # point type
    lty = None      # line stype
    ptype = None    # point style
    legend = None
    angle = None    # fill line angle (used in histograms)
    border = None    # border line (used in histograms)

    def __init__( self, x=None, y=None, col=None, pch=None, lty=None,
                  ptype=None, legend=None, angle=None, border=None,
                  cex=None, lwd=None):
        if x: self.x = x
        else: self.x = []
        if y: self.y = y
        else: self.y = []
        if col: self.col = col
        else: self.col = 'blue'
        if pch: self.pch = pch
        else: self.pch = '.'
        if lty: self.lty = lty
        else: self.lty = 1
        if ptype: self.ptype = ptype
        else: self.ptype = 'o'
        if legend: self.legend = legend
        else: self.legend = ''
        if angle: self.angle = angle
        else: self.angle = 0
        if border: self.border = border
        else: self.border = 'blue'
        if cex: self.cex = cex
        else: self.cex=1
        if lwd: self.lwd = lwd
        else: self.lwd=1

class plot:
    "Generic R plot class"

    sets = None
    fname = None
    title = None
    xlab = None
    ylab = None
    x_axis_ticks = None

    def __init__( self, fname, title=None, xlab=None,
                  ylab=None, pointsets=None):
        self.fname = fname
        if title: self.title = title
        else: self.title=''
        if xlab: self.xlab = xlab
        else: self.xlab=''
        if ylab: self.ylab= ylab
        else: self.ylab=''
        if pointsets: self.sets = pointsets
        else: self.sets = []

        self.x_axis_ticks = []

        return
        
    def add_set( self, x, y, col=None, pch=None, lty=None, ptype='p',
                 legend=None, cex=None, lwd=None):
        
        self.sets.append( pointset( x=x, y=y, col=col, pch=pch, lty=lty,
                                    ptype=ptype, legend=legend, cex=cex, lwd=lwd))

    def plot_init( self, ftype='x11', paper=None, width=None,
                   height=None, horizontal=None, pointsize=None):
        "Initialize a new plot file. Width and height are used only for png"
        if horizontal is None: horizontal=False
        if pointsize is None: pointsize=12
        if ftype == 'ps':
            assert paper and width and height, \
                   "Paper, width, height must be set for postscript output: %s, %s, %s" % (
                paper, width, height)
            r.postscript( self.fname+'.ps', paper=paper, width=width,
                          height=height, horizontal=horizontal,
                          pointsize=pointsize)
        elif ftype == 'png':
            assert width and height, \
                   "Width, height must be set for png output: %s, %s" % (width, height)
            r.png( self.fname+'.png', width, height)
        elif ftype == 'x11':
            r.x11()
        else: assert False, "Unknown ftype: %s" % ftype
        # Title text
        # 1=normal text, 2=bold, 3=italic, 4=bold italic, 5=symbol font
        r.par(font_main=1)


    def plot_close( self, ftype):
        "Close a plot file.  (Do not close x11 plots)"
        if ftype != 'x11': r.dev_off()
        
    def plot( self, ftype='x11', xlimit=None, ylimit=None,
              xjitter=0, yjitter=0, grid=None, log='', legend=True,
              paper=None, width=None, height=None, horizontal=None,
              pointsize=None):

        self.plot_init( ftype, paper=paper, width=width, height=height,
                        pointsize=pointsize)

        # find bounds
        xmin = 0
        xmax = 0
        ymin = 0
        ymax = 0
        for (i,set) in enumerate(self.sets):
            if len(set.x) > 0:
                if min( set.x) < xmin: xmin = min(set.x)
                if max( set.x) > xmax: xmax = max(set.x)
                if min( set.y) < ymin: ymin = min(set.y)
                if max( set.y) > ymax: ymax = max(set.y)

        if xlimit is None: m_xlimit = (xmin, xmax)
        else: m_xlimit = xlimit
        if ylimit is None: m_ylimit = (ymin, ymax)
        else: m_ylimit = ylimit
        
        for (i,set) in enumerate(self.sets):
            if i == 0:
                r.plot( r.jitter( set.x, xjitter), r.jitter( set.y, yjitter),
                        xlab = self.xlab, ylab = self.ylab,
                        xlim = m_xlimit, ylim = m_ylimit,
                        log = log,
                        main = self.title,
                        col = set.col, pch = set.pch, lty = set.lty,
                        type = set.ptype, cex=set.cex, lwd=set.lwd)
            else:
                r.points( r.jitter( set.x, xjitter), r.jitter( set.y, yjitter),
                          col = set.col, pch = set.pch, lty=set.lty, type = set.ptype,
                          cex=set.cex, lwd=set.lwd)

        if grid: r.grid( grid[0], grid[1])
        if legend: self.legend()

        self.plot_close( ftype)
        return

    def legend( self, position='topright'):
        labels = []
        cols = []
        pchs = []
        ltys = []
        for set in self.sets:
            labels.append(set.legend)
            cols.append(set.col)
            pchs.append(set.pch)
            ltys.append(set.lty)
        r.legend(position, labels, col=cols, pch=pchs, lty=ltys)
        return

class densityplot(plot):
    "Plot the density of set.x"

    def add_set( self, x, col=None, pch=None, lty=None, ptype='p', legend=None,
                 cex=None, lwd=None):
        self.sets.append( pointset( x=x, col=col, pch=pch, lty=lty,
                                    ptype=ptype, legend=legend,
                                    cex=cex, lwd=lwd))

    def plot( self, ftype='x11', xlimit=None, ylimit=None,
              grid=None, log='', legend=True, bwadjust=0.05,
              paper=None, width=None, height=None, horizontal=None,
              pointsize=None):
        self.plot_init( ftype, paper=paper, width=width, height=height,
                        pointsize=pointsize)

        # find x bounds
        xmin = 0
        xmax = 0
        ymin = 0
        ymax = 0
        densities = {}
        for (i,set) in enumerate(self.sets):
            if len(set.x) > 0:
                densities[i] = r.density(set.x, adjust=bwadjust)

                if min( set.x) < xmin: xmin = min(set.x)
                if max( set.x) > xmax: xmax = max(set.x)
                if max( densities[i]['y']) > ymax: ymax = max( densities[i]['y'])
            else:
                densities[i] = None

        if xlimit is None: m_xlimit = (xmin, xmax)
        else: m_xlimit = xlimit
        if ylimit is None: m_ylimit = (ymin, ymax)
        else: m_ylimit = ylimit

        for (i,set) in enumerate(self.sets):
            if i == 0:
                r.plot( densities[i],
                        xlab = self.xlab, ylab = self.ylab,
                        xlim = (xmin,xmax),
                        ylim = (0,ymax),
                        log = log,
                        main = self.title,
                        col = set.col, pch = set.pch, lty=set.lty,
                        type = set.ptype, cex=set.cex, lwd=set.lwd)
            else:
                r.points( densities[i],
                          col = set.col, pch = set.pch, lty=set.lty, type = set.ptype,
                          cex=set.cex, lwd=set.lwd)

        if legend: self.legend(position='topright')
        if grid: r.grid( grid[0], grid[1])

        self.plot_close( ftype)
        return

class histogramplot(plot):
    "Plot the histogram of set.x"


    def add_set( self, x, col=None, pch=None, ptype='p', lty=None, legend=None, angle=None,
                 border=None, cex=None, lwd=None):
        self.sets.append( pointset( x=x, col=col, pch=pch,
                                    ptype=ptype, legend=legend, angle=angle,
                                    border=border, cex=cex, lwd=lwd))
        return

    def add_x_axis_ticks( self, tick_list, col="blue", length=1 ):
        self.x_axis_ticks.append( (tick_list, col, length) )
        pass


    def plot( self, ftype='x11', xlimit=None, ylimit=None,
              grid=None, legend=True, breaks='Sturges',
              paper=None, width=None, height=None, horizontal=None,
              pointsize=None):

        self.plot_init( ftype, paper=paper, width=width, height=height,
                        pointsize=pointsize)

        # find x bounds
        xmin = 0
        xmax = 0
        ymin = 0
        ymax = 0
        hists_r = {}
        hists = {}
        for (i,set) in enumerate(self.sets):
            if len(set.x) > 0:
                if min( set.x) < xmin: xmin = min(set.x)
                if max( set.x) > xmax: xmax = max(set.x)

                # hists_r contains the actual r object, rather than
                # converting it to python, and later back. This way,
                # the object is of class "histogram", so the correct
                # plot functions (plot.histogram and lines.histogram)
                # are called"
                hists_r[i] = with_mode(NO_CONVERSION,
                                     r.hist)(set.x, breaks=breaks, plot=False)
                hists[i] = r.hist(set.x, breaks=breaks, plot=False)

                if max( hists[i]['counts']) > ymax: ymax = max( hists[i]['counts'])
            else:
                densities[i] = None

        if xlimit is None: m_xlimit = (xmin, xmax)
        else: m_xlimit = xlimit
        if ylimit is None: m_ylimit = (0, ymax)
        else: m_ylimit = ylimit

        for (i,set) in enumerate(self.sets):
            if i == 0:
                r.plot( hists_r[i],
                        xlab = self.xlab, ylab = self.ylab,
                        xlim = m_xlimit,
                        ylim = m_ylimit,
                        main = self.title,
                        lty = set.lty,
                        border = set.col,
                        col = set.col,
                        density=0,
                        angle = set.angle,
                        cex = set.cex,
                        lwd = set.lwd)
            else:
                r.lines( hists_r[i], border=set.col, col = set.col, lty=set.lty,
                         density=0, angle=set.angle, cex=set.cex, lwd=set.lwd)
                pass
            pass

        for (tick_list, col, length) in self.x_axis_ticks:
            r.axis(1,
                   at = tick_list,
                   labels = False,
                   lwd=1,
                   cex_axis=1,
                   col = col,
                   tcl = -length,
                   line = 0
                   )
            r.axis(1,
                   at = [0, 1],
                   labels = False,
                   lwd=1,
                   cex_axis=1,
                   col = "black",
                   tcl = 0, #tick height?
                   line = 0 #line displacement
                   )
            pass
        
        if legend: self.legend(position='topright')
        if grid: r.grid( grid[0], grid[1])

        self.plot_close( ftype)

        return


class scatterplot(plot):
    "Class to handle scatterplots in R"
    pass

    
    
