#!/usr/bin/env python

import os, time, sys, gzip
from Bio.SwissProt import SProt
from JJutil import mysqlutils, rate

class insert_utils:

    dbw = None

    def __init__(self,dbw):
        self.dbw = dbw
    
    def insert_prot_seq(self, sequence, length, crc, molecular_weight):
        i = """INSERT into prot_seq
        (seq_str_id)
        VALUES ('%d')"""
        
        seq_str_id = self.lookup_prot_seq_str( sequence, length, crc, molecular_weight)
        
        # insert the sequence entry
        #print "seq_str_id:", seq_str_id
        self.dbw.execute( i % seq_str_id)
        self.dbw.commit()
        seq_id = self.dbw.lastrowid()
        return seq_id

    
    def lookup_prot_seq_source(self, source_name):
        q = """SELECT source_id FROM prot_seq_source
        WHERE source_name = '%s'"""
        
        i = """INSERT INTO prot_seq_source
        (source_name)
        VALUES ('%s')"""
        
        # already exists
        self.dbw.execute(q % mysqlutils.quote(source_name))
        source_id = self.dbw.fetchsingle()
        #print "Source ID:", source_id
        if source_id: return source_id
        
        # insert the source, returning the created serial
        # FIXME: should check the insert_id for failure
        self.dbw.execute(i % mysqlutils.quote(source_name))
        self.dbw.commit()
        source_id = self.dbw.lastrowid()
        return source_id
    
    # fetch the source_id, inserting first if necessary
    def lookup_prot_seq_source_ver( self,source_name, version, date):
        q = """SELECT source_ver_id FROM prot_seq_source_ver
        JOIN prot_seq_source USING (source_id)
        WHERE source_name='%s' AND version='%s'
        AND date='%s'"""
        
        i = """INSERT INTO prot_seq_source_ver
        (source_id, version, date)
        VALUES
        ('%d', '%s', '%s')"""
        
        # escape input
        inputtup = tuple(map(mysqlutils.quote, (source_name, version, date)))
        
        # already exists?
        self.dbw.execute(q % inputtup)
        source_ver_id = self.dbw.fetchsingle()
        #print "Source version ID:", source_ver_id
        if source_ver_id: return source_ver_id
        
        # lookup the source_id, inserting if necessary
        source_id = self.lookup_prot_seq_source( source_name)
        inputtup = tuple(map(mysqlutils.quote, (source_id, version, date)))
        
        # now insert the particular source version
        self.dbw.execute(i % inputtup)
        self.dbw.commit()
        source_ver_id = self.dbw.lastrowid()
        return source_ver_id
    

    def lookup_prot_seq_str(self, sequence, length, crc, molecular_weight):
        q = """SELECT seq_str_id FROM prot_seq_str
        WHERE sequence='%s' AND length='%s' AND crc='%s'
        AND molecular_weight='%s'"""
        
        i = """INSERT INTO prot_seq_str
        (sequence, length, crc, molecular_weight)
        VALUES
        ('%s','%s','%s','%s')"""
        
        inputtup = tuple(map(mysqlutils.quote, (sequence, length, crc,
                                                molecular_weight)))
        
        # already exists
        self.dbw.execute(q % inputtup)
        seq_str_id = self.dbw.fetchsingle()
        
        if seq_str_id: return seq_str_id
        
        # insert the sequence
        self.dbw.execute( i % inputtup)
        self.dbw.commit()
        seq_str_id = self.dbw.lastrowid()

        return seq_str_id


    # insert all accessions for this database entry
    # this should be extended to handle cross references with different source_ids
    def insert_prot_seq_accessions(self, seq_id, accessions):
        q_prefix = """INSERT INTO prot_seq_accession
        (seq_id, source_id, acc_order, accession)
        VALUES """
        q_value = """(
        '%d',
        (SELECT source_id FROM prot_seq_version
        JOIN prot_seq_source_ver USING (source_ver_id)
        WHERE seq_id = '%d'),
        '%d','%s')"""
        
        q = q_prefix
        firstentry = True
        for (i,accession) in enumerate(accessions):
            if not firstentry: q += ', '
            firstentry = False
            
            inputtup = tuple(map(mysqlutils.quote, (
                seq_id, seq_id, i, accession)))
            
            q += q_value % inputtup
            
            # execute the insert
        self.dbw.execute(q)
        return

    def insert_prot_seq_version_simple(self, seq_id, source_ver_id, primary_acc):
        i = """INSERT IGNORE INTO prot_seq_version
        (seq_id, source_ver_id, primary_acc)
        VALUES
        ('%d','%d','%s')"""
        
        inputtup = tuple(map(mysqlutils.quote, (seq_id, source_ver_id, primary_acc)))
        self.dbw.execute( i % inputtup)
        
        return

    def compare_prot_seq_accessions(self, seq_id, accessions):
        q = """SELECT accession FROM prot_seq_accession
        WHERE seq_id='%d'
        ORDER BY acc_order"""

        self.dbw.execute( q % seq_id)
        f_accessions = self.dbw.fetchall()

        identical = True
        for (i,acc) in enumerate(accessions):
            if i > len(f_accessions)-1 or  f_accessions[i][0] != acc:
                identical = False
                break
            
            if identical:
                return seq_id
            else:
                return None
            
    def compare_prot_seq_str(self, seq_id, sequence, length, crc, molecular_weight):
        q = """SELECT seq_id FROM prot_seq
        JOIN prot_seq_str USING (seq_str_id)
        WHERE seq_id = '%d'
        AND sequence = '%s'
        AND length = '%s'
        AND crc = '%s'
        AND molecular_weight = '%s'"""
        
        inputtup = tuple(map(mysqlutils.quote, (seq_id, sequence, length, crc,
                                                molecular_weight)))
        self.dbw.execute( q % inputtup)
        f_seq_id = self.dbw.fetchsingle()
        if f_seq_id: return f_seq_id
        else:
            return None
        

    def compare_sp_crossrefs( self, seq_id, crossrefs):
            q = """SELECT source_name, id_0, id_1, id_2, id_3 FROM sp_cross_ref
            JOIN prot_seq_source using (source_id)
            WHERE seq_id='%d'
            ORDER BY source_name, id_0, id_1, id_2, id_3"""
            
            # compare without case, like MySQL
            # we are only concerned with the first 5 tuple fields.  Ignore others
            def compare_crossrefs( f1, f2):
                # compare as upper case, like MySQL
                for i in range(0,5):
                    if i >= len(f1)-1 or i >= len(f2)-1:
                        return 0
                    f1_str = str(f1[i]).upper()
                    f2_str = str(f2[i]).upper()
                    ret = cmp(f1_str, f2_str)
                    if ret != 0:
                        return ret
                    return 0
                
            # drop all but the first 5 tuple elements,
            # padded with zero-length strings
            crossrefs = [x+tuple(['']*5) for x in crossrefs]
            crossrefs = [x.__getslice__(0,5) for x in crossrefs]
            crossrefs.sort(compare_crossrefs)
            
            self.dbw.execute(q % seq_id)
            f_crossrefs = self.dbw.fetchall()
            f_crossrefs = list(f_crossrefs)

            
            if f_crossrefs == crossrefs:
                return seq_id
            else:
                return None
            
            
    def insert_sp_crossref( self, seq_id, source_name, id_0='', id_1='',
                            id_2='', id_3=''):
        i = """INSERT IGNORE INTO sp_cross_ref
        (seq_id, source_id, id_0, id_1, id_2, id_3)
        VALUES
        ('%d', '%d','%s','%s','%s','%s')"""
        # Ignore is here to get around a duplicate MGI entry in Q3TLP9.
        # The duplicate is not found on the website.
        
        source_id = self.lookup_prot_seq_source( source_name)
        inputtup = tuple(map(mysqlutils.quote,
                             (seq_id, source_id, id_0, id_1, id_2, id_3)))
        self.dbw.execute(i % inputtup)
        return

    def insert_sp_crossrefs( self, seq_id, crossrefs):
        # use only the first 5 tuple elements
        crossrefs = [x.__getslice__(0,5) for x in crossrefs]
        for crossref in crossrefs:
            self.insert_sp_crossref( seq_id, *crossref)
        return
        
