-
-
Save SamMolokanov/13c897d7ae336d21b3bf7604af50f2a2 to your computer and use it in GitHub Desktop.
Arel Helpers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module ArelHelpers | |
extend self | |
def self.included(base) | |
base.extend self | |
end | |
def asterisk(arel_table_or_model) | |
arel_table, columns = case arel_table_or_model | |
when Arel::Table | |
[arel_table_or_model, arel_table_or_model.engine.columns] | |
when ->(possible_model) { UtilitiesHelper.is_model?(possible_model) } | |
[arel_table_or_model.arel_table, arel_table_or_model.columns] | |
else | |
raise ArgumentError, "Must pass in an arel table or model" | |
end | |
columns.map { |c| arel_table[c.name] } | |
end | |
def greatest(*args) | |
Arel::Nodes::NamedFunction.new "greatest", args | |
end | |
def least(*args) | |
Arel::Nodes::NamedFunction.new "least", args | |
end | |
def cast(pred, type) | |
Arel::Nodes::NamedFunction.new "cast", [pred.as(type)] | |
end | |
def null_if(column, value) | |
Arel::Nodes::NamedFunction.new "NULLIF", [column, value] | |
end | |
def predicate(pred, true_value, false_value) | |
Arel::Nodes::SqlLiteral.new("CASE WHEN #{sqlv(pred)} THEN #{sqlv(true_value)} ELSE #{sqlv(false_value)} END") | |
end | |
def tsrange(lower_or_range, upper = nil) | |
Arel::Nodes::NamedFunction.new "tsrange", range_params(lower_or_range, upper) | |
end | |
def tstzrange(lower_or_range, upper = nil) | |
Arel::Nodes::NamedFunction.new "tstzrange", range_params(lower_or_range, upper) | |
end | |
def overlap(a, b) | |
Arel::Nodes::InfixOperation.new "&&", a, b | |
end | |
def coalesce(*args) | |
Arel::Nodes::NamedFunction.new "coalesce", args | |
end | |
def hstore_key(hstore, key) | |
Arel::Nodes::InfixOperation.new "->", hstore, cloneable(key) | |
end | |
def concat(*args) | |
Arel::Nodes::NamedFunction.new "concat", args | |
end | |
def mod(a, b) | |
Arel::Nodes::InfixOperation.new "%", a, b | |
end | |
def to_char(input, format) | |
Arel::Nodes::NamedFunction.new "to_char", [input, format] | |
end | |
def string_agg(input, delimiter) | |
Arel::Nodes::NamedFunction.new "string_agg", [input, delimiter] | |
end | |
def between(pred, lower_or_range, upper = nil) | |
Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper))) | |
end | |
def unnest(array) | |
Arel::Nodes::NamedFunction.new "unnest", [array] | |
end | |
def array_agg(expression) | |
Arel::Nodes::NamedFunction.new "array_agg", [expression] | |
end | |
def lower(expression) | |
Arel::Nodes::NamedFunction.new "lower", [expression] | |
end | |
def accumulative_or(array) | |
array.inject do |expressions, expression| | |
if expressions === expression | |
expression | |
else | |
expressions.or(expression) | |
end | |
end | |
end | |
def array_intersect(a1, a2, opts = {}) | |
select1 = unnest(sqlv(a1)) | |
select2 = unnest(sqlv(a2)) | |
if !opts[:case_sensitive] | |
select1 = lower cast(select1, "text") | |
select2 = lower cast(select2, "text") | |
end | |
Arel::Nodes::SqlLiteral.new <<-SQL | |
ARRAY( | |
SELECT #{sqlv(select1)} INTERSECT | |
SELECT #{sqlv(select2)} | |
) | |
SQL | |
end | |
def descendants_search(table, id, max_depth: 999) | |
tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL | |
WITH RECURSIVE descendants_search(id, path) AS ( | |
SELECT id, ARRAY[id] | |
FROM #{table.name} | |
WHERE id = #{id} | |
UNION ALL | |
SELECT #{table.name}.id, (path || #{table.name}.id) | |
FROM descendants_search | |
JOIN #{table.name} | |
ON descendants_search.id = #{table.name}.reports_to_id | |
WHERE NOT #{table.name}.id = ANY(path) | |
AND NOT array_length(path,1) > #{max_depth} | |
) | |
SELECT id | |
FROM descendants_search | |
WHERE id != #{id} | |
ORDER BY array_length(path, 1), path | |
SQL | |
table[:id].in(tree_sql) | |
end | |
def ancestor_search(table, id) | |
tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL | |
WITH RECURSIVE ancestor_search(id, reports_to_id, path) AS ( | |
SELECT id, reports_to_id, ARRAY[id] | |
FROM #{table.name} | |
WHERE id = #{id} | |
UNION ALL | |
SELECT #{table.name}.id, #{table.name}.reports_to_id, (path || #{table.name}.id) | |
FROM ancestor_search | |
JOIN #{table.name} | |
ON ancestor_search.reports_to_id = #{table.name}.id | |
WHERE NOT #{table.name}.id = ANY(path) | |
) | |
SELECT id | |
FROM ancestor_search | |
WHERE id != #{id} | |
ORDER BY array_length(path, 1), path | |
SQL | |
table[:id].in(tree_sql) | |
end | |
def sqlv(node) | |
case node | |
when ->(n) { n.respond_to?(:to_sql) } | |
node.to_sql | |
when Arel::Attributes::Attribute | |
Arel::Nodes::SqlLiteral.new "\"#{node.relation.name}\".\"#{node.name}\"" | |
when Array, Range | |
value = node.map { |x| x.is_a?(String) ? "'#{x}'" : x }.join(",") | |
Arel::Nodes::SqlLiteral.new "ARRAY[#{value}]" | |
when Time, DateTime, Date | |
Arel::Nodes.build_quoted node | |
when String | |
Arel::Nodes.build_quoted node | |
else | |
Arel::Nodes::SqlLiteral.new node.to_s | |
end | |
end | |
def array_agg(expression) | |
Arel::Nodes::NamedFunction.new "array_agg", [expression] | |
end | |
def between(pred, lower_or_range, upper = nil) | |
Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper))) | |
end | |
# This is a special ordering SQL used inside methods like array_agg | |
# http://www.postgresql.org/docs/current/static/sql-expressions.html#SYNTAX-AGGREGATES | |
def order_by(a, b) | |
Arel::Nodes::SqlLiteral.new "#{sqlv(a)} ORDER BY #{sqlv(b)}" | |
end | |
def range_params(lower_or_range, upper = nil) | |
case lower_or_range | |
when Range | |
lower = lower_or_range.min | |
upper = lower_or_range.max | |
else | |
lower = lower_or_range | |
end | |
[sqlv(lower), sqlv(upper)] | |
end | |
def cloneable(obj) | |
case obj | |
when Symbol | |
Arel::Nodes.build_quoted obj.to_s | |
else | |
obj | |
end | |
end | |
def self.sort(node, order) | |
case order.try(:to_sym) | |
when :asc | |
Arel::Nodes::Ascending.new node | |
when :desc | |
Arel::Nodes::Descending.new node | |
else | |
raise ArgumentError, "Must pass in either :asc or :desc" | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment