#!/usr/bin/env python
# Jacob Joseph
# 2008-02-13

import socket,sys
#from JJutil import mysqlutils
from DurandDB import blastq

class nc_lock:
    """Handle selection of each NC query chunk, locked by the nc_lock
database table."""
    dbw = None
    bq = None
    nc_id = None
    chunk_size = None
    nc_info = None
    blast_info = None

    def __init__(self, nc_id=None, chunk_size=None):
        self.nc_id = nc_id
        self.chunk_size = chunk_size

        self.bq = blastq.blastq()
        self.dbw = self.bq.dbw

    def lock( self):
        self.dbw.execute("LOCK TABLE nc_lock in SHARE ROW EXCLUSIVE MODE")

    def unlock( self):
        self.dbw.commit()

    def init_chunks( self):
        if not (self.nc_id and self.chunk_size):
            print "Error: Specify nc_id and chunk_size to create a new run"
            sys.exit(1)

        self.nc_info = self.bq.fetch_nc_info_d( self.nc_id)
        self.blast_info = self.bq.fetch_blast_info_d( self.nc_info['br_id'])

        self.lock()
        # check if we have already been initialized
        q = """SELECT * from nc_lock
        WHERE nc_id='%d'"""
        self.dbw.execute( q % self.nc_id)
        ret = self.dbw.fetchall()
        if len(ret) > 0:
            print "Error: Table 'nc_lock' already contains rows for nc_id", self.nc_id
            sys.exit(1)

        i = "INSERT INTO nc_lock VALUES "
        i_row = "('%d', '%d', '%d', '%d', NULL, NULL, NULL)"

        num_seqs = self.blast_info['num_sequences']
        for (chunk, chunk_min) in enumerate(range(0, num_seqs, self.chunk_size)):
            chunk_max = min( chunk_min + self.chunk_size - 1, num_seqs - 1)
            if chunk != 0: i += ", "
            i += i_row % (self.nc_id, chunk, chunk_min, chunk_max)

        self.dbw.execute( i)
        self.unlock()
        return

    def next_query_set( self):
        next_c = self.next_chunk()
        if next_c is None:
            return None

        (chunk_i, chunk_min, chunk_max) = next_c
        
        self.nc_info = self.bq.fetch_nc_info_d( self.nc_id)
        self.blast_info = self.bq.fetch_blast_info_d( self.nc_info['br_id'])
        set_id = self.blast_info['set_id']

        limit = chunk_max - chunk_min + 1
        offset = chunk_min

        seq_set = self.bq.fetch_seq_set( set_id=set_id,
                                         limit=limit, offset=offset)
        return (self.nc_id, chunk_i, seq_set)
        
    def next_chunk( self):
        self.lock()
        
        q = """SELECT nc_id, chunk, chunk_min, chunk_max
        FROM nc_lock WHERE start_time is NULL"""
        if self.nc_id is not None:
            q += """ AND nc_id='%d'""" % self.nc_id
        q += " ORDER BY chunk LIMIT 1"

        self.dbw.execute(q)
        ret = self.dbw.fetchall()

        if len(ret) == 0:
            print "No more chunks."
            self.unlock()
            return None

        (self.nc_id, chunk, chunk_min, chunk_max) = ret[0]

        u = """UPDATE nc_lock
        SET start_time = NOW(), hostname = '%s'
        WHERE nc_id='%d' AND chunk='%d'"""
        self.dbw.execute( u % (socket.gethostname(), self.nc_id, chunk))

        self.unlock()
        return (chunk, chunk_min, chunk_max)

    def complete_chunk_i( self, chunk):
        self.lock()

        u = """UPDATE nc_lock
        SET end_time=NOW()
        WHERE nc_id='%d' AND chunk='%d'"""
        self.dbw.execute( u % (self.nc_id, chunk))
        self.unlock()
        return

    def delete_all( self):
        "Remove all nc_lock rows for the current nc_id"
        self.lock()
        d = """DELETE FROM nc_lock where nc_id='%d'"""
        self.dbw.execute(d % self.nc_id)
        self.dbw.commit()
        self.unlock()

    def reset_all( self):
        "Reset all incomplete jobs in nc_lock."
        self.lock()
        u = """UPDATE nc_lock
        SET start_time=NULL, hostname=NULL
        WHERE start_time is not NULL
        AND end_time is NULL"""
        self.dbw.execute(u)
        self.unlock()
        return
        
