#!/usr/bin/env python

# Jacob Joseph
# 29 December 2008

# A test of clique recovery by NC.  Basic principle: construct a set
# of cliques, randomly add some edges, randomly remove some others.
# Calculate unweighted NC.  Calculate network statistics before and
# after NC.

import sys, os, time, cProfile
import random, math, numpy
import networkx as NX
from JJnetstat import stathelper
from matplotlib import pyplot
from JJutil import pickler

sys.path.append(os.path.expandvars('$HOME/Durand/neighborhood_correlation'))
import nc_base

class randnetwork( object):
    G = None
    G_orig = None
    clique_dist = None
    
    
    def __init__(self, clique_distribution):
        """clique_distribution = { size: count, ...}"""

        assert NX.__version__=="1.6", "Networkx version %s not tested" % NX.__version__

        self.clique_dist = clique_distribution
        self.construct_net()

        self.G = self.G_orig.copy()

    def construct_net(self):
        self.G_orig = NX.Graph()

        node_i = 0
        for (size, count) in self.clique_dist.items():
            for i in range(count):
                clique_nodes = range(node_i, node_i+size)
                node_i = node_i + size
                for (j,n0) in enumerate(clique_nodes):
                    for n1 in clique_nodes[j+1:]:
                        self.G_orig.add_edge( n0, n1)


    def randomize(self, p_del, p_add=None, simulate = False):
        """Add and remove edges in the network.  Consider every
        possible edge, and apply this probability directly.  To retain
        the same number of edges, in expectatation, the probabilities
        should be related as:

        Pdel is the probability of removing a specific edge.  The
        expected number of removals is Pdel*M.

        Padd is the probability of adding an edge, out of all possible
        edges not already in the network, before deletion.
        
        Padd * (N(N-1)/2 - M) = Pdel * M
        Padd = Pdel * M / (N(N-1) / 2 - M)"""

        self.G = self.G_orig.copy()

        nodes = self.G.nodes()
        edges = set(self.G.edges())

        # if p_add isn't specified, try to keep the expected number of
        # edges constant
        if p_add is None:
            N = float(len(nodes))
            M = float(len(edges))
            p_add = p_del * M / (N*(N-1) / 2 - M)

            print "N, M:", N, M
            print "p_add, p_del:", p_add, p_del
            print "expected add / delete: %g / %g" % ( p_add * (N*(N-1) / 2 - M),
                                                       p_del * M)

        del_edges = self.del_random( p_del, edges=edges)
        add_edges = self.add_random_fast( p_add, nodes, edges)
        #add_edges = self.add_random( p_add, nodes, edges)

        if not simulate:
            self.G.remove_edges_from( del_edges)
            self.G.add_edges_from( add_edges)

        print "added / deleted edges: %d / %d" % (len(add_edges), len(del_edges))
        return (len(add_edges), len(del_edges))

    def test(self, p_del, n=100):

        nodes = self.G.nodes()
        edges = set(self.G.edges())

        N = float(len(nodes))
        M = float(len(edges))

        print "N, N(N-1)/2:", N, N*(N-1)/2
        print "M:", M
        print "expected (M * p_del):", p_del * M

        p_add_direct = p_del * M / (N*(N-1) / 2 - M)

        print "p_add_direct: %g" % p_add_direct

        #p_add_fast = p_add_direct * ( (N*(N-1) / 2)) / (N*(N-1)/2 - M)
        p_add_fast =p_del * M / (N*(N-1) / 2 - M)

        print "p_add_fast: %g" %  p_add_fast

        sum_direct = 0.0
        sum_fast = 0.0
        for i in range(n):
            #sum_direct += len(self.add_random( p_add_direct, nodes, edges))

            sum_fast += len(self.add_random_fast( p_add_fast, nodes, edges))

        print "actual added direct:", sum_direct / n
        print "actual added fast:", sum_fast / n

        return

    def add_random(self, p_add, nodes = None, edges = None):
        """Return a set of edges to add"""
        rand = random.random
        add_edges = set()
        if nodes is None: nodes = self.G.nodes()
        if edges is None: edges = set(self.G.edges())

        for (i, n0) in enumerate(nodes):
            for n1 in nodes[i+1:]:
                # n0 < n1 in this case
                if (n0, n1) not in edges and (n1,n0) not in edges and rand() < p_add:
                    add_edges.add( (n0, n1))
        return add_edges


    def add_random_fast(self, p_add, nodes = None, edges = None):
        """See networkx fast_gnp_random_graph, or

            Batagelj and Brandes, "Efficient generation of large random networks",
            Phys. Rev. E, 71, 036113, 2005.
        """

        rand = random.random
        add_edges = set()
        if nodes is None: nodes = self.G.nodes()
        if edges is None: edges = set(self.G.edges())

        n = len(nodes)   # total nodes

        v = 1            # Nodes in graph are from 0,n-1 (this is the second node index).
        w = -1
        lp = math.log(1.0 - p_add)

        n_attempt = 0
        n_add = 0

        while v < n:
            lr = math.log( 1.0 - rand())
            w = w + 1 + int(lr/lp)        # expected interval between added edges

            while w >= v and v < n:
                w = w - v
                v = v + 1

            if v < n:
                n_attempt += 1
                n0 = nodes[w]
                n1 = nodes[v]

                #assert n0 != n1, "self-edge: w:%s, v:%s, (%s, %s)" % (w,v,n0,n1)
                assert n0 < n1, "edge order unexpected: w:%s, v:%s, (%s, %s)" % (w,v,n0,n1)

                # n0 < n1, always
                #if (n0, n1) not in edges and (n1, n0) not in edges:
                if (n0, n1) not in edges:
                    n_add += 1

                    #print "new edge:",  (n0, n1)

                    # if not self.G.has_edge( (n0, n1)):
                    add_edges.add( (n0, n1) )

        #print "n_attempt, n_add:", n_attempt, n_add, float(n_add) / n_attempt
        
        return add_edges


    def del_random(self, p_del, edges=None):
        """Return a set of edges to delete"""
        rand = random.random
        del_edges = []

        if edges is None: edges = self.G.edges_iter()
        for (n0, n1) in edges:
            if rand() < p_del:
                del_edges.append( (n0, n1))
        return del_edges


    def thresh_weighted_graph(self, G, thresh):
        """Retain only edges with weight >= thresh.  Produce an
        unweighted graph."""

        G_new = NX.Graph()
        G_new.add_nodes_from( G.nodes())

        #for (n0, n1, weight) in G.edges_iter(data=True):
        #    if weight >= thresh:
        #        G_new.add_edge( n0, n1)
        for (n0, nbrdict) in G.adjacency_iter():
            for (n1, edict) in nbrdict.items():
                if edict['weight'] >= thresh:
                    G_new.add_edge(n0, n1)
        
        return G_new

    def thresh_weighted_graph_faster(self, G, thresh):
        """Retain only edges with weight >= thresh.  Produce an
        unweighted graph.  This may not actually be faster."""

        G_new = NX.Graph()
        G_new.add_nodes_from( G.nodes())
        #G_new = G.copy()

        for (n0, n1, edict) in G.edges_iter(data=True):

            print "n0, n1, weight", n0, n1, edict['weight']

            if edict['weight'] >= thresh:
                G_new.add_edge( n0, n1)
        
        return G_new

        
    def unweighted_nc(self, nc_thresh=0.05):
        """Calculate unweighted NC.  Produces a weighted graph."""

        G_new = NX.Graph()
        G = self.G

        nodes = G.nodes()
        #degree = G.degree(with_labels=True)

        N = len(nodes)

        # add all nodes
        G_new.add_nodes_from( nodes)

        # NC need only be calculated within connected components
        ccomps = NX.connected_components(G)
        #print "Connected Components (before):", [len(ccomp) for ccomp in ccomps]

        for ccomp in ccomps:
            if len(ccomp) == 1: continue
            
            for (i, n0) in enumerate(ccomp):
                #n0_neighs = set(G.neighbors(n0))
                n0_neighs = set(G[n0])
                n0_neighs.add(n0)     # add self to neighbors
                Nx = len(n0_neighs)
                if Nx==1: continue
                
                for n1 in ccomp[i+1:]:
                    # intersecting of sets is faster, and creation is hardly slower than a list
                    #n1_neighs = G.neighbors(n1) + [n1]
                    n1_neighs = set(G[n1])
                    n1_neighs.add(n1)
                    
                    Ny = len(n1_neighs)
                    if Ny==1: continue
                    
                    Nxy = len(n0_neighs.intersection(n1_neighs))

                    if Nxy==0: continue
                    

                    denom = Nx * (N-Nx) * Ny * (N-Ny)
                    #if denom == 0:
                    #    print "zero denominator", n0, n1, Nx, Ny, Nxy
                    #    continue

                    NC = float(N * Nxy - Nx * Ny) / math.sqrt( denom)
                    
                    # print n0, n1, Nx, Ny, Nxy, NC

                    if NC >= nc_thresh:
                        G_new.add_edge(n0, n1, weight=NC)

        #print "Connected Components (after):", [len(ccomp) for ccomp in
        #                                        NX.connected_components(G_new)]

        return G_new

    def nc_optimize_density(self, n_iter_max=20, delta=0.001, nc_calc_thresh=0.1):
        """Optimize the NC threshold to obtain a network with density
        within Dorig +/- Dorig * delta"""

        d_orig = NX.density(self.G_orig)

        G_nc = self.unweighted_nc(nc_thresh=nc_calc_thresh)
        G_cur = G_nc
        
        d_nc = NX.density(G_cur)

        d_last = None
        nc_last = None
        m_last = None
        n_iter = 1
        nc_thresh = 0.5
        while d_nc > d_orig * (1+delta) or d_nc < d_orig * (1-delta):
            if d_last is None or (d_nc == d_last and m_last is None):
                d_last = d_nc
                nc_last = nc_thresh

                sign = 1 if d_nc > d_orig else -1
                #print "test", nc_thresh, sign * 0.5 * min((nc_thresh, 1-nc_thresh))
                nc_thresh = nc_thresh + sign * 0.5 * min((nc_thresh, 1-nc_thresh))

            else:
                if d_nc == d_last:
                    # We're surely going in the right direction. Just continue for now.
                    m = m_last

                    # if we hit an endpoint, turn back
                    if nc_thresh ==0 or nc_thresh == 1:
                        m = -1*m
                    
                else:
                    m = (nc_thresh - nc_last) / (d_last - d_nc)
                    m_last = m
                
                d_last = d_nc
                nc_last = nc_thresh

                nc_thresh += m * (d_nc - d_orig)

                if nc_thresh < 0: nc_thresh = 0
                elif nc_thresh > 1: nc_thresh = 1
                
            G_cur = self.thresh_weighted_graph(G_nc, thresh=nc_thresh)
            #G_cur = self.thresh_weighted_graph_faster(G_nc, thresh=nc_thresh)
            d_nc = NX.density(G_cur)

            #print "%0.6f %0.6f %0.6f - %0.6f %0.6f - %d %0.6f" % (nc_thresh, d_orig, d_nc,
            #                                                      d_last, nc_last, n_iter,
            #                                                      d_orig * delta)
            
            n_iter += 1
            if n_iter >= n_iter_max: break

        if nc_thresh < nc_calc_thresh:
            print "WARNING: Optimal threshold is below NC calculation threshold"

        print "%d iterations reached nc_thresh (%g) for Dnc (%g).  Dorig (%g)." % (
            n_iter, nc_thresh, d_nc, d_orig)
        
        return (G_cur, nc_thresh)


    def test_graph(self, p_del, p_add=None, trials=10, nc_thresh=0.5,
                   nc_fix_density=True, pickle=True,name=None):

        pickledir = os.path.expandvars('$HOME/tmp/pickles')
        
        if name is not None and pickle:
            retval = pickler.cachefn( pickledir=pickledir, args="%s_%s_%s_%s" % (name, p_del, trials,
                                                                              nc_thresh))
            if retval: return retval

        stat_rand = []
        stat_nc = []

        nxstat = stathelper.nxstat(G=self.G_orig)
        stat_orig = nxstat.calc_statistics(omit_spl=True, omit_cc=False)

        nc_thresh_last = nc_thresh
        print "p_del:", p_del
        for i in range(trials):
            print "Trial: ", i
            
            # copies self.G_orig to self.G, randomizes self.G
            self.randomize( p_del = p_del, p_add = p_add, simulate=False)
        
            nxstat.G = self.G
            stat_rand.append( nxstat.calc_statistics(omit_spl=True, omit_cc=False))

            if nc_fix_density:
                (self.G, nc_thresh_last) = self.nc_optimize_density()
                
            else:
                self.G = self.unweighted_nc(nc_thresh=nc_thresh)
                
            nxstat.G = self.G
            stat_nc.append( nxstat.calc_statistics(omit_spl=True, omit_cc=False))

        d = {'orig': {}, 'rand': {}, 'nc': {}}
        for stat_name in ('graph_density', 'graph_mean_cc', 'ccomp_num',
                          'ccomp_num_single', 'ccomp_avg_density_w'):
            d['orig'][stat_name] = { 'mean': stat_orig[stat_name],
                                     'std': None}
            for (key, stats) in ( ('rand', stat_rand), ('nc', stat_nc)):
                statarr = numpy.array([ a[stat_name] for a in stats])
                d[key][stat_name] = { 'mean': numpy.mean(statarr),
                                      'std': numpy.std(statarr) }

        if name is not None and pickle:
            pickler.cachefn( retval=d, pickledir=pickledir,
                             args="%s_%s_%s_%s" % (name, p_del, trials, nc_thresh))
        return d

    def test_noise(self, ntrials=5, delrange=None, name=None,
                   pickle_use=False, pickle_write=True, profile=False):
        pickledir = os.path.expandvars('$HOME/tmp/pickles')
        print "....", name, pickle_use, pickle_write
        
        if name is not None and pickle_use and not profile:
            retval = pickler.cachefn( pickledir=pickledir, args="%s_%d" % (name, ntrials))
            if retval: return retval

        if delrange is None:
            i = 11
            delrange = [float(a)/40 for a in range(1,21)]

        score_dict = {}
        cmd = """for p_del in delrange:
        score_dict[p_del] = self.test_graph(p_del, trials=ntrials, name=name)"""

        if profile:
            cProfile.runctx(cmd, globals(), locals())
        else:
            exec(cmd)

        if name is not None and pickle_write:
            pickler.cachefn( retval = score_dict, pickledir=pickledir, args="%s_%d" % (name, ntrials))

        return score_dict

            
    def table_test_graph(self, p_del, p_add=None, ntrials=10,
                         nc_fix_density=True, nc_thresh=0.5):
        score_dict = self.test_graph(p_del=p_del, p_add=p_add, trials=ntrials,
                                     nc_fix_density=nc_fix_density, nc_thresh=0.5)

        s = "--------------------------------------------------------------------------------\n"
        s += "  Nodes: %5d   Edges: %d\n" % (len(self.G_orig.nodes()), len(self.G_orig.edges()))
        s += "#Trials: %5d  P(del): %g\n" % (ntrials, p_del)
    
        s += "           Statistic    Cliques     Random (sd)                  NC (sd)\n"
        s += "--------------------------------------------------------------------------------\n"
        for stat_name in score_dict['orig'].keys():
            s += "%20s %10.6f %10.6f (%-11g) %10.6f (%-11g)\n" % (
                stat_name,
                score_dict['orig'][stat_name]['mean'],
                score_dict['rand'][stat_name]['mean'],
                score_dict['rand'][stat_name]['std'],
                score_dict['nc'][stat_name]['mean'],
                score_dict['nc'][stat_name]['std'])
        return s

