Created
October 27, 2010 03:26
-
-
Save wjzhangq/648379 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# Copyright 2009 Facebook | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); you may | |
# not use this file except in compliance with the License. You may obtain | |
# a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
# License for the specific language governing permissions and limitations | |
# under the License. | |
"""A lightweight wrapper around MySQLdb.""" | |
import copy | |
import MySQLdb | |
import MySQLdb.constants | |
import MySQLdb.converters | |
import MySQLdb.cursors | |
import itertools | |
import logging | |
from time import time | |
class Connection(object): | |
"""A lightweight wrapper around MySQLdb DB-API connections. | |
The main value we provide is wrapping rows in a dict/object so that | |
columns can be accessed by name. Typical usage: | |
db = database.Connection("localhost", "mydatabase") | |
for article in db.query("SELECT * FROM articles"): | |
print article.title | |
Cursors are hidden by the implementation, but other than that, the methods | |
are very similar to the DB-API. | |
We explicitly set the timezone to UTC and the character encoding to | |
UTF-8 on all connections to avoid time zone and encoding errors. | |
""" | |
def __init__(self, host, database, user=None, password=None): | |
self.host = host | |
self.database = database | |
args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8", | |
db=database, init_command='SET time_zone = "+8:00"', | |
sql_mode="TRADITIONAL") | |
if user is not None: | |
args["user"] = user | |
if password is not None: | |
args["passwd"] = password | |
# We accept a path to a MySQL socket file or a host(:port) string | |
if "/" in host: | |
args["unix_socket"] = host | |
else: | |
self.socket = None | |
pair = host.split(":") | |
if len(pair) == 2: | |
args["host"] = pair[0] | |
args["port"] = int(pair[1]) | |
else: | |
args["host"] = host | |
args["port"] = 3306 | |
self._db = None | |
self._db_args = args | |
try: | |
self.reconnect() | |
except: | |
logging.error("Cannot connect to MySQL on %s", self.host, | |
exc_info=True) | |
def __del__(self): | |
self.close() | |
def close(self): | |
"""Closes this database connection.""" | |
if self._db is not None: | |
self._db.close() | |
self._db = None | |
def commit(self): | |
if self._db is not None: | |
try: | |
self._db.ping() | |
except: | |
self.reconnect() | |
try: | |
self._db.commit() | |
except Exception,e: | |
self._db.rollback() | |
logging.exception("Can not commit",e) | |
def rollback(self): | |
if self._db is not None: | |
try: | |
self._db.rollback() | |
except Exception,e: | |
logging.error("Can not rollback") | |
def reconnect(self): | |
"""Closes the existing database connection and re-opens it.""" | |
self.close() | |
self._db = MySQLdb.connect(**self._db_args) | |
self._db.autocommit(False) | |
def iter(self, query, *parameters): | |
"""Returns an iterator for the given query and parameters.""" | |
if self._db is None: self.reconnect() | |
cursor = MySQLdb.cursors.SSCursor(self._db) | |
try: | |
self._execute(cursor, query, parameters) | |
column_names = [d[0] for d in cursor.description] | |
for row in cursor: | |
yield Row(zip(column_names, row)) | |
finally: | |
cursor.close() | |
def query(self, query, *parameters): | |
"""Returns a row list for the given query and parameters.""" | |
cursor = self._cursor() | |
try: | |
self._execute(cursor, query, parameters) | |
column_names = [d[0] for d in cursor.description] | |
return [Row(itertools.izip(column_names, row)) for row in cursor] | |
finally: | |
cursor.close() | |
def get(self, query, *parameters): | |
"""Returns the first row returned for the given query.""" | |
rows = self.query(query, *parameters) | |
if not rows: | |
return None | |
elif len(rows) > 1: | |
raise Exception("Multiple rows returned for Database.get() query") | |
else: | |
return rows[0] | |
def execute(self, query, *parameters): | |
"""Executes the given query, returning the lastrowid from the query.""" | |
cursor = self._cursor() | |
try: | |
self._execute(cursor, query, parameters) | |
return cursor.lastrowid | |
finally: | |
cursor.close() | |
def count(self,query, *parameters): | |
"""Executes the given query, returning the count value from the query.""" | |
cursor = self._cursor() | |
try: | |
cursor.execute(query, parameters) | |
return cursor.fetchone()[0] | |
finally: | |
cursor.close() | |
def __getattr__(self,tablename): | |
''' | |
return single table queryer for select table | |
''' | |
return TableQueryer(self,tablename) | |
def fromQuery(self,Select): | |
''' | |
return single table queryer for query | |
''' | |
return TableQueryer(self,Select) | |
def insert(self,table,**datas): | |
''' | |
Executes the given parameters to an insert SQL and execute it | |
''' | |
return Insert(self,table)(**datas) | |
def executemany(self, query, parameters): | |
"""Executes the given query against all the given param sequences. | |
We return the lastrowid from the query. | |
""" | |
cursor = self._cursor() | |
try: | |
cursor.executemany(query, parameters) | |
return cursor.lastrowid | |
finally: | |
cursor.close() | |
def _cursor(self): | |
if self._db is None: self.reconnect() | |
try: | |
self._db.ping() | |
except: | |
self.reconnect() | |
return self._db.cursor() | |
def _execute(self, cursor, query, parameters): | |
try: | |
return cursor.execute(query, parameters) | |
except OperationalError: | |
logging.error("Error connecting to MySQL on %s", self.host) | |
self.close() | |
raise | |
class TableQueryer: | |
''' | |
Support for single table simple querys | |
''' | |
def __init__(self,db,tablename): | |
self.tablename=tablename | |
self.db=db | |
def get_one(self,query): | |
rs=Select(self.db,self.tablename,query)() | |
if len(rs)>0: | |
if len(rs)>1: | |
raise OperationalError,"returned multi row when fetch one result" | |
return rs[0] | |
return None | |
def insert(self,**fields): | |
return Insert(self.db,self.tablename)(**fields) | |
def __call__(self,query=None): | |
return Operater(self.db,self.tablename,query) | |
def __getattr__(self,field_name): | |
return conds(field_name) | |
class Operater: | |
def __init__(self,db,tablename,query): | |
self.insert=Insert(db,tablename) | |
self.count=Count(db,tablename,query) | |
self.select=Select(db,tablename,query) | |
self.update=Update(db,tablename,query) | |
self.delete=Delete(db,tablename,query) | |
class Count (object): | |
''' | |
Count with current where clouse | |
''' | |
def __init__(self,db,tablename,where): | |
self.db=db | |
self.tablename=tablename | |
self.where=where | |
def __call__(self): | |
if self.where: | |
_sql="".join(["SELECT count(1) FROM ",self.tablename," where ",self.where.get_sql()]) | |
_params=self.where.get_params() | |
return self.db.count(_sql,*_params) | |
else: | |
_sql="",join(["SELECT count(1) FROM ",self.tablename]) | |
return self.db.count(_sql) | |
class Select: | |
''' | |
Select list with current where clouse | |
''' | |
def __init__(self,db,tablename,where): | |
self.db=db | |
self._tablename=tablename | |
self._where=where | |
self._sort_fields=[] | |
self._limit=None | |
self._fields=[] | |
self._groups=[] | |
self._having=None | |
def sort(self,**fields): | |
del self._sort_fields[:] | |
for key in fields.keys(): | |
self._sort_fields.append("".join(["`",key,"` ",fields[key]])) | |
return self | |
def limit(self,start,count): | |
self._limit="".join(["LIMIT ",str(start),",",str(count)]) | |
return self | |
def collect(self,*fields): | |
if len(fields): | |
self._fields+=fields | |
return self | |
def group_by(self,*fields): | |
if len(fields)<1: | |
raise OperationalError,"Must have a field" | |
for f in fields: | |
self._groups.append(f) | |
def having(self,cond): | |
self._having=cond | |
def __getslice__(self,pid,pcount): | |
if pid<1: | |
raise OperationalError,"Wrong page id,page id can not lower than 1" | |
if pcount<1: | |
raise OperationalError,"Wrong page size,page size can not lower than 1" | |
_start=(pid-1)*pcount | |
self._limit="".join(["LIMIT ",str(_start),",",str(pcount)]) | |
return self | |
def get_sql(self): | |
_sql_slice=["SELECT "] | |
if self._fields: | |
_sql_slice.append(",".join(["".join(["`",str(f),"`"]) for f in self._fields])) | |
else: | |
_sql_slice.append("*") | |
_sql_slice.append(" FROM `") | |
if str(self._tablename.__class__)=="database.Select": | |
_sql_slice.append("(") | |
_sql_slice.append(self._tablename.get_sql()) | |
_sql_slice.append(")t") | |
else: | |
_sql_slice.append(self._tablename) | |
_sql_slice.append("`") | |
if self._where: | |
_sql_slice.append(" WHERE ") | |
if str(self._tablename.__class__)=="database.Select": | |
_sql_slice.append(self._where.get_sql(tn='t')) | |
else: | |
_sql_slice.append(self._where.get_sql()) | |
_sql_slice.append(" ") | |
if len(self._groups)>0: | |
_sql_slice.append("GROUP BY ") | |
if str(self._tablename.__class__)=="database.Select": | |
_sql_slice.append(",".join([f.get_sql(tn="t") for f in self._groups])) | |
else: | |
_sql_slice.append(",".join([f.get_sql() for f in self._groups])) | |
if self._having: | |
_sql_slice.append(" HAVING ") | |
_sql_slice.append(self._having.get_sql()) | |
_sql_slice.append(" ") | |
if self._sort_fields: | |
_sql_slice.append("ORDER BY ") | |
if str(self._tablename.__class__)=="database.Select": | |
_sql_slice.append(",".join([self._add_tb('t',s) for s in self._sort_fields])) | |
else: | |
_sql_slice.append(",".join([s for s in self._sort_fields])) | |
if self._limit: | |
_sql_slice.append(" ") | |
_sql_slice.append(self._limit) | |
return "".join(_sql_slice) | |
def _add_tb(self,tn,src): | |
import re | |
p=compile(r'`(\w*?)`') | |
return p.sub(r'`%s.\1`'%tn,src) | |
def __call__(self): | |
_sql=self.get_sql() | |
_plist=[] | |
if str(self._tablename.__class__)=="database.Select": | |
for p in self._tablename.get_sql(): | |
_plist.append(p) | |
if self._where: | |
for p in self._where: | |
_plist.append(p) | |
if self._having: | |
for p in self._having.get_params(): | |
_plist.append(p) | |
if _plist: | |
return self.db.query(_sql,*_plist) | |
else: | |
return self.db.query(_sql) | |
class Update: | |
''' | |
Update Query Generator | |
''' | |
def __init__(self,db,tablename,where): | |
self.db=db | |
self._tablename=tablename | |
self._where=where | |
self._update_cols=None | |
def __call__(self,*fields): | |
if len(fields)<1: | |
raise OperationalError,"Must have unless 1 field to update" | |
_params=[] | |
_cols=[] | |
for f in fields: | |
_cols.append(f.get_sql()) | |
_params.append(f.get_params()[0]) | |
_sql_slice=["UPDATE ",self._tablename," SET ",",".join(_cols)] | |
if self._where: | |
_sql_slice.append(" WHERE ") | |
_sql_slice.append(self._where.get_sql()) | |
for p in self._where.get_params(): | |
_params.append(p) | |
_sql="".join(_sql_slice) | |
return self.db.execute(_sql,*_params) | |
class Delete: | |
def __init__(self,db,tablename,where): | |
self.db=db | |
self._tablename=tablename | |
self._where=where | |
def __call__(self): | |
_sql_slice=["DELETE FROM `",self._tablename,"`"] | |
if self._where: | |
_sql_slice.append(" WHERE ") | |
_sql_slice.append(self._where.get_sql()) | |
_sql="".join(_sql_slice) | |
return self.db.execute(_sql,self._where.get_params()) | |
class Insert: | |
''' | |
Insert Query Generator | |
''' | |
def __init__(self,db,tablename): | |
self.db=db | |
self.tablename=tablename | |
def __call__(self,**fileds): | |
columns=fileds.keys() | |
_prefix="".join(['INSERT INTO `',self.tablename,'`']) | |
_fields=",".join(["".join(['`',column,'`']) for column in columns]) | |
_values=",".join(["%s" for i in range(len(columns))]) | |
_sql="".join([_prefix,"(",_fields,") VALUES (",_values,")"]) | |
_params=[fileds[key] for key in columns] | |
return self.db.execute(_sql,*tuple(_params)) | |
class Row(dict): | |
"""A dict that allows for object-like property access syntax.""" | |
def __getattr__(self, name): | |
try: | |
return self[name] | |
except KeyError: | |
raise AttributeError(name) | |
class conds: | |
def __init__(self,field): | |
self.field_name=field | |
self._sql="" | |
self._params=[] | |
self._has_value=False | |
self._sub_conds=[] | |
self._no_value=False | |
def _prepare(self,sql,value): | |
if not self._has_value: | |
self._sql=sql | |
self._params.append(value) | |
self._has_value=True | |
return self | |
raise OperationalError,"Multiple Operate conditions" | |
def __str__(self): | |
if self._has_value: | |
return self._sql | |
else: | |
return self.field_name | |
def get_sql(self,tn=None): | |
_sql_slice=[] | |
_sql_slice.append(self._sql) | |
_sql_slice.append(" ") | |
if len(self._sub_conds): | |
for cond in self._sub_conds: | |
_sql_slice.append(cond[1]) | |
_sql_slice.append(cond[0].get_sql()) | |
_where = "".join(_sql_slice) | |
if tn: | |
import re | |
p=compile(r'`(\w*?)`') | |
_where = p.sub(r'`%s.\1`'%tn,_where) | |
return _where | |
def get_params(self): | |
_my_params=[]+self._params | |
if len(self._sub_conds): | |
for cond in self._sub_conds: | |
_my_params+=cond[0].get_params() | |
return _my_params | |
def __sub__(self,value): | |
return self._prepare("".join(["`",self.field_name,"`-%s"]),value) | |
def __add__(self,value): | |
return self._prepare("".join(["`",self.field_name,"`+%s"]),value) | |
def __ne__(self,value): | |
return self._prepare("".join(["`",self.field_name,'`','<>%s']),value) | |
def __eq__(self,value): | |
if not self._has_value: | |
if str(value.__class__)=="database.conds": | |
self._sql="".join(["`",self.field_name,'`','=',value.get_sql()]) | |
self._params.append(value.get_params()[0]) | |
else: | |
self._sql="".join(["`",self.field_name,'`','=%s']) | |
self._params.append(value) | |
self._has_value=True | |
return self | |
raise OperationalError,"Multiple Operate conditions" | |
def like(self,value): | |
return self._prepare("".join(["`",self.field_name,'`',' like %s']),value) | |
def DL(self,format,value): | |
return self._prepare("".join(["DATE_FORMAT(`",self.field_name,"`,'",format,"')",'<=%s']),value) | |
def DG(self,format,value): | |
return self._prepare("".join(["DATE_FORMAT(`",self.field_name,"`,'",format,"')",'>=%s']),value) | |
def DE(self,format,value): | |
return self._prepare("".join(["DATE_FORMAT(`",self.field_name,"`,'",format,"')",'=%s']),value) | |
def __le__(self,value): | |
return self._prepare("".join(["`",self.field_name,'`','<=%s']),value) | |
def __lt__(self,value): | |
return self._prepare("".join(["`",self.field_name,'`','<%s']),value) | |
def __gt__(self,value): | |
return self._prepare("".join(["`",self.field_name,'`','>%s']),value) | |
def __ge__(self,value): | |
return self._prepare("".join(["`",self.field_name,'`','>=%s']),value) | |
def In(self,array): | |
if not self._has_value: | |
if str(array.__class__)=="database.Select": | |
self._sql="".join(["`",self.field_name,'`',' in (',array.get_sql(),")"]) | |
for p in array.get_params(): | |
self._params.append(p) | |
else: | |
_values=",".join(["".join(['\'',i,'\'']) for i in array]) | |
self._sql="".join(["`",self.field_name,'`',' in (',_values,")"]) | |
self._has_value=True | |
return self | |
raise OperationalError,"Multiple Operate conditions" | |
def Not_In(self,array): | |
if not self._has_value: | |
if str(array.__class__)=="database.Select": | |
self._sql="".join(["`",self.field_name,'`',' not in (',array.get_sql(),")"]) | |
for p in array.get_params(): | |
self._params.append(p) | |
else: | |
_values=",".join(["".join(['\'',i,'\'']) for i in array]) | |
self._sql="".join(["`",self.field_name,'`',' not in (',_values,")"]) | |
self._has_value=True | |
return self | |
raise OperationalError,"Multiple Operate conditions" | |
def __getattr__(self,func_name): | |
if not self._has_value: | |
if str(array.__class__)=="database.Select": | |
self.self.field_name="".join([func_name,"(t.",self.field_name,") as ",func_name,"_",self.field_name]) | |
else: | |
self.self.field_name="".join([func_name,"(",self.field_name,") as ",func_name,"_",self.field_name]) | |
return self | |
raise OperationalError,"Multiple Operate conditions" | |
def __and__(self,cond): | |
if self._has_value: | |
self._sub_conds.append((cond," AND ")) | |
return self | |
raise OperationalError,"Operation with no value" | |
def __or__(self,cond): | |
if self._has_value: | |
self._sub_conds.append((cond," OR ")) | |
return self | |
raise OperationalError,"Operation with no value" | |
# Fix the access conversions to properly recognize unicode/binary | |
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE | |
FLAG = MySQLdb.constants.FLAG | |
CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions) | |
for field_type in \ | |
[FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING] + \ | |
([FIELD_TYPE.VARCHAR] if 'VARCHAR' in vars(FIELD_TYPE) else []): | |
CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str)) | |
# Alias some common MySQL exceptions | |
IntegrityError = MySQLdb.IntegrityError | |
OperationalError = MySQLdb.OperationalError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment