#!/usr/bin/env python

# Jacob Joseph
# 2008-01-23
# A class to run blast on a previously defined sequence set.

import os, sys, re, time
from Bio.Blast import NCBIStandalone, NCBIXML
from JJutil import rate

from DurandDB import blastq

class blast_run:
    """Run and insert blast for a particular fasta query file."""

    bq = None
    set_id = None
    query_set_id = None
    br_id = None
    blast_backup_dir = None
    tmpdir = None
    
    def __init__(self, tmpdir=None,
                 blast_backup_dir=None,
                 debug=False):
        
        if tmpdir is None:
            self.tmpdir = os.path.expandvars("$HOME/tmp/blast_input")
        else:
            self.tmpdir = tmpdir

        if blast_backup_dir is None:
            self.blast_backup_dir = os.path.expandvars("$HOME/blast_backup/")
        else:
            self.blast_backup_dir = blast_backup_dir

        for p in (self.tmpdir, self.blast_backup_dir):
            if not os.path.exists( p): os.mkdir( p)

        self.debug = debug
        self.bq = blastq.blastq(debug=self.debug)

        #print "Initializing blast run for protein sequence set",self.set_id
        return

    def new_run( self, set_id, param_str, comment, blastall_path,
                 query_set_id=None):
        """Create a blast_run entry. Specifying query_set_id restricts
        blast queries to only a specific set.  Hits will be in set_id."""
        self.set_id =  set_id

        if query_set_id is None: query_set_id = set_id
        self.query_set_id = query_set_id
        
        i = """INSERT INTO blast_run
        (set_id, date, num_sequences, num_residues, params, comment, blastall_path,
        query_set_id)
        VALUES
        (%(set_id)s,
        NOW(),
        (SELECT COUNT(*) FROM prot_seq_set_member WHERE set_id=%(set_id)s),
        (SELECT SUM(length) FROM prot_seq_set_member
         JOIN prot_seq USING (seq_id)
         JOIN prot_seq_str USING (seq_str_id)
         WHERE set_id=%(set_id)s),
        %(param_str)s, %(comment)s, %(blastall_path)s,
        %(query_set_id)s)
        RETURNING br_id"""

        self.br_id = self.bq.dbw.fetchsingle( i, locals())

        print "  Creating a new blast run (%d) for set (%d)" % (self.br_id, self.set_id)

        self.bq.dbw.commit()
        return self.br_id

    def fetch_info(self, br_id=None):
        q_blast = """SELECT set_id, num_sequences, num_residues, params, blastall_path,
        query_set_id
        FROM blast_run
        WHERE br_id = %(br_id)s"""

        if self.br_id is None:
            if br_id is None:
                print "fetch_info(%s): No br_id specified" % br_id
                sys.exit(1)
            else:
                self.br_id = br_id
        elif self.br_id != br_id:
            print """fetch_info(%s): This instance of blast run already initialized with
br_id %s.""" % (br_id, self.br_id)
            sys.exit(1)

        self.bq.dbw.execute( q_blast, {'br_id':self.br_id})
        (self.set_id, self.num_sequences, self.num_residues,
         self.param_str, self.blastall_path,
         self.query_set_id) = self.bq.dbw.fetchone()
        return
        
    def continue_run( self, chunk_i, br_id=None):
        "Execute a particular br_id chunk."

        self.fetch_info(br_id)

        blast_out = os.path.join(self.blast_backup_dir,
                                 'br_id_' + str(self.br_id) +
                                 'chunk_' + str(chunk_i) + '.gz')
        query_file = os.path.join(self.tmpdir,
                                  str(self.query_set_id)+'.fasta.'+str(chunk_i))
        
        pair_rate = self.execute_blast( blast_out, chunk_i, query_file)
        #pair_rate = rate.rate()
        self.insert_blast( blast_out, pair_rate)

    # assume a database in the current directory
    # FIXME: figure out how to pass blast params intelligently
    def execute_blast( self, blast_out, chunk_i, query_file):
        q_lock = """SELECT chunk_size
        FROM blast_lock
        WHERE chunk=%(chunk_i)s
        AND br_id=%(br_id)s"""

        self.bq.dbw.execute( q_lock, {'chunk_i': chunk_i, 'br_id': self.br_id})
        chunk_size = self.bq.dbw.fetchsingle()
                              
        print "  Executing blast run", self.br_id
        print "    Query file: ", query_file
        print "    Num Sequences: %d, Num Residues: %d" % (self.num_sequences, self.num_residues)
        print "    Parameters:", self.param_str
        print "    Blastall path:", self.blastall_path
        print "    Chunk size:", chunk_size

        # write the blast results to a file
        # this allows a backup of the run
        #print "now writing to file..."
        b_out_fd = os.popen('gzip > '+blast_out, 'w')

        # Evaluate the parameters from the database
        blast_params = eval(self.param_str, self.__dict__)

        blast_db = os.path.join( self.tmpdir,
                                 str(self.set_id) + '.fasta')

        # expectation and search_length parameters chosen from NC manuscript
        (b_out,b_err) = NCBIStandalone.blastall(
            self.blastall_path,
            'blastp', blast_db, query_file, **blast_params)

        # Rate of blast processing
        query_rate = rate.rate(totcount = chunk_size)
        pair_rate = rate.rate()

        # count substrings in the output to monitor blast output. we
        # could miss a few if these occur and the bounds of the read
        # blocks, but we're looking only for a rough approximation
        # anyway
        cnt_re_iter = re.compile('<Iteration>')
        cnt_re_hsp = re.compile('<Hsp>')

        while True:
            #l = b_out.readline()
            #b_out_fd.writelines( [l])
            block = b_out.read(10485760) # approx 10MB blocks
            #print "write, parse block", time.time()
            b_out_fd.write( block)
            
            if block is None or block == '':
                break

            # Write a bit of status per <?xml block
            #if l.__contains__('<Iteration>'):           # rare

            #c_hsp = len(cnt_re_hsp.findall(l))
            c_hsp = len(cnt_re_hsp.findall(block))

            #if l.__contains__('<Hsp>'):                 # very common
            # count the number of pairs
            if c_hsp > 0:
                pair_rate.increment(c_hsp)

            #c_iter = len(cnt_re_iter.findall(l))
            c_iter = len(cnt_re_iter.findall(block))
            if c_iter > 0:
                query_rate.increment(c_iter)

                if (query_rate.count % 100 == 0 or 
                    (query_rate.count % 100 < (query_rate.count - c_iter) % 100)):
                    print "q:", query_rate
                    print "p:", pair_rate

            #print "end", time.time()
            #print query_rate
            #print pair_rate
            
        if query_rate.count == 0:
            print b_err.readlines()
            
        b_out_fd.close()
        return pair_rate

    def insert_blast( self, blast_out, pair_rate=None, max_hit_rank=0, commit=True):
        print "  Inserting blast results to database (br_id: %d)..." % self.br_id
        # using a pipe from gzip is faster (2x with xml parsing)
        # (and it runs as a separate process)
        #b_out = gzip.open( blast_out_file, 'r')
        b_out = os.popen('gzip -c -d '+blast_out, 'r')

        b_iterator = NCBIXML.parse(b_out)

        # count the number of inserts, with a total count if
        # a backup was written
        #if pair_rate:
        #    insert_rate = rate.rate(totcount = pair_rate.count)
        #else: insert_rate = rate.rate()
        insert_rate = rate.rate()

        while True:
            try:
                b_record = b_iterator.next()
            except StopIteration:
                print "No more records"
                break

            seq_id_0 = int(b_record.query)

            # We want to send blocks of (seq_id_0, hit_rank) to the database, so buffer all records for this query.
            query_rows = {}

            for (alignment_num,alignment) in enumerate(b_record.alignments):
                seq_id_1 = int(alignment.hit_def)

                for (hit_rank, hsp) in enumerate(alignment.hsps):

                    # only store the first HSP
                    if hit_rank > max_hit_rank: continue

                    # calculate sequence end from length minus gaps
                    seq_start_0 = hsp.query_start
                    seq_start_1 = hsp.sbjct_start
                    seq_end_0 = hsp.query_end
                    seq_end_1 = hsp.sbjct_end

                    # there are differences between the standalone and XML parsers
                    # these values don't seem to follow the Biopython source
                    # documentation. (None,None) essentially means 0
                    # use custom parser for correct functionality
                    identities = hsp.identities
                    gaps = hsp.gaps
                    positives = hsp.positives

                    if identities is None or identities==(None,None):
                        identities = 0
                    if gaps is None or gaps==(None,None):
                        gaps = 0
                    if positives is None or positives==(None,None):
                        positives = 0

                    # other derived information
                    if hsp.query.find('X') > -1: low_complexity=True
                    else: low_complexity=False

                    coverage = float(hsp.align_length) / ((b_record.query_letters +
                                                          alignment.length) / 2)

                    if not hit_rank in query_rows:
                        query_rows[hit_rank] = {}

                    query_rows[hit_rank][seq_id_1] = (
                        hsp.bits, hsp.expect, seq_start_0, seq_end_0,
                        seq_start_1, seq_end_1, identities, gaps,
                        positives, low_complexity, coverage)

            # Insert all rows for a query. Commit the transaction only
            # once all have been successfully inserted
            for (hit_rank, record_dict) in query_rows.iteritems():
                self.bq.insert_blast_hit(self.br_id, hit_rank,
                                         seq_id_0, record_dict,
                                         commit=False)

            # print the query insert status periodically
            if insert_rate.count % 100 == 0:
                print "insert", insert_rate
                #self.dbh.commit()
            insert_rate.increment()

        if commit: self.bq.dbw.commit()
        b_out.close()
        return