def plot_stats( stat_dict, fprefix=None, title_p='',
                fext=None, dpi=None):
    x, statarr = build_plot_array(stat_dict)
    
    for stat_name in statarr[ statarr.keys()[0]].keys():
        plot_stat( x, statarr, stat_name, fprefix, title_p,
                   legend = (stat_name=='ccomp_num_single'),
                   fext = fext, dpi = dpi)


def plot_stat( x, statarr, stat_name, fprefix=None, title_p='', dpi=100,
               legend=True, fext='png'):
    
    if fprefix is not None:
        pyplot.ioff() # turn off figure updates
        
    pyplot.rc('figure', figsize=(3.25,2.25))
    pyplot.rc('font', size=10, family = 'serif',
              serif = ['Computer Modern Roman']
              )
    pyplot.rc('text', usetex=True)
    
    for key in statarr.keys():
        if key=='orig':
            pyplot.plot( x,
                         statarr[key][stat_name][0],
                         label = 'Original ($G_H$)',
                         alpha=0.7
                        )
        else:
            if key=='rand':
                lab = "Simulation ($G_S$)"
                marker = '*'
                ls = ':'
            elif key=='nc':
                lab = "NC ($G_{\mbox{NC}}$)"
                marker = 'o'
                ls = '--'

            # don't draw error bars that extend into the negative region
            ypos = numpy.array(statarr[key][stat_name][0])
            error_neg_plus = numpy.array( [ statarr[key][stat_name][1],
                                            statarr[key][stat_name][1]])
            error_neg_plus[0] = numpy.min( [error_neg_plus[0], ypos], axis=0)
            
            pyplot.errorbar(x = x, y = statarr[key][stat_name][0],
                            yerr = error_neg_plus,
                            label = lab,
                            alpha=0.7)
        
    #pyplot.title(title_p + stat_name)
    #pyplot.axis('tight')
    pyplot.xlabel('Noise ($P_d$)')
    #pyplot.ylabel(stat_name)

    pyplot.xlim(0,0.51)
    if stat_name in ('graph_mean_cc','ccomp_avg_density_w') : pyplot.ylim(-0.01,1.01)
    #elif stat_name in ('ccomp_num', 'ccomp_num_single'):
    #    pyplot.ylim( pyplot.ylim()[0], pyplot.ylim()[1]+0.1)
    
    pyplot.grid(color='gray', alpha=0.5)

    if legend:
        #l = pyplot.legend(loc='best')
        l = pyplot.legend(loc=2) # upper left
        l.legendPatch.set(fill=True, facecolor='gray', edgecolor='gray', alpha=0.5)
        
    pyplot.subplots_adjust(left=0.13, right=0.985, top=0.98, bottom=0.16)
    
    if fprefix is not None:
        pyplot.savefig( fprefix+"_%s.%s" % (stat_name, fext), dpi=dpi)
        pyplot.close()
    
        
