#!/usr/bin/env python
# Jacob Joseph

import os
from JJutil import pgutils

class seq_set:
    """Handle creation of a prot_seq_set in the database, generation
of the associated blast database, and generation of a series of fasta
query files. For now, all sequences from an organism are used."""

    dbw = None
    set_name = None
    set_desc = None
    gb_tax_id_list = None       # [4932, ...]
    db_version_dict = None      # {'db': 'version', ...}
    set_id = None
    chunk_size = None
    nofragments = None          # Do not include Uniprot fragments
    tmpdir = None
    
    max_chunk = None         # the last query chunk filename index

    def __init__(self, name=None, desc=None, tax_ids=None,
                 db_versions=None, set_id=None, chunk_size=5000,
                 nofragments=False,
                 tmpdir=None,
                 write_chunks=True, write_blastdb=True ):
        self.set_name = name
        self.set_desc = desc
        self.gb_tax_id_list = tax_ids
        self.db_version_dict = db_versions
        self.set_id = set_id
        self.chunk_size = chunk_size
        self.nofragments = nofragments
        if self.tmpdir is None:
            self.tmpdir = os.path.expandvars("$HOME/tmp/blast_input/")
        else:
            self.tmpdir = tmpdir

        if not os.path.exists( self.tmpdir):
            os.mkdir( self.tmpdir)

        self.dbw = pgutils.dbwrap()

        print "Initializing protein sequence set"
        
        # if given a set_id, just create the blast database and fasta files
        if self.set_id is None:
            if tax_ids is None or db_versions is None:
                assert False, "If set_id is unspecified, tax_ids and db_versions must be given"
            self.create_prot_seq_set()
            self.populate_prot_seq_set_member()
        else:
            if self.set_name is None or self.set_desc is None:
                (self.set_name, self.set_descr) = self.fetch_set_info()

        self.write_set_fasta(write_chunks=write_chunks)
        if write_blastdb: self.write_blast_db()

        return

    def fetch_set_info(self, set_id=None):
        """Fetch (name, description) for a sequence set"""

        q = """SELECT name, description
        FROM prot_seq_set
        WHERE set_id = %(set_id)s"""

        if set_id is None: set_id = self.set_id

        self.dbw.execute( q, locals())
        return self.dbw.fetchone()
    
    
    def create_prot_seq_set( self):
        "Create a protein sequence set."
        print "  Creating a new protein sequence set..."

        i = """INSERT INTO prot_seq_set
        (name, description)
        VALUES
        (%(set_name)s, %(set_desc)s)
        RETURNING set_id"""

        self.dbw.execute( i, {'set_name': self.set_name,
                                  'set_desc': self.set_desc})
        self.set_id = self.dbw.fetchsingle()
        self.commit()
        return

    def populate_prot_seq_set_member( self):
        """Populate prot_seq_set from a list of NCBI taxonomy IDs and
a dictionary of database versions. Set nofragments to exclude DE
(name) lines with '(fragment'."""
        print "  Poplulating protein sequence set...", self.set_id

        i_base = """INSERT INTO prot_seq_set_member
        SELECT '%d' AS set_id, prot_seq.seq_id FROM prot_seq_version
        JOIN sp_gb_tax_id USING (seq_id)
        JOIN prot_seq_source_ver USING (source_ver_id)
        JOIN prot_seq_source USING (source_id)
        JOIN prot_seq USING (seq_id)"""

        i_nofrag_join = """JOIN sp_xtra using (seq_id)"""
        
        i_base_cond = """WHERE (%s)
        AND (%s)"""

        i_nofrag_cond = """AND description not like '%(fragment%'"""

        i_db = """(source_name='%s' AND version='%s')"""
        i_dbs = ""
        for (j, (db,version)) in enumerate(self.db_version_dict.items()):
            # include OR before successive elements
            if j != 0: i_dbs += " OR "
            i_dbs += i_db % (db, version)

        i_gb_tax_id = """gb_tax_id='%d'"""
        i_gb_tax_ids = ""
        for (j, gb_tax_id) in enumerate(self.gb_tax_id_list):
            # include OR before successive elements
            if j != 0: i_gb_tax_ids += " OR "
            i_gb_tax_ids += i_gb_tax_id % gb_tax_id

        # construct the query from the above fragments
        if self.nofragments: i = i_base + '\n' + i_nofrag_join + '\n' + i_base_cond
        else: i = i_base + '\n' + i_base_cond
        i = i % (self.set_id, i_dbs, i_gb_tax_ids)
        if self.nofragments: i += i_nofrag_cond

        self.dbw.execute(i)
        self.commit()
        return

    def write_set_fasta( self, write_chunks=True):
        """Write a fasta file out of the entire protein sequence set
for later database creation with formatdb. Also, write a series of
files containing small chunks of this database to allow incremental,
parallel blast runs and insertions."""
        print "  Writing protein sequence set (%d) fasta file(s)..." % self.set_id
        
        # fetch all protein sequences
        q = """SELECT seq_id, sequence FROM prot_seq_set_member
        JOIN prot_seq USING (seq_id)
        JOIN prot_seq_str USING (seq_str_id)
        WHERE set_id=%(set_id)s"""

        self.dbw.execute(q, {'set_id': self.set_id})
        ret = self.dbw.fetchall()

        db_out = open(os.path.join(self.tmpdir,
                                   str(self.set_id)+'.fasta'), 'w')
        chunk_i = 0
        if write_chunks: 
            chunk_out = open(os.path.join(self.tmpdir,
                             str(self.set_id)+'.fasta.'+str(chunk_i)), 'w')

        for (i,(seq_id,sequence)) in enumerate(ret):
            print >> db_out, ">%d" % seq_id
            print >> db_out, sequence

            if not write_chunks: continue
            
            # start a new query fasta file if necessary
            if i > 0 and (i % self.chunk_size == 0):
                chunk_out.close()
                chunk_i += 1
                chunk_out = open(os.path.join(
                    self.tmpdir,str(self.set_id)+'.fasta.'+str(chunk_i)),'w')

            print >> chunk_out, ">%d" % seq_id
            print >> chunk_out, sequence

        db_out.close()
        if write_chunks: chunk_out.close()
        self.max_chunk = chunk_i
        return

    def write_blast_db( self):
        """Create a blast database in the current directory."""
        print "  Creating blast database set (%d)..." % self.set_id
        
        #os.system( "./formatdb -i " + fasta_file)
        os.system( "./blast-2.2.15/bin/formatdb -i " +
                   os.path.join(self.tmpdir,str(self.set_id) + '.fasta'))
        return
