#!/usr/bin/env python

# Jacob Joseph
# 27 May 2010

# Test clustering comparison using VI

from JJcluster import cluster_sql
import numpy, pp

# waterworks
import ClusterMetrics

def build_confusion_dict(clusters_0, clusters_1):
    d = {}
    for (c0, c0_set) in clusters_0.iteritems():
        for (c1, c1_set) in clusters_1.iteritems():
            d[ (c0, c1)] = len( c0_set.intersection(c1_set))
    return d

def build_confusion_matrix(clusters_0, clusters_1):
    cm = ClusterMetrics.ConfusionMatrix()

    for (c0, c0_set) in clusters_0.iteritems():
        for (c1, c1_set) in clusters_1.iteritems():
            overlap = len( c0_set.intersection(c1_set))
            if overlap > 0:
                cm.add(c0, c1, count=overlap)
    return cm
    

def cut_comparison(cr0, cr1, cuts=None):
    if cuts is None:
        # psycopg2 doesn't seem to be able to adapt numpy floats
        cuts = [float(a) for a in numpy.arange(0.1, 1, 0.1) ]  # 0.1...0.9

    stat_dict = {}

    print "fetching cuts"
    cr0_cuts = dict(( (dist,cr0.cut_tree(distance=dist)) for dist in cuts) )
    cr1_cuts = dict(( (dist,cr1.cut_tree(distance=dist)) for dist in cuts) )
    print "done"

    for (dist0,cut0) in sorted(cr0_cuts.iteritems()):
        if dist0 not in stat_dict: stat_dict[dist0] = {}
            
        for (dist1,cut1) in sorted(cr1_cuts.iteritems()):
            if dist0 > dist1: continue

            print "building matrix:", dist0, dist1
            cm = build_confusion_matrix(cut0, cut1)

            print "calculating stats:", dist0, dist1
            stat_dict[dist0][dist1] = {
                'vi': cm.variation_of_information(),
                'nvi': cm.normalized_vi(),
                #'micro-f': cm.micro_average_f(),  # slow
                #'macro-f': cm.macro_average_f(),  # slow
                'vm': cm.v_measure(),
                'beta vm': cm.v_beta(),
                #'121o': cm.one_to_one_optimal()
                }

    return stat_dict

def calc_stats(clusters_0, clusters_1):
    cm = build_confusion_matrix(clusters_0, clusters_1)
    
    #print "calculating stats:", cr_id, dist
    stats = {
        'vi': cm.variation_of_information(),
        'nvi': cm.normalized_vi(),
        #'micro-f': cm.micro_average_f(),  # slow
        #'macro-f': cm.macro_average_f(),  # slow
        'vm': cm.v_measure(),
        'beta vm': cm.v_beta(),
        #'121o': cm.one_to_one_optimal()
        }
    return stats

def flatclust_vs_hcuts(cr_id_arr, hclust, cuts=None):
    if cuts is None:
        # psycopg2 doesn't seem to be able to adapt numpy floats
        cuts = [float(a) for a in numpy.arange(0.1, 1, 0.1) ]  # 0.1...0.9

    print "fetching cuts"
    hclust_cuts = dict(( (dist, hclust.cut_tree(distance=dist)) for dist in cuts) )
    print "done"

    stat_dict = {}
    job_dict = {}
    
    for cr_id in cr_id_arr:
        # fetch flat clustering
        c = cluster_sql.flatcluster( cluster_run_id = cr_id, cacheq=True)
        clusters_0 = c.fetch_clusters()

        for (dist, clusters_1) in sorted( hclust_cuts.iteritems()):
            print cr_id, dist
            job_dict[(cr_id, dist)] = job_server.submit(
                calc_stats, # function
                (clusters_0, clusters_1), # params
                (build_confusion_matrix,), # dependent functions
                ("ClusterMetrics",)) # modules

    job_server.print_stats()
    
    # fetch finished jobs
    for cr_id in cr_id_arr:
        if not cr_id in stat_dict: stat_dict[cr_id] = {}

        for dist in hclust_cuts:
            stat_dict[cr_id][dist] = job_dict[(cr_id, dist)]()

    job_server.print_stats()
    
    return stat_dict

