#!/usr0/local/bin/python

import SGD
from JJutil import mysqlutils
import time
import insert_generic

class insert_sgd(insert_generic.insert_generic):

    sgd = None
    source_info = None
    
    def __init__(self, dbw):
        self.dbw = dbw
        self.sgd = SGD.sgd_parser()
        self.source_info={}

    #
    # insert_sgd : the top level function that kicks off the insert
    #
    def insert_sgd(self):
        self.sgd.parse_sgd()
        
        print "SGD date: ", self.sgd.sgd_date
        
        self.source_info['SGD'] = {'version' : self.sgd.sgd_version,
                                   'date' : self.sgd.sgd_date}
        
        sequences = [self.sgd.sequences[x] for x in self.sgd.get_sequences()]
        for i in sequences:
            if self.sgd.get_feature(i.sgd_id) == None:
                print "we might have a bad record for " + i.sgd_id
            else:
                self.insert_prot_seq_version(i.sgd_id)

    def insert_prot_seq_version(self, sgd_id):
        q = """SELECT seq_id FROM prot_seq_version
        JOIN prot_seq_source_ver USING (source_ver_id)
        WHERE source_id = '%d'
        AND primary_acc = '%s'
        ORDER BY seq_id DESC
        LIMIT 1"""

        source_id = self.lookup_prot_seq_source( "SGD" )
        source_ver_id = self.lookup_prot_seq_source_ver( "SGD",
                                                         self.source_info['SGD']['version'],
                                                         self.source_info['SGD']['date'])
        self.dbw.commit()
        
        primary_acc = sgd_id
        inputtup = tuple(map(mysqlutils.quote, (source_id, primary_acc)))
        self.dbw.execute( q % inputtup)
        seq_id = self.dbw.fetchsingle()
        #print "found seq_id:", seq_id

        # grab these here for clarity

        print "trying to insert "+sgd_id

        if seq_id and seq_id == self.compare_prot_seq_sgd( seq_id, sgd_id):
            self.insert_prot_seq_version_simple( seq_id, source_ver_id, sgd_id)
        else:
            self.insert_prot_seq_version_full(seq_id, source_ver_id, sgd_id)


    def insert_prot_seq_version_full(self, seq_id, source_ver_id, sgd_id):
        
        sequence = self.sgd.get_sequence(sgd_id)
        property = self.sgd.get_property(sgd_id)
        feature = self.sgd.get_feature(sgd_id)

        dbxrefs = self.sgd.get_dbxref(sgd_id)        
        

        accessions = [sgd_id]
        
        if len(feature.secondary_sgd_ids) > 0:
            for i in feature.secondary_sgd_ids:
                accessions.append(i)

        crossrefs = []

        if dbxrefs:
            for i in dbxrefs:
                crossrefs.append((i.source,i.accession,'','',''))

        go_codes = self.sgd.get_go_code(sgd_id)

        seq_id = self.insert_prot_seq(sequence.sequence,
                                      property.protein_length,
                                      sequence.crc64, property.mol_weight)

        self.insert_prot_seq_version_simple( seq_id, source_ver_id,
                                             sequence.sgd_id )

        self.dbw.commit()

        self.insert_sgd_feature(seq_id, feature)

        self.insert_sgd_go(seq_id, go_codes)

        self.insert_prot_seq_accessions( seq_id, accessions )

        self.insert_sgd_gene_names(seq_id, feature)

        if crossrefs:
            self.insert_sp_crossrefs(seq_id, crossrefs)

        self.dbw.commit()

        return

    def lookup_sgd_go_term(self, go_term):
        q = """ SELECT go_id FROM sgd_go_term
        where go_id = %d    
        """

        self.dbw.execute(q % go_term.go_id)
        go_id = self.dbw.fetchsingle()

        if go_id : return go_id

        self.insert_sgd_go_term(go_term)

        return go_term.go_id
        
    def insert_sgd_go_term(self,go_term):
        
        i = """ INSERT INTO sgd_go_term 
        (go_id, go_term, go_aspect, go_definition)
        VALUES ('%d','%s','%s','%s')
        """

        #for go_term in go_terms:

        inputtup = tuple(map(mysqlutils.quote, (go_term.go_id,
                                                go_term.go_term,
                                                go_term.go_aspect,
                                                go_term.go_definition)))        
        self.dbw.execute(i % inputtup)
        self.dbw.commit()
        return

    def lookup_sgd_chromosome(self, chromosome_num):
        q = """ SELECT chromosome FROM sgd_chromosome
        where chromosome = %d    
        """

        self.dbw.execute(q % chromosome_num)
        chromosome_id = self.dbw.fetchsingle()        
        if chromosome_id == None:                    
            self.insert_sgd_chromosome(chromosome_num)
        return chromosome_num
    
    def insert_sgd_chromosome(self, chromosome_num):
        
        i = """ INSERT INTO sgd_chromosome 
        (chromosome, num_genes)
        VALUES ('%d','%d')
        """

        chromosome = self.sgd.get_chromosome(chromosome_num)
        
        inputtup = tuple(map(mysqlutils.quote, (chromosome.chromosome_num,
                                                chromosome.num_genes)))        
        self.dbw.execute(i % inputtup)
        self.dbw.commit()
        return

    def insert_sgd_go(self,seq_id, go_codes):
        i = """ INSERT INTO sgd_go 
        (seq_id, go_id)
        VALUES ('%d','%d')
        """
        
        if go_codes == None:
            return
        
        for go_code in go_codes:
            go_id = self.lookup_sgd_go_term(self.sgd.get_go_term(go_code.go_id))
            assert go_id != None, "You should not see this - bad go code in a sgd_record : "+str(go_code.go_id)
            inputtup = tuple(map(mysqlutils.quote, (seq_id, go_id)))
            self.dbw.execute(i % inputtup)        

        return

    def insert_sgd_feature(self, seq_id, feature):
        
        i = """ INSERT INTO sgd_feature 
        (seq_id, chromosome, start_coord, end_coord, strand, feature_type,
        feature_qualifier, orf_name, description)    
        VALUES ('%d','%d','%s','%s','%s','%s','%s','%s','%s')
        """

        chromosome = self.lookup_sgd_chromosome(feature.chromosome)
        chrom_start = feature.chrom_start
        chrom_end= feature.chrom_end
        strand = feature.strand
        feature_type = feature.feature_type
        feature_qualifier = feature.feature_qualifier
        orf_name = feature.orf_name
        description = feature.description

        inputtup = tuple(map(mysqlutils.quote, (seq_id, chromosome,
                                                chrom_start, chrom_end,
                                                strand, feature_type,
                                                feature_qualifier, orf_name,
                                                description)))

        self.dbw.execute(i % inputtup)
        return

    def insert_sgd_gene_names(self, seq_id, feature):

        i = """ INSERT INTO sgd_gene_name
        (seq_id, acc_order, gene_name)     
        VALUES ('%d','%d','%s')
        """    

        primary_gene_name = feature.primary_gene_name
        primary_gene_name = primary_gene_name.strip()
        
        aliases = feature.aliases

        if  primary_gene_name != "":            
            inputtup = tuple(map(mysqlutils.quote, (seq_id, 0, primary_gene_name)))        
            self.dbw.execute(i % inputtup)

        acc_order = 1

        for gene_name in aliases:
            inputtup = tuple(map(mysqlutils.quote, (seq_id, acc_order,
                                                        gene_name)))        
            self.dbw.execute(i % inputtup)
            acc_order = acc_order + 1
        return


    def compare_prot_seq_sgd( self, seq_id, sgd_id):

        sequence = self.sgd.get_sequence(sgd_id)
        property = self.sgd.get_property(sgd_id)
        feature = self.sgd.get_feature(sgd_id)

        dbxrefs = self.sgd.get_dbxref(sgd_id)        

        accessions = [sgd_id]
        
        if len(feature.secondary_sgd_ids) > 0:
            for i in feature.secondary_sgd_ids:
                accessions.append(i)

        crossrefs = []

        if dbxrefs:
            for i in dbxrefs:
                crossrefs.append((i.source,i.accession,'','',''))


        if not seq_id == self.compare_prot_seq_str( seq_id, sequence.sequence,
                                                    property.protein_length,
                                                    sequence.crc64,
                                                    property.mol_weight):
            return None

        if not seq_id == self.compare_prot_seq_accessions( seq_id, accessions):
            return None

        if not seq_id == self.compare_sgd_go(seq_id, sgd_id):
            return None

        if not seq_id == self.compare_sgd_feature(seq_id, sgd_id):
            return None

        if not seq_id == self.compare_sgd_gene_names(seq_id, sgd_id):
            return None

        if not seq_id == self.compare_sp_crossrefs(seq_id, crossrefs):
            return None

        return seq_id

    def compare_sgd_feature(self, seq_id, sgd_id):
        
        q = """SELECT chromosome,start_coord,end_coord,strand,feature_type,
        feature_qualifier,orf_name,description FROM sgd_feature  
        WHERE seq_id=%d
        """

        feature = self.sgd.get_feature(sgd_id)
        chromosome=feature.chromosome
        chrom_start=feature.chrom_start
        chrom_end=feature.chrom_end
        strand=feature.strand.upper()
        feature_type=feature.feature_type
        feature_qualifier=feature.feature_qualifier
        orf_name=feature.orf_name
        description=feature.description
        
        self.dbw.execute(q % seq_id)
        f_feature = self.dbw.fetchone()    

        identical = True

        if feature == None:
            return False

        if not chromosome == f_feature[0]:
            identical = False

        if not chrom_start == f_feature[1]:
            identical = False

        if not chrom_end == f_feature[2]:
            identical = False

        if not strand == f_feature[3].upper():
            identical = False

        if not feature_type == f_feature[4]:
            identical = False

        if not feature_qualifier == f_feature[5]:
            identical = False

        if not orf_name == f_feature[6]:
            identical = False

        if not description == f_feature[7]:
            identical = False        

        if identical :
            return seq_id
        return None


    def compare_sgd_gene_names(self, seq_id, sgd_id):        
        q = """SELECT gene_name
        FROM sgd_gene_name 
        WHERE seq_id='%d'"""

        self.dbw.execute( q % seq_id )

        primary_gene_name = self.sgd.get_feature(sgd_id).primary_gene_name
        gene_names = self.sgd.get_feature(sgd_id).aliases
        if primary_gene_name:
            gene_names.append(primary_gene_name)
        gene_names.sort()

        f_gene_names = self.dbw.fetchcolumn()
        f_gene_names.sort()

        if f_gene_names == gene_names:
            return seq_id

        return None

    def compare_sgd_go(self, seq_id, sgd_id):
        q = """SELECT go_id FROM sgd_go 
        WHERE seq_id='%d'
        ORDER BY go_id"""

        go_codes = self.sgd.get_go_code(sgd_id)

        self.dbw.execute( q % seq_id )

        f_go_codes = self.dbw.fetchcolumn()

        sorted_go_codes = sorted([x.go_id for x in go_codes])

        f_go_codes.sort()
            
        if sorted_go_codes == f_go_codes:
            return seq_id

        return None


if __name__ == "__main__":
    

    dbhost = "goby.compbio.cs.cmu.edu"
    dbuser = "jmjoseph"
    dbpass = "pQ1rng8c"
    db = "DurandLab2"
    
    dbw = mysqlutils.dbwrap( dbhost, dbuser, dbpass, db)
    sgd_import = insert_sgd(dbw)
    sgd_import.insert_sgd()
        
