#!/usr/bin/env python
# Jacob Joseph

import socket, sys
from JJutil import pgutils

class blast_lock:
    """Handle selection of each blast query chunk, locked by the blast_lock
database table."""

    def __init__(self, br_id=None, set_id=None,
                 chunk_size=None, max_chunk=None,
                 query_set_id = None,
                 debug=False):
        self.br_id = br_id
        self.set_id = set_id
        self.query_set_id = query_set_id if query_set_id is not None else set_id
        self.chunk_size = chunk_size
        self.max_chunk = max_chunk
        self.debug = debug

        self.dbw = pgutils.dbwrap( debug=self.debug)

    def lock( self):
        self.dbw.execute("LOCK TABLE blast_lock in SHARE ROW EXCLUSIVE MODE")


    def unlock( self):
        #self.dbw.execute("UNLOCK TABLES")

        #unlock doesn't exist in postgres.  ending the transaction
        #finishes it

        self.dbw.commit()
        

    def init_chunks( self):

        assert (self.br_id is not None
                and self.set_id is not None
                and self.query_set_id is not None
                and self.chunk_size is not None
                and self.max_chunk is not None), """Error: Specify br_id, set_id, query_set_id, chunk_size,
        and max_chunk to create a new run (%s, %s, %s, %s, %s)""" % (
            self.br_id, self.set_id, self.query_set_id, self.chunk_size,
            self.max_chunk)

        self.lock()
        # check whether we have already been initialized
        q = """SELECT * from blast_lock
        WHERE br_id=%(br_id)s"""
        self.dbw.execute( q, {'br_id': self.br_id})
        ret = self.dbw.fetchall()
        if len(ret) > 0:
            print "Error: Table 'blast_lock' already contains rows for br_id",self.br_id
            sys.exit(1)

        i = """INSERT INTO blast_lock
        (br_id, set_id, query_set_id, chunk_size, max_chunk, chunk)
        VALUES
        (%(br_id)s, %(set_id)s, %(query_set_id)s, %(chunk_size)s,
        %(max_chunk)s, %(chunk)s)"""

        for chunk in range(0,self.max_chunk+1):
            self.dbw.execute(i, {'br_id': self.br_id,
                                 'set_id': self.set_id,
                                 'query_set_id': self.query_set_id,
                                 'chunk_size': self.chunk_size,
                                 'max_chunk': self.max_chunk,
                                 'chunk': chunk})
        self.unlock()
        return

    def next_chunk_i( self, nretries=3):
        self.lock()
        
        q = """SELECT br_id, set_id, query_set_id, chunk, chunk_size
        FROM blast_lock WHERE start_time is NULL"""
        if self.br_id:
            q += """ AND br_id=%(br_id)s"""
        q += " ORDER BY chunk LIMIT 1"

        #q += "\nFOR UPDATE"   # is this enough?
        
        self.dbw.execute(q, {'br_id': self.br_id})
        ret = self.dbw.fetchall()

        if len(ret) == 0:
            print "No more chunks."
            self.unlock()
            return None

        (self.br_id, self.set_id, self.query_set_id,
         chunk, self.chunk_size) = ret[0]

        u = """UPDATE blast_lock
        SET start_time = NOW(), hostname = %(hostname)s
        WHERE br_id=%(br_id)s AND chunk=%(chunk)s"""
        self.dbw.execute( u, {'hostname': socket.gethostname(),
                               'br_id': self.br_id,
                               'chunk': chunk})

        self.unlock()
        return chunk

    def complete_chunk_i( self, chunk):
        self.lock()

        #d = """DELETE from blast_lock
        #WHERE br_id='%d' AND chunk='%d'"""
        #self.dbw.execute( d % (self.br_id, chunk))
        #self.unlock()

        u = """UPDATE blast_lock
        SET end_time=NOW()
        WHERE br_id=%(br_id)s AND chunk=%(chunk)s"""
        self.dbw.execute( u, {'br_id': self.br_id,
                              'chunk': chunk})
        self.unlock()
        return

    def delete_all( self):
        "Remove all blast_lock rows for the current br_id"
        self.lock()
        d = """DELETE FROM blast_lock where br_id=%(br_id)s"""
        self.dbw.execute(d, {'br_id': self.br_id})
        self.dbw.commit()
        self.unlock()

    def reset_all( self):
        "Reset all incomplete jobs in blast_lock."
        self.lock()
        u = """UPDATE blast_lock
        SET start_time=NULL, hostname=NULL
        WHERE start_time is not NULL
        AND end_time is NULL"""
        self.dbw.execute(u)
        self.unlock()
        return
        
