#!/usr/bin/env python

from JJutil import mysqlutils

def insert_prot_seq_accession():
    dbw = mysqlutils.dbwrap( dbhost, dbuser, dbpass, "DurandLab2")

    q0 = """select prot_seq.seq_id, prot_seq_source.source_name_id,
sourceID1, GB_taxID_list
from prot_seq, prot_seq_source, prot_seq_source_name, DurandLab.prot_seq
where source_name = 'swissprot' and version='48.8'
and source = 'sp'
and prot_seq_source.source_name_id = prot_seq_source_name.source_name_id
and prot_seq_source.source_id = prot_seq.source_id
and DurandLab2.prot_seq.length = DurandLab.prot_seq.length
and DurandLab2.prot_seq.description = DurandLab.prot_seq.description
and DurandLab2.prot_seq.sequence = DurandLab.prot_seq.sequence limit 100"""

    q1 = """select prot_seq.seq_id, prot_seq_source.source_name_id,
sourceID2, GB_taxID_list
from prot_seq, prot_seq_source, prot_seq_source_name, DurandLab.prot_seq
where source_name = 'swissprot' and version='48.8'
and source = 'sp'
and prot_seq_source.source_name_id = prot_seq_source_name.source_name_id
and prot_seq_source.source_id = prot_seq.source_id
and DurandLab2.prot_seq.length = DurandLab.prot_seq.length
and DurandLab2.prot_seq.description = DurandLab.prot_seq.description
and DurandLab2.prot_seq.sequence = DurandLab.prot_seq.sequence"""

    i = """insert into prot_seq_accession
(seq_id,source_name_id,primary_acc,accession,gb_tax_id) values
('%s','%s','%s','%s','%s')"""

    # handle the sourceID1 field, marking them as primary accessions
    dbw.execute(q0)
    r = dbw.fetchall()

    print "query done"

    r_split = []
    for (seq_id,source_name_id,sourceID,GB_taxID_list) in r:
        if not (sourceID[0].isalpha() and sourceID[1:].isdigit()):
            print "Skipping malformed SwissProt accession:",sourceID
            continue

        gb_split = GB_taxID_list.split(',')
        for gb in gb_split:
            # using distinct was too slow
            if not r_split.__contains__((seq_id,source_name_id,1,sourceID,gb)):
                r_split.append( (seq_id,source_name_id,1,sourceID,gb))

    print r_split[:10]
    for (seq_id,source_name_id,primary_acc,accession,gb_tax_id) in r_split:
        dbw.execute(i % (seq_id,source_name_id,primary_acc,
                         accession,gb_tax_id))
        
    dbw.commit()
    dbw.close()


if __name__ == "__main__":
    dbhost = "goby.compbio.cs.cmu.edu"
    dbuser = "jmjoseph"
    dbpass = "pQ1rng8c"

    insert_prot_seq_accession()
