#!/usr/bin/env python

import os, time, sys, gzip
#from Bio.SwissProt import SProt
from JJutil import rate, pgutils

class insert_generic(object):

    dbw = None
    tmpdir = None
    source_info = None

    def __init__(self, tmpdir=None):

        dbhost = "goby.compbio.cs.cmu.edu"
        db = "DurandLab2"
        debug = True

        self.dbw = pgutils.dbwrap( dbhost = dbhost, dbname = db, debug = debug)

        if tmpdir is None:
            tmpdir = os.path.expandvars('$HOME/tmp')
        else:
            self.tmpdir = tmpdir
            
        self.source_info = {}

    def insert_prot_seq_version_simple(self, seq_id, source_ver_id, primary_acc):
        "Insert (only) a new data source version for an existing sequence id"
        
        i = """INSERT INTO prot_seq_version
        (seq_id, source_ver_id, primary_acc)
        VALUES
        (%(seq_id)s, %(source_ver_id)s, %(primary_acc)s)"""
        
        self.dbw.execute( i,  locals())
        return

    
    def insert_prot_seq(self, sequence, length, crc, molecular_weight):
        "Insert a new sequence id"
        
        i = """INSERT into prot_seq
        (seq_str_id)
        VALUES ( %(seq_str_id)s )
        RETURNING seq_id"""

        seq_str_id = self.lookup_prot_seq_str( sequence, length, crc, molecular_weight)
        
        # insert the sequence entry
        #print "seq_str_id:", seq_str_id
        seq_id = self.dbw.fetchsingle( i, locals())
        self.dbw.commit()
        return seq_id

    def fetch_prot_seq_str(self, seq_id):
        """Return the protein sequence string row for a specific sequence"""
        q = """SELECT * FROM prot_seq_str
        WHERE seq_id = %(seq_id)s"""

        return self.dbw.fetchone( q, locals())

    def lookup_prot_seq_str(self, sequence, length, crc, molecular_weight,
                            dont_insert=False):
        "Lookup (return existing or insert) a protein sequence string id"

        q = """SELECT seq_str_id FROM prot_seq_str
        WHERE sequence = %(sequence)s
        AND length= %(length)s
        AND crc= %(crc)s
        AND molecular_weight= %(molecular_weight)s"""
        
        i = """INSERT INTO prot_seq_str
        (sequence, length, crc, molecular_weight)
        VALUES
        (%(sequence)s, %(length)s, %(crc)s, %(molecular_weight)s)
        RETURNING seq_str_id"""
        
        # already exists
        seq_str_id = self.dbw.fetchsingle(q, locals())
        
        if seq_str_id is None and not dont_insert:
            # insert the sequence
            seq_str_id = self.dbw.fetchsingle( i, locals())
            self.dbw.commit()
            
        return seq_str_id


    def compare_prot_seq_str(self, seq_id, sequence,
                             length, crc, molecular_weight):
        "Compare (return existing id) a protein sequence string"
        
        q = """SELECT seq_id FROM prot_seq
        JOIN prot_seq_str USING (seq_str_id)
        WHERE seq_id = %(seq_id)s
        AND sequence = %(sequence)s
        AND length = %(length)s
        AND crc = %(crc)s
        AND molecular_weight = %(molecular_weight)s"""
        
        f_seq_id = self.dbw.fetchsingle( q, locals()) 
        return f_seq_id # None if does not exist

    
    def lookup_prot_seq_source( self, source_name):
        if self.source_info.get(source_name,{}).has_key('source_id'):
            return self.source_info[source_name]['source_id']

        if not self.source_info.has_key(source_name):
            self.source_info[source_name] = {}

        q = """SELECT source_id FROM prot_seq_source
        WHERE source_name = %(source_name)s"""

        i = """INSERT INTO prot_seq_source
        (source_name)
        VALUES (%(source_name)s)
        RETURNING source_id"""

        # already exists
        source_id = self.dbw.fetchsingle(q, locals())
        if source_id is None:
            # insert the source, returning the created serial
            source_id = self.dbw.fetchsingle(i, locals())
            self.dbw.commit()

        self.source_info[source_name]['source_id'] = source_id
        return source_id

    def fetch_prot_seq_source_ver( self, source_name, version, date):
        q = """SELECT source_ver_id FROM prot_seq_source_ver
        JOIN prot_seq_source USING (source_id)
        WHERE source_name=%(source_name)s AND version=%(version)s
        AND date=%(date)s"""

        return self.dbw.fetchsingle(q, locals())

    def lookup_prot_seq_source_ver( self, source_name, version, date):
        """fetch the source_id, inserting first if necessary"""

        if self.source_info.get(source_name,{}).has_key('source_ver_id'):
            return self.source_info[source_name]['source_ver_id']

        i = """INSERT INTO prot_seq_source_ver
        (source_id, version, date)
        VALUES
        (%(source_id)s, %(version)s, %(date)s)
        RETURNING source_ver_id"""

        # already exists?
        source_ver_id = self.fetch_prot_seq_source_ver(
            source_name, version, date)
        if source_ver_id is None:
            # lookup the source_id, inserting if necessary
            source_id = self.lookup_prot_seq_source( source_name)

            # now insert the particular source version
            source_ver_id = self.dbw.fetchsingle(i, locals())
            self.dbw.commit()
            
        self.source_info[source_name]['source_ver_id'] = source_ver_id
        return source_ver_id
    

    # this should be extended to handle cross references with different source_ids
    def insert_prot_seq_accessions(self, seq_id, accessions):
        "Insert all accessions for a sequence id"
        q_prefix = """INSERT INTO prot_seq_accession
        (seq_id, source_id, acc_order, accession)
        VALUES """

        # FIXME: accession should be properly quoted here
        q_value = """(
        %(seq_id)d,
        (SELECT source_id FROM prot_seq_version
        JOIN prot_seq_source_ver USING (source_ver_id)
        WHERE seq_id = %(seq_id)d),
        %(i)d, %(accession)s)"""
        
        q = q_prefix
        firstentry = True
        for (i,accession) in enumerate(accessions):
            if not firstentry: q += ', '
            firstentry = False
            
            q += q_value % locals()
            
        # execute the insert
        self.dbw.execute(q)
        return


    def compare_prot_seq_accessions(self, seq_id, accessions):
        "Compare (return existing) a protein accessions set"
        
        q = """SELECT accession FROM prot_seq_accession
        WHERE seq_id=%(seq_id)s
        ORDER BY acc_order"""

        f_accessions = self.dbw.fetchcolumn( q, locals())
        f_accessions.sort()
        accessions.sort()

        if f_accessions == accessions:
            return seq_id
        return None
