Created
November 3, 2012 02:11
-
-
Save kputnam/4005520 to your computer and use it in GitHub Desktop.
Patch ActiveRecord + Arel to support INSERT ... ON DUPLICATE KEY UPDATE ...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# I'd prefer to add a parameter to ActiveRecord::Base.create (and .save) | |
# which would look like this: | |
# | |
# user.save(on_duplicate_key_update: true) | |
# | |
# User.create(attrs, on_duplicate_key_update: %w(updated_at)) | |
# | |
# But those methods are already overloaded and it seems like any methods, | |
# even new ones we could add, would require some painful plumbing to pass | |
# these parameters down to ActiveRecord::Relation.insert. | |
# | |
# So for now, we cheat using a "thread-local" global variable. | |
# | |
# # Returns existing record (and existing record isn't modified) | |
# User.on_duplicate_key_update do | |
# User.create(...) | |
# end | |
# | |
# # Returns existing record (and overwrites existing's updated_at and first_name) | |
# User.on_duplicate_key_update(%w(updated_at first_name)) do | |
# User.create(...) | |
# end | |
# | |
# But it's stackable, so you can override the current setting within a | |
# delimited scope: | |
# | |
# User.on_duplicate_key_update(true) do | |
# # duplicate inserts will be merged | |
# User.create(...) | |
# | |
# User.on_duplicate_key_update(false) do | |
# # duplicate inserts will fail | |
# User.create(...) | |
# end | |
# end | |
# | |
class Arel::InsertOrOverwriteManager < Arel::InsertManager | |
def initialize(engine) | |
super | |
@ast = Arel::Nodes::InsertOrOverwriteStatement.new | |
end | |
def overwrites=(column_names) | |
@ast.overwrites = column_names | |
end | |
def overwrites | |
@ast.overwrites | |
end | |
end | |
class Arel::Nodes::InsertOrOverwriteStatement < Arel::Nodes::InsertStatement | |
attr_accessor :overwrites | |
def initialize | |
super | |
@overwrites = [] | |
end | |
def initialize_copy(other) | |
super | |
@overwrites = @overwrites.clone if @overwrites | |
end | |
end | |
class Arel::Visitors::ToSQL | |
def visit_Arel_Nodes_InsertOrOverwriteStatement(o) | |
raise NotImplementedError, | |
"INSERT ... ON DUPLICATE KEY not implemented for this db" | |
end | |
end | |
# arel-3.0.2/lib/arel/visitors/to_sql.rb | |
# arel-3.0.2/lib/arel/visitors/mysql.rb | |
class Arel::Visitors::MySQL | |
def visit_Arel_Nodes_InsertOrOverwriteStatement(o) | |
# TODO: There is a case when o.values.blank? and o.columns.empty?, which | |
# should emit something like this, if this is valid in MySQL: | |
# | |
# INSERT INTO `relation` VALUES (DEFAULT) | |
# ON DUPLICATE KEY UPDATE id = LAST_INSERT_ID(id), | |
# afield = DEFAULT, bfield = DEFAULT, ... | |
# | |
columns = o.columns.map{|x| quote_column_name(x.name) }.join(",") | |
values = Hash[o.columns.zip(o.values.expressions).map do |attr, expr| | |
Arel::Nodes::SqlLiteral === expr ? | |
[attr.name, visit(expr)] : | |
[attr.name, quote(expr, attr && column_for(attr))] | |
end] | |
# This is essential, otherwise connection.last_insert_id will be 0 | |
pkey = quote_column_name(o.relation.primary_key.name) | |
pkey = [pkey, " = LAST_INSERT_ID(", pkey, ")"].join | |
overwrites = values.slice(*o.overwrites.map(&:to_s)) | |
overwrites = overwrites.map do |column, value| | |
[quote_column_name(column), " = ", value].join | |
end.push(pkey).join(", ") | |
["INSERT INTO #{visit o.relation}", | |
"(", columns, ")", | |
"VALUES (", values.values.join(", "), ")", | |
"ON DUPLICATE KEY UPDATE", | |
overwrites].join(" ") | |
end | |
end | |
# Obnoxious plumbing | |
################################################################################ | |
class Arel::SelectManager | |
def create_insert | |
if (columns = @engine.on_duplicate_key_update).present? | |
Arel::InsertOrOverwriteManager.new(@engine).tap do |im| | |
im.overwrites = columns if Array === columns | |
end | |
else | |
Arel::InsertManager.new(@engine) | |
end | |
end | |
end | |
class ActiveRecord::Base | |
def self.on_duplicate_key_update(columns = true) | |
key = ["on_duplicate_key_update", table_name].join | |
stack = Thread.current[key] ||= [] | |
return stack.last \ | |
unless block_given? | |
stack.push(columns) | |
yield | |
ensure | |
stack.pop if block_given? | |
end | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "spec_helper" | |
describe ActiveRecord::Base do | |
before do | |
@base = Class.new(ActiveRecord::Base) do | |
def self.name; "User"; end | |
end | |
end | |
describe "on_duplicate_key_update" do | |
it "returns block's result" do | |
@base.on_duplicate_key_update { :ok }.should == :ok | |
end | |
it "returns current setting" do | |
@base.on_duplicate_key_update.should be_nil | |
end | |
it "returns current setting" do | |
@base.on_duplicate_key_update(%w(updated_at)) do | |
@base.on_duplicate_key_update | |
end.should == %w(updated_at) | |
end | |
it "defaults setting to true" do | |
@base.on_duplicate_key_update do | |
@base.on_duplicate_key_update | |
end.should be_true | |
end | |
it "stores settings in stack order" do | |
@base.on_duplicate_key_update(%w(a b c)) do | |
@base.on_duplicate_key_update(false) do | |
@base.on_duplicate_key_update(true) do | |
@base.on_duplicate_key_update.should be_true | |
@base.on_duplicate_key_update | |
end.should be_true | |
@base.on_duplicate_key_update.should be_false | |
@base.on_duplicate_key_update | |
end.should be_false | |
@base.on_duplicate_key_update.should == %w(a b c) | |
@base.on_duplicate_key_update | |
end.should == %w(a b c) | |
end | |
it "always restores setting" do | |
@base.on_duplicate_key_update(%w(last_name)) | |
@base.on_duplicate_key_update.should be_nil | |
end | |
it "always restores setting" do | |
@base.on_duplicate_key_update(%w(first_name)) { :ok } | |
@base.on_duplicate_key_update.should be_nil | |
end | |
it "always restores setting" do | |
@base.on_duplicate_key_update(%w(first_name)) do | |
raise "something happened" | |
end rescue nil | |
@base.on_duplicate_key_update.should be_nil | |
end | |
context "when false" do | |
it "emits an ordinary insert statement" do | |
Arel::InsertManager. | |
should_receive(:new). | |
and_raise("good enough") | |
Arel::InsertOrOverwriteManager. | |
should_not_receive(:new) | |
lambda { @base.new.save(validate: false) }.should \ | |
raise_error("good enough") | |
end | |
end | |
context "when true" do | |
subject { @base.on_duplicate_key_update { @base.new.save(validate: false) }} | |
it "emits a special insert statement" do | |
Arel::InsertOrOverwriteManager. | |
should_receive(:new). | |
and_raise("good enough") | |
lambda { subject }. | |
should raise_error("good enough") | |
end | |
it "emits a special insert statement" do | |
@base.connection. | |
stub(:exec_insert). | |
with(/ ON DUPLICATE KEY UPDATE `id` = LAST_INSERT_ID\(`id`\)$/, "SQL", []). | |
and_raise("good enough") | |
lambda { subject }. | |
should raise_error("good enough") | |
end | |
end | |
context "when set to list of columns" do | |
subject { @base.on_duplicate_key_update(%w(first_name)) { @base.new.save(validate: false) }} | |
it "emits a special insert statement" do | |
im = mock(:InsertOrOverwriteManager) | |
im.should_receive(:overwrites=). | |
with(%w(first_name)). | |
and_raise("good enough") | |
Arel::InsertOrOverwriteManager. | |
should_receive(:new). | |
and_return(im) | |
lambda { subject }. | |
should raise_error("good enough") | |
end | |
it "emits a special insert statement" do | |
@base.connection. | |
stub(:exec_insert). | |
with(/ ON DUPLICATE KEY UPDATE `first_name` = NULL, `id` = LAST_INSERT_ID\(`id`\)$/, "SQL", []). | |
and_raise("good enough") | |
lambda { subject }. | |
should raise_error("good enough") | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hmm, that
define_method
isn't thread safe.