Created
July 5, 2016 00:22
-
-
Save kwatch/53d991c47c42132e6352a6d0520ceb92 to your computer and use it in GitHub Desktop.
SQLおじさんのサンプルSQLをO/Rマッパーで書いてみた
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
SQLAlchemy example code. | |
Requirements: | |
* Python3 | |
* PostgreSQL | |
* SQLAlchemy | |
* Psychopg2 | |
""" | |
import sys, os, re | |
import sqlalchemy | |
import sqlalchemy.orm | |
import sqlalchemy.ext.declarative | |
class Config(object): | |
## SQLAlchemy | |
sa_url = 'postgres://user2@localhost/example2' # CHANGE HERE | |
sa_echo = True | |
config = Config() | |
engine = sqlalchemy.create_engine(config.sa_url, echo=config.sa_echo) | |
Base = sqlalchemy.ext.declarative.declarative_base() | |
DBSession = sqlalchemy.orm.sessionmaker() | |
DBSession.configure(bind=engine) | |
from sqlalchemy import ( | |
Column, ForeignKey, UniqueConstraint, | |
String, Text, Integer, Date, DateTime, Time, Boolean, | |
) | |
from sqlalchemy.orm import relationship, backref | |
class Invoice(Base): | |
"""請求書クラス""" | |
DDL = r""" | |
CREATE TABLE invoices ( | |
id serial PRIMARY KEY | |
, customer_id integer NOT NULL --REFERENCES customers(id) | |
, total_amount integer NOT NULL | |
, total_tax integer NOT NULL | |
); | |
""" | |
__tablename__ = "invoices" | |
id = Column(Integer, primary_key=True) | |
customer_id = Column(Integer, nullable=False) # ForeignKey('customers.id') | |
total_amount = Column(Integer, nullable=False, default=0) | |
total_tax = Column(Integer, nullable=False, default=0) | |
lines = relationship('InvoiceLine', uselist=True) | |
TAX_RATE = 0.08 | |
def __repr__(self): | |
return "Invoice(id=%r, customer_id=%r, total_amout=%r, total_tax=%r)" % \ | |
(self.id, self.customer_id, self.total_amount, self.total_tax) | |
class InvoiceLine(Base): | |
"""請求書明細クラス""" | |
DDL = r""" | |
CREATE TABLE invoice_lines ( | |
id serial NOT NULL PRIMARY KEY | |
, invoice_id integer NOT NULL REFERENCES invoices(id) | |
, line_no integer NOT NULL | |
, item_id integer NOT NULL --REFERENCES items(id) | |
, item_count integer NOT NULL | |
, unit_price integer NOT NULL | |
, UNIQUE (invoi | |
); | |
ALTER TABLE invoice_lines ADD CONSTRAINT invoices_lines_compound_uniq UNIQUE(invoice_id, line_no); | |
""" | |
__tablename__ = "invoice_lines" | |
__table_args__ = (UniqueConstraint('invoice_id', 'line_no'), ) | |
id = Column(Integer, primary_key=True) | |
invoice_id = Column(Integer, ForeignKey(Invoice.id), nullable=False) | |
line_no = Column(Integer, nullable=False) | |
item_id = Column(Integer, nullable=False) # ForeginKey('items.id') | |
item_count = Column(Integer, nullable=False) | |
unit_price = Column(Integer, nullable=False) | |
#item = relationship('Item', uselist=False) | |
header = relationship('Invoice', uselist=False) | |
def __repr__(self): | |
return "InvoiceLine(id=%r, invoice_id=%r, line_no=%r, item_i=%r, item_count=%r, unit_price=%r)" % \ | |
(self.id, self.invoice_id, self.line_no, self.item_id, self.item_count, self.unit_price) | |
class TX(object): | |
"""Helper class to start transaction. | |
Usage: | |
with TX() as db: | |
for invoice in db.query(Invoice).all(): | |
print(invoice) | |
""" | |
def __enter__(self): | |
self.db = db = DBSession() | |
return db | |
def __exit__(self, exclass, ex, extraceback): | |
if ex: | |
self.db.rollback() | |
else: | |
self.db.commit() | |
def row_number(xs, partition_by=lambda x: x): | |
"""Simulates row_number() window function. | |
ex:: | |
>>> items = [("a", "X"), ("b", "X"), ("c", "X"), | |
... ("d", "Y"), ("e", "Y"),] | |
>>> for i, (k, v) in row_number(items, lambda t: t[1]): | |
... print([i, k, v]) | |
... | |
[1, 'a', 'X'] | |
[2, 'b', 'X'] | |
[3, 'c', 'X'] | |
[1, 'd', 'Y'] | |
[2, 'e', 'Y'] | |
""" | |
i = None | |
prev = None | |
for x in xs: | |
curr = partition_by(x) | |
if i is None or prev != curr: | |
i = 1 | |
prev = curr | |
else: | |
i += 1 | |
yield i, x | |
class Operation(object): | |
"""Base class of operations.""" | |
def __init__(self, db): | |
self.db = db | |
class InvoiceOp(Operation): | |
def query_invoice_lines1(self): | |
""" | |
window関数を使ったSQLを発行するバージョン (for Oracle or PostgreSQL)。 | |
ref: http://qiita.com/kantomi/items/5e07641016615c073b9f#-%E7%BF%BB%E8%A8%B3%E3%81%97%E3%81%9Fsql | |
クエリオブジェクトを返す。 | |
なお DB の Optimizer とは何の関係もないことに注意。 | |
発行されるSQL:: | |
SELECT a.invoice_id AS a_invoice_id | |
, a.line_no AS a_line_no | |
, a.item_id AS a_item_id | |
, a.unit_price AS a_unit_price | |
, a.item_count AS a_item_count | |
, a.amount AS a_amount | |
, CASE | |
WHEN (a.ord <= a.diff) THEN 1 | |
ELSE 0 | |
END AS anon_1 | |
FROM ( | |
SELECT invoice_lines.id AS id | |
, invoice_lines.invoice_id AS invoice_id | |
, invoice_lines.line_no AS line_no | |
, invoice_lines.item_id AS item_id | |
, invoice_lines.item_count AS item_count | |
, invoice_lines.unit_price AS unit_price | |
, invoice_lines.item_count * invoice_lines.unit_price | |
AS amount | |
, invoices.total_tax - sum(trunc(invoice_lines.item_count * invoice_lines.unit_price * 0.08)) | |
OVER (partition BY invoice_lines.invoice_id) | |
AS diff | |
, row_number() | |
OVER (PARTITION BY invoice_lines.invoice_id | |
ORDER BY invoice_lines.item_count * invoice_lines.unit_price DESC, invoice_lines.line_no) | |
AS ord | |
FROM invoice_lines | |
JOIN invoices ON invoices.id = invoice_lines.invoice_id | |
) AS a | |
""" | |
from sqlalchemy import func as fn, desc, case | |
from sqlalchemy.orm import aliased | |
h = Invoice # 請求書 (or: h = aliased(Invoice, name='h')) | |
d = InvoiceLine # 請求書明細 (or: d = aliased(InvoiceLine, name='d')) | |
amount = d.item_count * d.unit_price # 数量 * 単価 (値ではなく式であることに注意) | |
tax_rate = Invoice.TAX_RATE # 税率 | |
subcolumns = [ | |
amount.label('amount'), # 数量 * 単価 | |
( | |
h.total_tax - fn.sum(fn.trunc(amount * tax_rate)).over(partition_by=d.invoice_id) | |
).label('diff'), # 誤差金額 | |
( | |
fn.row_number().over(partition_by=d.invoice_id, order_by=(desc(amount), d.line_no)) | |
).label('ord'), # 購入額順位 | |
] | |
# サブクエリ (FROM 請求書明細 INNER JOIN 請求書 h ON d.請求書NO = h.請求書NO) | |
subquery = (self.db.query(InvoiceLine, *subcolumns) | |
.join(InvoiceLine.header) | |
).subquery() | |
a = aliased(subquery, name='a') | |
# メインクエリ | |
columns = [ | |
a.c.invoice_id, # 請求書No | |
a.c.line_no, # 行No | |
a.c.item_id, # 商品CD | |
a.c.unit_price, # 単価 | |
a.c.item_count, # 数量 | |
a.c.amount, # 購入額 | |
case([(a.c.ord <= a.c.diff, 1)], else_=0), # 配賦するなら1、しないなら0 | |
] | |
query = (self.db.query(*columns) | |
#.filter(....) # 通常は何らかの絞り込みがある | |
)#.all() | |
return query # クエリオブジェクトを返す | |
def query_invoice_lines2(self): | |
""" | |
window関数のかわりにサブクエリを使うSQLを発行するバージョン (for MySQL or SQLite)。 | |
ジェネレータを返す。 | |
なお DB の Optimizer とは何の関係もないことに注意。 | |
発行されるSQL:: | |
SELECT invoice_lines.id AS invoice_lines_id | |
, invoice_lines.invoice_id AS invoice_lines_invoice_id | |
, invoice_lines.line_no AS invoice_lines_line_no | |
, invoice_lines.item_id AS invoice_lines_item_id | |
, invoice_lines.item_count AS invoice_lines_item_count | |
, invoice_lines.unit_price AS invoice_lines_unit_price | |
, anon_1.tax_diff AS anon_1_tax_diff | |
FROM invoice_lines | |
JOIN ( | |
SELECT invoices.id AS invoice_id | |
, invoices.total_tax - sum(trunc(invoice_lines.item_count * invoice_lines.unit_price * 0.08)) | |
AS tax_diff | |
FROM invoices | |
JOIN invoice_lines | |
ON invoices.id = invoice_lines.invoice_id | |
GROUP BY invoices.id, | |
invoices.total_tax | |
) AS anon_1 ON invoice_lines.invoice_id = anon_1.invoice_id | |
ORDER BY invoice_lines.invoice_id | |
, invoice_lines.item_count * invoice_lines.unit_price DESC | |
, invoice_lines.line_no | |
""" | |
from sqlalchemy import func as fn, desc | |
tax_rate = Invoice.TAX_RATE # 税率 | |
amount = InvoiceLine.item_count * InvoiceLine.unit_price | |
tax_diff_ = Invoice.total_tax - fn.sum(fn.trunc(amount * tax_rate)) | |
subcolumns = [ | |
Invoice.id .label('invoice_id'), | |
tax_diff_ .label('tax_diff'), | |
] | |
subq = (self.db.query(*subcolumns) | |
.join(Invoice.lines) | |
#.filter(....) # 通常は何らかの絞り込みがある | |
.group_by(Invoice.id, Invoice.total_tax) | |
).subquery() | |
## | |
d = InvoiceLine | |
qry = (self.db.query(InvoiceLine, subq.c.tax_diff) | |
.join(subq, d.invoice_id == subq.c.invoice_id) | |
.order_by(d.invoice_id, desc(amount), d.line_no) | |
)#.all() | |
## | |
for i, (invoice_line, tax_diff) in row_number(qry, lambda x: x[0].invoice_id): | |
invoice_line.distribute_flag = i <= tax_diff # 配賦するならTrue、しないならFalse | |
yield invoice_line # ジェネレータが返される | |
def query_invoice_lines3(self): | |
""" | |
Window関数もサブクエリも使わないバーション。 | |
SQL を 2 回発行するのでその分遅くなるが、通常は許容範囲内に収まるはず。 | |
ジェネレータを返す。 | |
なお DB の Optimizer とは何の関係もないことに注意。 | |
発行されるSQL:: | |
SELECT invoices.id AS invoices_id | |
, invoices.total_tax - sum(trunc(invoice_lines.item_count * invoice_lines.unit_price * 0.08)) | |
AS anon_1 | |
FROM invoices | |
JOIN invoice_lines ON invoices.id = invoice_lines.invoice_id | |
GROUP BY invoices.id, invoices.total_tax | |
; | |
SELECT invoice_lines.id AS invoice_lines_id | |
, invoice_lines.invoice_id AS invoice_lines_invoice_id | |
, invoice_lines.line_no AS invoice_lines_line_no | |
, invoice_lines.item_id AS invoice_lines_item_id | |
, invoice_lines.item_count AS invoice_lines_item_count | |
, invoice_lines.unit_price AS invoice_lines_unit_price | |
FROM invoice_lines | |
WHERE invoice_lines.invoice_id IN (101, 102, 103, ....) | |
ORDER BY invoice_lines.invoice_id | |
, invoice_lines.item_count * invoice_lines.unit_price DESC | |
, invoice_lines.line_no | |
; | |
""" | |
from sqlalchemy import func as fn, desc | |
# | |
tax_rate = Invoice.TAX_RATE # 税率 | |
d = InvoiceLine | |
amount = d.item_count * d.unit_price | |
tax_diff_ = Invoice.total_tax - fn.sum(fn.trunc(amount * tax_rate)) | |
qry1 = (self.db.query(Invoice.id, tax_diff_) | |
.join(Invoice.lines) | |
#.filter(....) # 通常は何らかの絞り込みがある | |
.group_by(Invoice.id, Invoice.total_tax) | |
) | |
tax_diffs = { invoice_id: tax_diff for invoice_id, tax_diff in qry1 } | |
if not tax_diffs: # 空なら終了 | |
return | |
# | |
invoice_ids = tuple(tax_diffs.keys()) | |
order_by = [d.invoice_id, desc(d.item_count * d.unit_price), d.line_no] | |
qry2 = (self.db.query(InvoiceLine) | |
.filter(d.invoice_id.in_(invoice_ids)) | |
.order_by(*order_by) | |
) | |
# | |
for i, invoice_line in row_number(qry2, lambda x: x.invoice_id): | |
tax_diff = tax_diffs[invoice_line.invoice_id] | |
invoice_line.distribute_flag = i < tax_diff # 配賦するならTrue、しないならFalse | |
yield invoice_line # ジェネレータが返される | |
def query_invoice_lines4(self, batch_size=100): | |
""" | |
query_invoice_lines3() のSQLを、請求書100件ずつ実行するようにしたバージョン。 | |
またSQLを2つ発行するので、それぞれを別のメソッドに分離。 | |
ジェネレータを返す。 | |
なお DB の Optimizer とは何の関係もないことに注意。 | |
発行されるSQLは、quey_invoice_lines3() とほぼ同じで、limit と offset がついているだけ。 | |
""" | |
offset = 0 | |
while True: | |
qry1 = self._query_invoices4(batch_size, offset) | |
tax_diffs = { invoice_id: tax_diff for invoice_id, tax_diff in qry1 } | |
if not tax_diffs: # 空なら終了 | |
break | |
invoice_ids = tuple(tax_diffs.keys()) | |
# | |
qry2 = self._query_invoice_lines4(invoice_ids) | |
for i, invoice_line in row_number(qry2, lambda x: x.invoice_id): | |
tax_diff = tax_diffs[invoice_line.invoice_id] | |
invoice_line.distribute_flag = i < tax_diff # 配賦するならTrue、しないならFalse | |
yield invoice_line # ジェネレータが返される | |
# | |
if len(invoice_ids) < batch_size: | |
break | |
offset += batch_size | |
def _query_invoices4(self, limit, offset): | |
from sqlalchemy import func as fn | |
tax_rate = Invoice.TAX_RATE # 税率 | |
d = InvoiceLine | |
amount = d.item_count * d.unit_price | |
tax_diff = Invoice.total_tax - fn.sum(fn.trunc(amount * tax_rate)) | |
return (self.db.query(Invoice.id, tax_diff) | |
.join(Invoice.lines) | |
#.filter(....) # 通常は何らかの絞り込みがある | |
.group_by(Invoice.id, Invoice.total_tax) | |
.limit(limit) | |
.offset(offset) | |
) | |
def _query_invoice_lines4(self, invoice_ids): | |
from sqlalchemy import func as fn, desc | |
d = InvoiceLine | |
order_by = [d.invoice_id, desc(d.item_count * d.unit_price), d.line_no] | |
return (self.db.query(InvoiceLine) | |
.filter(d.invoice_id.in_(invoice_ids)) | |
.order_by(*order_by) | |
) | |
def _main(args): | |
if not args: | |
script_name = os.path.basename(sys.argv[0]) | |
sys.stderr.write("Usage: python %s [1-4]\n" % script_name) | |
return 1 | |
arg = args[0] | |
with TX() as db: | |
if arg == '1': | |
for row in InvoiceOp(db).query_invoice_lines1(): | |
print(row) | |
elif arg == '2': | |
for invoice_line in InvoiceOp(db).query_invoice_lines2(): | |
print(invoice_line) | |
print(invoice_line.distribute_flag) | |
elif arg == '3': | |
for invoice_line in InvoiceOp(db).query_invoice_lines3(): | |
print(invoice_line) | |
print(invoice_line.distribute_flag) | |
elif arg == '4': | |
for invoice_line in InvoiceOp(db).query_invoice_lines4(): | |
print(invoice_line) | |
print(invoice_line.distribute_flag) | |
else: | |
sys.stderr.write("%s: Unexpected argument.\n" % arg) | |
return 1 | |
return 0 | |
if __name__ == '__main__': | |
status = _main(sys.argv[1:]) | |
if status != 0: | |
sys.exit(status) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment