#!/usr/bin/env python # Jacob Joseph # 16 May 2009 # Postgres database tools import psycopg2, getpass, tempfile, os, sys, logging # support SQL_IN, and other additions import psycopg2.extensions #def quote( s): # return PgSQL.PgQuoteString( s) # http://www.python.org/dev/peps/pep-0342/ def consumer(func): def wrapper(*args,**kw): gen = func(*args, **kw) gen.next() return gen wrapper.__name__ = func.__name__ wrapper.__dict__ = func.__dict__ wrapper.__doc__ = func.__doc__ return wrapper class dbwrap: # the variables PGUSER, PGPASSWORD, PGDATABASE, and PGHOST should # be set if these are not supplied as arguments, here def __init__( self, dbhost=None, dbport=None, dbname=None, dbuser=None, dbpasswd=None, debug=False, logger=None): self.log = logging.getLogger("%s%s" % (logger.name+':' if logger is not None else '', self.__module__)) # FIXME: We probably shouldn't adjust the root logger if debug: rootlogger = logging.getLogger() rootlogger.setLevel(min(logging.INFO, rootlogger.level)) if len(self.log.handlers) == 0: self.log.addHandler(logging.StreamHandler()) self.dsn = {} self.dsn['host'] = dbhost self.dsn['port'] = dbport self.dsn['dbname'] = dbname self.dsn['user'] = dbuser self.dsn['password'] = dbpasswd self.log.debug("Database DSN: %s", self.dsn) self.conn = None self.curs = None self.ncurs = None self.ready = 0 def open( self, get_pass=True): self.log.info("Opening PgSQL server connection (%s)", self.dsn['host']) dsn_string = "" for k,v in self.dsn.items(): if v is None: continue dsn_string += "%s=%s " % (k,v) try: self.conn = psycopg2.connect( dsn_string) except psycopg2.OperationalError, m: if get_pass and str(m).find("password") >= 0: user = raw_input("Database username [%s]:" % getpass.getuser()) if len(user) == 0: user = getpass.getuser() self.dsn['user'] = user self.dsn['password'] = getpass.getpass("Database password:") self.open(get_pass=False) else: raise #self.conn.autocommit = 0 # a cursor must be named to support server-side queries # howerver, named cursors may only be used once self.curs = self.conn.cursor() self.ready = 1 def named_curs_open(self): """Overwrite the existing named cursor. Should be called before executing.""" if self.ncurs is not None: try: self.ncurs.close() except psycopg2.ProgrammingError, e: if e.message.find("named cursor isn't valid anymore") >= 0: pass else: raise self.conn.commit() # the current transaction must end self.ncurs = self.conn.cursor("named_cursor") def close( self): self.conn.close() self.ready = 0 self.conn = None self.curs = None def commit( self): # this will fail if the connection is not open, but what else # can be done? if self.ready: self.conn.commit() def execute( self, s, args=None, stream=False): """If stream is true, use a new named cursor""" if not self.ready: self.open() try: if stream: self.named_curs_open() self.ncurs.execute(s, args) else: self.curs.execute( s, args) except: self.log.error("Execute: Error with the follwing statement:\n\n%s", self.curs.mogrify(s, args)) raise return def fetchall_iter(self, q = None, args=None, batch_size=1000): if q is not None: self.execute( q, args, stream=True) if not self.ready: m = "fetchall: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m while True: results = self.ncurs.fetchmany(batch_size) if not results: break for result in results: yield result def fetchall( self, q = None, args=None): """Return an iterator of rows""" if q is not None: self.execute( q, args) #self.ping() if self.ready: return self.curs.fetchall() else: m = "fetchall: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m def fetchone( self, q = None, args=None): "Fetch one row." if q is not None: self.execute( q, args) #self.ping() if self.ready: return self.curs.fetchone() else: m = "fetchone: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m def fetchone_d( self, q = None, args=None): "Fetch one row, as a dictionary of field names." if q is not None: self.execute( q, args) #self.ping() if self.ready: row = self.curs.fetchone() desc = self.curs.description field_names = [ field[0] for field in desc] if row is None: return None ret = {} for (field, val) in zip(field_names, row): ret[field] = val return ret else: m = "fetchone_d: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m def fetchmany( self, q = None, args=None): if q is not None: self.execute( q, args) #self.ping() if self.ready: return self.curs.fetchmany() else: m = "fetchmany: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m def fetchsingle( self, q = None, args = None): """Fetch a single value. Previous execute must only return a single row,column""" if q is not None: self.execute( q, args) #self.ping() if self.ready: tup = self.curs.fetchone() if not tup: return None if len(tup) == 1: return tup[0] else: m = "fetchsingle: more than one column returned" self.log.error(m) raise m else: m = "fetchsingle: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m def lastrowid(self): assert False, "lastrowid() Unimplemented. Use e.g., INSERT INTO ... RETURNING seq_id" def fetchcolumn( self, q = None, args = None): """Fetch a single column, returned as an array""" if q is not None: self.execute( q, args) #self.ping() if self.ready: ret = self.curs.fetchall() return [ a[0] for a in ret ] else: m = "fetchcolumn: database is not open --\n" \ + " was last op a query as it should have been?" self.log.error(m) raise m def insert_dict( self, table, rows): """Insert a single row, from a dictionary with keys corresponding to column names. If 'rows' is a dictionary, insert one row. 'rows' may also be a list of dictionaries, each of which will be inserted simultaneously (all rows must specify the same set of columns).""" if type(rows) is dict: cols = rows.keys() elif type(rows) is list: cols = rows[0].keys() else: assert False, "Unknown rows type: '%s'" % type(rows) i_base = "INSERT INTO %s (%s)\n VALUES (%s)" i_cols = reduce( lambda x, y: x + ', ' + y, cols) #i_val assert False, "Not fully implemented" # FIXME: Streaming queries were used with mysql. Are these needed # with postgres? def ssexecute(self, s): self.execute(s) def ssfetchmany(self, count): return self.fetchmany() @consumer def copy_from(self, table, columns, batch_size=100000, format_str=None): """Returns a generator that is used to send data by COPY back to the database. Repeatedly call .send() with a tuple of row data. This will be cached to a temporary file, and executed on the server when batch_size rows are queued, or .close() is called. BE SURE TO CALL .close()""" if not self.ready: self.open() if format_str is None: format_str = "" for col in columns: format_str += "%r\t" format_str = format_str[:-1] tmp = tempfile.TemporaryFile() lines = 0 try: while True: row_tup = yield tmp.write(format_str % row_tup + '\n') lines += 1 if lines >= batch_size: tmp.seek(0) self.curs.copy_from( tmp, table, columns=columns) self.commit() lines = 0 tmp.close() tmp = tempfile.TemporaryFile() except GeneratorExit: #print >> sys.stderr, "received close()" tmp.seek(0) self.curs.copy_from( tmp, table, columns=columns) self.commit() tmp.close() return @consumer def copy_from_innertry(self, table, columns, batch_size=100000, format_str=None): """Returns a generator that is used to send data by COPY back to the database. Repeatedly call .send() with a tuple of row data. This will be cached to a temporary file, and executed on the server when batch_size rows are queued, or .close() is called. BE SURE TO CALL .close()""" # This is somewhat slower if not self.ready: self.open() if format_str is None: format_str = "" for col in columns: format_str += "%r\t" format_str = format_str[:-1] tmp = tempfile.TemporaryFile() lines = 0 while True: try: row_tup = yield tmp.write(format_str % row_tup + '\n') lines += 1 except GeneratorExit: #print "received close()" tmp.seek(0) self.curs.copy_from( tmp, table, columns=columns) self.commit() tmp.close() return except: # finish the database transaction? raise else: if lines >= batch_size: tmp.seek(0) self.curs.copy_from( tmp, table, columns=columns) self.commit() lines = 0 tmp.close() tmp = tempfile.TemporaryFile() if __name__ == "__main__": dbw = dbwrap() dbw.open()