def build_plot_array(score_dict, stat_names=None):
    
    keys = score_dict[score_dict.keys()[0]].keys()
    
    if stat_names is None:
        stat_names = score_dict[score_dict.keys()[0]][keys[0]].keys()
        
    x = sorted( score_dict.keys())
    vals = {}
    
    for key in keys: vals[key] = {}
    
    for stat_name in stat_names:
        for key in keys:
            vals[key][stat_name] = ( [score_dict[p_del][key][stat_name]['mean'] for p_del in x],
                                     [score_dict[p_del][key][stat_name]['std'] for p_del in x])
            
    return x, vals


def print_table():
    ntrials = int(sys.argv[1])
    p_del = float(sys.argv[2])
    profile = sys.argv[3]=='True'
    #nc_thresh = float(sys.argv[4])
    nc_thresh = 0.5
    if len(sys.argv) == 5:
        p_add = float(sys.argv[4])
    else:
        p_add = None
    nc_fix_density=True
    
    #rn = randnetwork( { 7: 1, 8: 1, 10: 1, 12: 2, 14: 1,
    #                    22: 1,  31: 1, 32: 1, 38: 2, 44: 3,
    #                    # 46: 1, 55: 1, 56: 1, 77: 1, 81: 1, 906: 1
    #                    } ) 
    
    #rn = randnetwork( { 4: 10, 8: 10, 16: 10, 32: 10, 64:10} )
    rn = randnetwork( { 4: 100} )

    if profile:
        import cProfile
        cProfile.runctx("stat_table = rn.table_test_graph(p_del=p_del, p_add = p_add, ntrials=ntrials, nc_thresh=nc_thresh, nc_fix_density=nc_fix_density)", globals(), locals())

    elif False:
        stat_table = rn.table_test_graph(p_del=p_del, p_add=p_add,
                                         ntrials=ntrials, nc_thresh=nc_thresh,
                                         nc_fix_density=nc_fix_density)


if __name__ =="__main__":
    date = time.strftime('%Y%m%d')

    if False:
        print_table

    if True:
        ntrials = 1000
        # constant 1024 nodes
        motifs = [ ('cliques-4', { 4: 256},),
                   # ('cliques-8', { 8: 128}),
                   ('cliques-16', {16: 64}),
                   # ('cliques-32', {32: 32}),
                   # ('cliques-64', {64: 16})
                   ('cliques-mix', { 4: 16, 8: 8, 16: 8, 32: 8, 64:8})
                   ]
        for (name, motif) in motifs:
            rn = randnetwork( motif)
            score_dict = rn.test_noise( ntrials = ntrials, pickle_write=True, pickle_use=True,
                                        name = name, profile=False)
            plot_stats( score_dict, title_p=name+": ",
                        fprefix='figures/%s_%s_%s' % (date, name, ntrials),
                        fext='pdf')

    


        