def flatclust_vs_flatclust(cr_id_arr0, cr_id_arr1, symmetric=True):

    job_dict = {}
    stat_dict = {}

    for (i,cr_id0) in enumerate(cr_id_arr0):
        c0 = cluster_sql.flatcluster( cluster_run_id = cr_id0, cacheq=True)
        clusters_0 = c0.fetch_clusters()

        for cr_id1 in (cr_id_arr1[i:] if symmetric else cr_id_arr1):
            print cr_id0, cr_id1
            c1 = cluster_sql.flatcluster( cluster_run_id = cr_id1, cacheq=True)
            clusters_1 = c1.fetch_clusters()
            
            job_dict[(cr_id0, cr_id1)] = job_server.submit(
                calc_stats, # function
                (clusters_0, clusters_1), # params
                (build_confusion_matrix,), # dependent functions
                ("ClusterMetrics",)) # modules

    # fetch finished jobs
    for ((cr_id0, cr_id1), job) in job_dict.iteritems():
        if not cr_id0 in stat_dict: stat_dict[cr_id0] = {}
            
        stat_dict[cr_id0][cr_id1] = job()

    job_server.print_stats()
    return stat_dict

def print_stats(stat_dict, key):
    cuts = sorted(stat_dict.keys())
    
    m = "Key: %s\n" % key
    
    m += "       "
    for col in cuts:
        m += "%9.2f" % col
    m += "\n"
    m += "      " + "-" * 9 * len(cuts) + "\n"
    
    for row in cuts:
        m += "%4.2f |" % row
        for col in cuts:
            if not col in stat_dict[row]:
                m += "         "
                
            else:
                m += "%9.2g" % stat_dict[row][col][key]
        m += "\n"
    return m

def print_stats2(stat_dict, keys0=None, keys1=None, statistic='nvi'):
    # cuts = sorted(stat_dict.keys())
    
    if keys0 is None:
        keys0 = sorted( stat_dict.keys())
        
    if keys1 is None:
        keys1 = sorted(stat_dict[stat_dict.keys()[0]].keys())
        
    m = "Statistic: %s\n" % statistic
    
    m += "       "
    for col in keys1:
        m += "%9.2f" % col
    m += "\n"
    m += "      " + "-" * 9 * len(keys1) + "\n"
    
    for row in keys0:
        m += "%4.2f |" % row
        for col in keys1:
            if not col in stat_dict[row]:
                m += "         "
                
            else:
                m += "%9.3g" % stat_dict[row][col][statistic]
        m += "\n"
    return m


if __name__ == "__main__":
    job_server = pp.Server()

    cr0 = cluster_sql.hcluster(cluster_run_id=82, cacheq=True)
    #cr1 = cluster_sql.hcluster(cluster_run_id=82)

    #cut_0 = cr0.cut_tree(distance=0.5)
    #cut_1 = cr1.cut_tree(distance=0.5)

    #confmatrix = build_confusion_matrix(cut_0, cut_1)

    #cut_stats = cut_comparison(cr0, cr1)
    
    #flatstats = flatclust_vs_hcuts( [93, 92, 91, 90, 85, 86, 87, 88, 89, 84], cr0)
    flatstats = flatclust_vs_flatclust( [93, 92, 91, 90, 85, 86, 87, 88, 89, 84],
                                        [93, 92, 91, 90, 85, 86, 87, 88, 89, 84])

    print print_stats2(flatstats,
                       keys0=[93, 92, 91, 90, 85, 86, 87, 88, 89, 84],
                       keys1=[93, 92, 91, 90, 85, 86, 87, 88, 89, 84])
