Created
April 16, 2013 02:48
-
-
Save cuimuxi/5392970 to your computer and use it in GitHub Desktop.
facebook tornado database
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# Copyright 2009 Facebook | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); you may | |
# not use this file except in compliance with the License. You may obtain | |
# a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
# License for the specific language governing permissions and limitations | |
# under the License. | |
"""A lightweight wrapper around MySQLdb.""" | |
from __future__ import absolute_import, division, with_statement | |
import copy | |
import itertools | |
import logging | |
import time | |
try: | |
import MySQLdb.constants | |
import MySQLdb.converters | |
import MySQLdb.cursors | |
except ImportError: | |
# If MySQLdb isn't available this module won't actually be useable, | |
# but we want it to at least be importable (mainly for readthedocs.org, | |
# which has limitations on third-party modules) | |
MySQLdb = None | |
class Connection(object): | |
"""A lightweight wrapper around MySQLdb DB-API connections. | |
The main value we provide is wrapping rows in a dict/object so that | |
columns can be accessed by name. Typical usage:: | |
db = database.Connection("localhost", "mydatabase") | |
for article in db.query("SELECT * FROM articles"): | |
print article.title | |
Cursors are hidden by the implementation, but other than that, the methods | |
are very similar to the DB-API. | |
We explicitly set the timezone to UTC and the character encoding to | |
UTF-8 on all connections to avoid time zone and encoding errors. | |
""" | |
def __init__(self, host, database, user=None, password=None, | |
max_idle_time=7 * 3600): | |
self.host = host | |
self.database = database | |
self.max_idle_time = max_idle_time | |
args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8", | |
db=database, init_command='SET time_zone = "+0:00"', | |
sql_mode="TRADITIONAL") | |
if user is not None: | |
args["user"] = user | |
if password is not None: | |
args["passwd"] = password | |
# We accept a path to a MySQL socket file or a host(:port) string | |
if "/" in host: | |
args["unix_socket"] = host | |
else: | |
self.socket = None | |
pair = host.split(":") | |
if len(pair) == 2: | |
args["host"] = pair[0] | |
args["port"] = int(pair[1]) | |
else: | |
args["host"] = host | |
args["port"] = 3306 | |
self._db = None | |
self._db_args = args | |
self._last_use_time = time.time() | |
try: | |
self.reconnect() | |
except Exception: | |
logging.error("Cannot connect to MySQL on %s", self.host, | |
exc_info=True) | |
def __del__(self): | |
self.close() | |
def close(self): | |
"""Closes this database connection.""" | |
if getattr(self, "_db", None) is not None: | |
self._db.close() | |
self._db = None | |
def reconnect(self): | |
"""Closes the existing database connection and re-opens it.""" | |
self.close() | |
self._db = MySQLdb.connect(**self._db_args) | |
self._db.autocommit(True) | |
def iter(self, query, *parameters): | |
"""Returns an iterator for the given query and parameters.""" | |
self._ensure_connected() | |
cursor = MySQLdb.cursors.SSCursor(self._db) | |
try: | |
self._execute(cursor, query, parameters) | |
column_names = [d[0] for d in cursor.description] | |
for row in cursor: | |
yield Row(zip(column_names, row)) | |
finally: | |
cursor.close() | |
def query(self, query, *parameters): | |
"""Returns a row list for the given query and parameters.""" | |
cursor = self._cursor() | |
try: | |
self._execute(cursor, query, parameters) | |
column_names = [d[0] for d in cursor.description] | |
return [Row(itertools.izip(column_names, row)) for row in cursor] | |
finally: | |
cursor.close() | |
def get(self, query, *parameters): | |
"""Returns the first row returned for the given query.""" | |
rows = self.query(query, *parameters) | |
if not rows: | |
return None | |
elif len(rows) > 1: | |
raise Exception("Multiple rows returned for Database.get() query") | |
else: | |
return rows[0] | |
# rowcount is a more reasonable default return value than lastrowid, | |
# but for historical compatibility execute() must return lastrowid. | |
def execute(self, query, *parameters): | |
"""Executes the given query, returning the lastrowid from the query.""" | |
return self.execute_lastrowid(query, *parameters) | |
def execute_lastrowid(self, query, *parameters): | |
"""Executes the given query, returning the lastrowid from the query.""" | |
cursor = self._cursor() | |
try: | |
self._execute(cursor, query, parameters) | |
return cursor.lastrowid | |
finally: | |
cursor.close() | |
def execute_rowcount(self, query, *parameters): | |
"""Executes the given query, returning the rowcount from the query.""" | |
cursor = self._cursor() | |
try: | |
self._execute(cursor, query, parameters) | |
return cursor.rowcount | |
finally: | |
cursor.close() | |
def executemany(self, query, parameters): | |
"""Executes the given query against all the given param sequences. | |
We return the lastrowid from the query. | |
""" | |
return self.executemany_lastrowid(query, parameters) | |
def executemany_lastrowid(self, query, parameters): | |
"""Executes the given query against all the given param sequences. | |
We return the lastrowid from the query. | |
""" | |
cursor = self._cursor() | |
try: | |
cursor.executemany(query, parameters) | |
return cursor.lastrowid | |
finally: | |
cursor.close() | |
def executemany_rowcount(self, query, parameters): | |
"""Executes the given query against all the given param sequences. | |
We return the rowcount from the query. | |
""" | |
cursor = self._cursor() | |
try: | |
cursor.executemany(query, parameters) | |
return cursor.rowcount | |
finally: | |
cursor.close() | |
def _ensure_connected(self): | |
# Mysql by default closes client connections that are idle for | |
# 8 hours, but the client library does not report this fact until | |
# you try to perform a query and it fails. Protect against this | |
# case by preemptively closing and reopening the connection | |
# if it has been idle for too long (7 hours by default). | |
if (self._db is None or | |
(time.time() - self._last_use_time > self.max_idle_time)): | |
self.reconnect() | |
self._last_use_time = time.time() | |
def _cursor(self): | |
self._ensure_connected() | |
return self._db.cursor() | |
def _execute(self, cursor, query, parameters): | |
try: | |
return cursor.execute(query, parameters) | |
except OperationalError: | |
logging.error("Error connecting to MySQL on %s", self.host) | |
self.close() | |
raise | |
class Row(dict): | |
"""A dict that allows for object-like property access syntax.""" | |
def __getattr__(self, name): | |
try: | |
return self[name] | |
except KeyError: | |
raise AttributeError(name) | |
if MySQLdb is not None: | |
# Fix the access conversions to properly recognize unicode/binary | |
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE | |
FLAG = MySQLdb.constants.FLAG | |
CONVERSIONS = copy.copy(MySQLdb.converters.conversions) | |
field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING] | |
if 'VARCHAR' in vars(FIELD_TYPE): | |
field_types.append(FIELD_TYPE.VARCHAR) | |
for field_type in field_types: | |
CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type] | |
# Alias some common MySQL exceptions | |
IntegrityError = MySQLdb.IntegrityError | |
OperationalError = MySQLdb.OperationalError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment