#!/usr/bin/env python

import os, time, sys, gzip
from Bio.SwissProt import SProt
from JJutil import mysqlutils, rate
from insert_generic import insert_generic

class insert_swissprot(insert_generic):

    def fetch_current_uniprot( self):
        baseurl = "ftp://ca.expasy.org/databases/uniprot/current_release/knowledgebase/complete/"
        files = ['reldate.txt','uniprot_sprot.dat.gz', 'uniprot_trembl.dat.gz']

        os.chdir( self.tmpdir)

        for f in files:
            url = baseurl+f
            print "Fetching data file from",url

            # only overwrite when date or size is incorrect (-N)
            # wget is used rather than python's urllib to allow resuming
            # FIXME: actually, -N breaks resuming
            os.system( "wget -N %s" % url)


        return

    
    def parse_reldate(self):
        fname = os.path.join(self.tmpdir, 'reldate.txt')
        fd = open( fname)

        for line in fd:
            if line.startswith("UniProt Knowledgebase Release"): continue

            linearr = line.split()
            version = linearr[2]
            date = linearr[4]

            if linearr[0].startswith("UniProtKB/Swiss-Prot"):
                self.source_info['swissprot'] = { 'version': version,
                                                  'date': date}
            elif linearr[0].startswith("UniProtKB/TrEMBL"):
                self.source_info['trembl'] = { 'version': version,
                                               'date': date}
        return


    def parse_sp( self, filename, source_name):
        fd = gzip.open(filename, 'r')

        sp_parser = SProt.RecordParser()
        sp_iterator = SProt.Iterator( fd, sp_parser)

        version = self.source_info[source_name]['version']
        date = self.source_info[source_name]['date']

        i_rate = rate.rate()

        cur_record = sp_iterator.next()

        while cur_record:
            # print the rate information periodically
            if i_rate.count % 500 == 0:
                print i_rate
            i_rate.increment()

            self.insert_prot_seq_version( source_name, version,
                                          date, cur_record)
            cur_record = sp_iterator.next()
        return


    def compare_prot_seq( self, seq_id, entry_name, gene_name,
                          description, molecular_weight,
                          created, sequence_update, annotation_update,
                          keywords, features, gb_tax_ids, accessions,
                          crossrefs, sequence, length, crc):
        """Check if the existing prot_seq seq_id entry is identical to
that to be inserted"""

        if not seq_id == self.compare_sp_xtra( seq_id, entry_name,
            gene_name, description,
            created, sequence_update, annotation_update):
            print "Non-identical sp_xtra", seq_id
            return None

        if not seq_id == self.compare_sp_keywords( seq_id, keywords):
            print "Non-identical keywords", seq_id
            return None

        if not seq_id == self.compare_sp_features( seq_id, features):
            print "Non-identical sp_features", seq_id
            return None

        if not seq_id == self.compare_sp_gb_tax_ids( seq_id, gb_tax_ids):
            print "Non-identical gb_tax_ids", seq_id
            return None

        if not seq_id == self.compare_prot_seq_accessions( seq_id, accessions):
            print "Non-identical accessions", seq_id
            return None

        if not seq_id == self.compare_sp_crossrefs( seq_id, crossrefs):
            print "Non-identical cross references", seq_id
            return None

        if not seq_id == self.compare_prot_seq_str( seq_id, sequence, length,
                                                    crc, molecular_weight):
            print "Non-identical sequences", seq_id
            return None

        return seq_id


    def compare_sp_keywords(self, seq_id, keywords):
        #q = """SELECT keyword FROM sp_keyword
        #JOIN sp_keyword_str USING (keyword_strid)
        #WHERE seq_id='%d'
        #ORDER BY keyword"""

        q = """SELECT keyword FROM sp_keyword
        JOIN sp_keyword_str USING (keyword_strid)
        WHERE seq_id='%d'"""

        # sort without case sensitivity, like MySQL
        keywords.sort(key=str.upper)

        self.dbw.execute( q % seq_id)
        f_keywords = self.dbw.fetchcolumn()
        f_keywords = list(f_keywords)
        f_keywords.sort(key=str.upper)

        if keywords == f_keywords:
            return seq_id
        return None


    def compare_sp_features(self, seq_id, features):
        #q = """SELECT feature_key, from_pos, to_pos, comment, ftid
        #FROM sp_feature
        #JOIN sp_feature_key_str USING (feature_key_strid)
        #WHERE seq_id='%d'
        #ORDER BY feature_key, from_pos, to_pos, comment, ftid"""

        # ORDER BY prompts a filesort. Sort in python.
        q = """SELECT feature_key, from_pos, to_pos, comment, ftid
        FROM sp_feature
        JOIN sp_feature_key_str USING (feature_key_strid)
        WHERE seq_id='%d'"""

        # compare without case, like MySQL
        def compare_features( f1, f2):
            for i in range(0,len(f1)):
                f1_str = str(f1[i]).upper()
                f2_str = str(f2[i]).upper()
                ret = cmp(f1_str, f2_str)
                if ret != 0:
                    return ret
            return 0

        #features.sort(compare_features)

        # convert any ints to strings.  see sp_feature schema
        for (i,f) in enumerate(features):
            features[i] = tuple(map(str,f))

        features.sort(compare_features)

        self.dbw.execute(q % seq_id)
        f_features = self.dbw.fetchall()
        f_features = list(f_features)
        f_features.sort(compare_features)

        if f_features == features:
            return seq_id
        return None


    def compare_sp_gb_tax_ids(self, seq_id, gb_tax_ids):
        q = """SELECT gb_tax_id FROM sp_gb_tax_id
        WHERE seq_id='%d'
        ORDER BY gb_tax_id"""

        gb_tax_ids.sort()
        self.dbw.execute(q % seq_id)
        f_gb_tax_ids = self.dbw.fetchall()

        identical = True
        for (i,id) in enumerate(gb_tax_ids):
            if i > len(f_gb_tax_ids)-1 or f_gb_tax_ids[i][0] != int(id):
                identical = False
                break

        if identical:
            return seq_id
        return None


    def compare_sp_xtra(self, seq_id, entry_name, gene_name,
                        description, created, sequence_update,
                        annotation_update):
        #q = """SELECT seq_id from sp_xtra
        #WHERE seq_id = '%s' AND data_class = '%s' AND entry_name = '%s'
        #AND gene_name = '%s' AND description = '%s' AND molecule_type = '%s'
        #AND molecular_weight = '%s' AND created = '%s' AND sequence_update = '%s'
        #AND annotation_update = '%s'"""
        q = """SELECT seq_id from sp_xtra
        WHERE seq_id = '%d' AND entry_name = '%s'
        AND gene_name = '%s' AND description = '%s'"""    

        #inputtup = tuple(map(mysqlutils.quote, (
        #    seq_id, data_class, entry_name, gene_name, description, molecule_type,
        #    molecular_weight, created, sequence_update, annotation_update)))
        inputtup = tuple(map(mysqlutils.quote, (
            seq_id, entry_name, gene_name, description)))
        self.dbw.execute( q % inputtup)
        f_seq_id = self.dbw.fetchsingle()
        if f_seq_id: return f_seq_id
        return None


    def insert_prot_seq_version( self, source_name, version, date, spe):
        q = """SELECT seq_id FROM prot_seq_version
        JOIN prot_seq_source_ver USING (source_ver_id)
        WHERE source_id = '%d'
        AND primary_acc = '%s'
        ORDER BY seq_id DESC
        LIMIT 1"""

        # reformat dates to ease in mysql parsing
        date = time.strftime(
            '%Y-%m-%d', time.strptime(date, '%d-%b-%Y'))
        # currently broken in BioPython due to format change in Swiss-Prot
        # see http://bugzilla.open-bio.org/show_bug.cgi?id=1956
        #created = time.strftime(
        #    '%Y-%m-%d', time.strptime(spe.created, '%d-%b-%Y'))
        #sequence_update = time.strftime(
        #    '%Y-%m-%d',time.strptime(spe.sequence_update, '%d-%b-%Y'))
        #annotation_update = time.strftime(
        #    '%Y-%m-%d', time.strptime(spe.annotation_update, '%d-%b-%Y'))

        source_id = self.lookup_prot_seq_source( source_name)
        source_ver_id = self.lookup_prot_seq_source_ver( source_name, version, date)
        primary_acc = spe.accessions[0]

        # find the most recent seq_id from prot_seq_version from the most
        # recent version of the source of that name (assuming increasing
        # source_ver_ids)
        inputtup = tuple(map(mysqlutils.quote, (source_id, primary_acc)))
        self.dbw.execute( q % inputtup)
        seq_id = self.dbw.fetchsingle()
        #print "found seq_id:", seq_id

        # grab these here for clarity
        molecular_weight = spe.seqinfo[1]
        crc = spe.seqinfo[2]

        # compare this sequence entry with that to be inserted
        # if identical, just update the prot_seq_version table
        # otherwise, do a full insert
        if seq_id is not None and seq_id == self.compare_prot_seq( seq_id,
            spe.entry_name, spe.gene_name, spe.description,
            molecular_weight, spe.created, spe.sequence_update,
            spe.annotation_update, spe.keywords, spe.features,
            spe.taxonomy_id, spe.accessions, spe.cross_references,
            spe.sequence, spe.sequence_length, crc):
            self.insert_prot_seq_version_simple( seq_id, source_ver_id, primary_acc)
        else:
            self.insert_prot_seq_version_full( source_ver_id, primary_acc,
                seq_id, spe.entry_name, spe.gene_name,
                spe.description, molecular_weight,
                spe.created, spe.sequence_update, spe.annotation_update, 
                spe.keywords, spe.features, spe.taxonomy_id, spe.accessions,
                spe.cross_references, spe.sequence, spe.sequence_length, crc)

        self.dbw.commit()

        return
    

    def insert_prot_seq_version_full( self, source_ver_id, primary_acc, seq_id,
        entry_name, gene_name, description,
        molecular_weight, created, sequence_update, annotation_update,
        keywords, features, gb_tax_ids, accessions, crossrefs, sequence, length,
        crc):

        #print "Full insert"

        # create a new seq_id
        seq_id = self.insert_prot_seq( sequence, length, crc, molecular_weight)
        #print "full insert new seq_id:", seq_id

        # insert into prot_seq_version
        self.insert_prot_seq_version_simple( seq_id, source_ver_id, primary_acc)

        # commit before populating dependent tables
        #self.dbw.commit()

        # populate sp_xtra
        self.insert_sp_xtra( seq_id, entry_name, gene_name, description,
                             created, sequence_update, annotation_update)

        # populate sp_keywords
        self.insert_sp_keywords( seq_id, keywords)

        # populate sp_features
        self.insert_sp_features( seq_id, features)

        # populate gb_tax_ids
        self.insert_sp_gb_tax_ids( seq_id, gb_tax_ids)

        # populate prot_seq_accessions
        self.insert_prot_seq_accessions( seq_id, accessions)

        # populate sp_cross_refs
        self.insert_sp_crossrefs( seq_id, crossrefs)

        return


    def insert_sp_gb_tax_ids( self, seq_id, gb_tax_ids):
        q_prefix = """INSERT INTO sp_gb_tax_id
        (seq_id, gb_tax_id)
        VALUES """
        q_value = """('%d', '%d')"""

        q = q_prefix
        firstentry = True
        for gb_tax_id in gb_tax_ids:
            if not firstentry: q += ', '
            firstentry = False

            gb_tax_id = int(gb_tax_id)
            inputtup = (seq_id,gb_tax_id)

            q += q_value % inputtup

        # execute the insert
        self.dbw.execute(q)
        return


    def lookup_sp_feature_key_str( self, feature_key):
        q = """SELECT feature_key_strid FROM sp_feature_key_str
        WHERE feature_key='%s'"""

        i = """INSERT INTO sp_feature_key_str (feature_key) VALUES ('%s')"""

        inputtup = mysqlutils.quote(feature_key)

        # already exists
        self.dbw.execute(q % inputtup)
        feature_key_strid = self.dbw.fetchsingle()
        #print "Feature key string ID:", feature_key_strid
        if feature_key_strid: return feature_key_strid

        # insert
        self.dbw.execute(i % inputtup)
        feature_key_strid = self.dbw.lastrowid()
        return feature_key_strid


    def insert_sp_feature( self, seq_id, feature, from_pos,
                           to_pos, comment, ftid):
        i = """INSERT INTO sp_feature
        (seq_id, feature_key_strid, from_pos, to_pos, comment, ftid)
        VALUES
        ('%d','%d','%s','%s','%s','%s')"""

        feature_key_strid = self.lookup_sp_feature_key_str( feature)

        inputtup = tuple(map(mysqlutils.quote, (
            seq_id, feature_key_strid, from_pos, to_pos, comment, ftid)))
        self.dbw.execute(i % inputtup)
        return


    def insert_sp_features( self, seq_id, features):
        for (feature, from_pos, to_pos, comment, ftid) in features:
            self.insert_sp_feature( seq_id, feature, from_pos, to_pos,
                               comment, ftid)
        return


    def lookup_sp_keyword_str( self, keyword):
        q = """SELECT keyword_strid FROM sp_keyword_str
        WHERE keyword='%s'"""

        i = """INSERT INTO sp_keyword_str (keyword) VALUES ('%s')"""

        inputtup = mysqlutils.quote(keyword)

        # already exists
        self.dbw.execute(q % inputtup)
        keyword_strid = self.dbw.fetchsingle()
        #print "Keyword string ID:", keyword_strid
        if keyword_strid: return keyword_strid

        # insert
        self.dbw.execute(i % inputtup)
        keyword_strid = self.dbw.lastrowid()
        return keyword_strid


    def insert_sp_keyword( self, seq_id, keyword):
        i = """INSERT INTO sp_keyword
        (seq_id, keyword_strid)
        VALUES
        ('%d','%d')"""

        keyword_strid = self.lookup_sp_keyword_str( keyword)

        inputtup = (seq_id, keyword_strid)
        self.dbw.execute(i % inputtup)
        return


    def insert_sp_keywords( self, seq_id, keywords):
        for keyword in keywords:
            self.insert_sp_keyword( seq_id, keyword)
        return


    def insert_sp_xtra( self, seq_id, entry_name, gene_name, description,
                        created, sequence_update, annotation_update):
        #i = """INSERT INTO sp_xtra
        #(seq_id, entry_name, gene_name, description, molecule_type,
        #molecular_weight, created, sequence_update, annotation_update)
        #VALUES
        #('%s','%s','%s','%s','%s','%s','%s','%s','%s','%s')"""
        i = """INSERT INTO sp_xtra
        (seq_id, entry_name, gene_name, description)
        VALUES
        ('%d','%s','%s','%s')"""

        #inputtup = tuple(map(mysqlutils.quote, (
        #    seq_id, entry_name, gene_name, description, molecule_type,
        #    molecular_weight, created, sequence_update, annotation_update)))
        inputtup = tuple(map(mysqlutils.quote, (
            seq_id, entry_name, gene_name, description)))
        self.dbw.execute( i % inputtup)
        #print i % inputtup
        return

    def insert_sp_crossrefs( self, seq_id, crossrefs):
        "Insert cross references for a Uniprot entry"
        
        # use only the first 5 tuple elements
        crossrefs = [x.__getslice__(0,5) for x in crossrefs]
        for crossref in crossrefs:
            self.insert_sp_crossref( seq_id, *crossref)
        return


    def insert_sp_crossref( self, seq_id, source_name, id_0='', id_1='',
                            id_2='', id_3=''):
        "Insert one cross reference into sp_cross_ref"
        
        i = """INSERT IGNORE INTO sp_cross_ref
        (seq_id, source_id, id_0, id_1, id_2, id_3)
        VALUES
        ('%d', '%d','%s','%s','%s','%s')"""
        # Ignore is here to get around a duplicate MGI entry in Q3TLP9.
        # The duplicate is not found on the website.
        
        source_id = self.lookup_prot_seq_source( source_name)
        inputtup = tuple(map(mysqlutils.quote,
                             (seq_id, source_id, id_0, id_1, id_2, id_3)))
        self.dbw.execute(i % inputtup)
        return


    def compare_sp_crossrefs( self, seq_id, crossrefs):
        #q = """SELECT source_name, id_0, id_1, id_2, id_3 FROM sp_cross_ref
        #JOIN prot_seq_source using (source_id)
        #WHERE seq_id='%d'
        #ORDER BY source_name, id_0, id_1, id_2, id_3"""

        q = """SELECT source_name, id_0, id_1, id_2, id_3 FROM sp_cross_ref
        JOIN prot_seq_source using (source_id)
        WHERE seq_id='%d'"""

        # compare without case, like MySQL
        # we are only concerned with the first 5 tuple fields.  Ignore others
        def compare_crossrefs( f1, f2):
            # compare as upper case, like MySQL
            for i in range(0,5):
                if i >= len(f1)-1 or i >= len(f2)-1:
                    return 0
                f1_str = str(f1[i]).upper()
                f2_str = str(f2[i]).upper()
                ret = cmp(f1_str, f2_str)
                if ret != 0:
                    return ret
            return 0

        # drop all but the first 5 tuple elements,
        # padded with zero-length strings
        crossrefs = [x+tuple(['']*5) for x in crossrefs]
        crossrefs = [x.__getslice__(0,5) for x in crossrefs]
        crossrefs.sort(compare_crossrefs)

        self.dbw.execute(q % seq_id)
        f_crossrefs = self.dbw.fetchall()
        f_crossrefs = list(f_crossrefs)
        f_crossrefs.sort(compare_crossrefs)

        if f_crossrefs == crossrefs:
            return seq_id
        return None


if __name__ == "__main__":
    
    tmpdir = "/net/goby/usr1/jmjoseph/tmp/"

    insp = insert_swissprot( tmpdir=tmpdir)

    # download reldate.txt, uniprot_sprot.dat.gz,
    # and uniprot_trembl.dat.gz from:
    # ftp://ca.expasy.org/databases/uniprot/current_release/knowledgebase/complete
    insp.fetch_current_uniprot()

    # parse the version information
    insp.parse_reldate()

    # now perform the inserts for both swissprot and trembl
    print "Inserting Swiss-Prot"
    insp.parse_sp( os.path.join(local_tmp,'uniprot_sprot.dat.gz'),
              'swissprot')
    print "Inserting Trembl"
    insp.parse_sp( os.path.join(tmpdir,'uniprot_trembl.dat.gz'),
                   'trembl')
    
    # remove the data files
    os.remove(os.path.join(local_tmp, 'uniprot_sprot.dat'))
    os.remove(os.path.join(local_tmp, 'uniprot_trembl.dat'))
