Last active
August 5, 2016 18:47
-
-
Save dkubb/874e9d1dc9631c8c16f9022ef0779e7e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module EnumerableExtensions | |
# An exception raised when an invalid number of entries is returned | |
class InvalidCountError < StandardError | |
# Initialize an exception to report an invalid enumerable count | |
# | |
# @param expectation [String] | |
# @param entries [Array] | |
# | |
# @return [undefined] | |
# | |
# @api public | |
def initialize(expectation, entries) | |
super( | |
'Found %{count}, expected %{expectation}' % { | |
count: entries.count, | |
expectation: expectation | |
} | |
) | |
end | |
end # InvalidCountError | |
# Object to represent undefined arguments | |
Undefined = Object.new.freeze | |
# Return exactly one entry from the enumerable | |
# | |
# @param default [Object] | |
# | |
# @yield [count, entries] | |
# | |
# @yieldparam [Integer] count | |
# @yieldparam [Enumerable] entries | |
# | |
# @yieldreturn [Object] | |
# return the default from the block, if provided | |
# | |
# @return [Object] | |
# returned if exactly one entry is found | |
# | |
# @raise [InvalidCountError] | |
# raised if zero or more than one entry is found and there is no default | |
# | |
# @api public | |
def one(default = Undefined) | |
block = -> (*block_args) { [yield(*block_args)] } if block_given? | |
result = if block || default.equal?(Undefined) | |
exactly(1, default, &block) | |
else | |
exactly(1, [default]) | |
end | |
result.fetch(0) | |
end | |
# Return one or more entries from the enumerable | |
# | |
# @return [Array] | |
# returned if one or more entries | |
# | |
# @raise [InvalidCountError] | |
# raised if zero entries are found | |
# | |
# @api public | |
def min_one | |
entries = to_a | |
fail InvalidCountError.new('one or more', entries) if entries.none? | |
entries | |
end | |
# Return zero or one entry from the enumerable | |
# | |
# @return [Object] | |
# returned if zero or one entry is found | |
# | |
# @raise [InvalidCountError] | |
# raised if more than one entry is ound | |
# | |
# @api public | |
def max_one | |
entries = take(2).to_a | |
fail InvalidCountError.new('zero or one', entries) if entries.many? | |
entries.first | |
end | |
# Return an exact number of entries from the enumerable | |
# | |
# @param count [Integer] | |
# @param default [Object] | |
# | |
# @yield [count, entries] | |
# | |
# @yieldparam [Integer] count | |
# @yieldparam [Enumerable] entries | |
# | |
# @yieldreturn [Object] | |
# return the default from the block, if provided | |
# | |
# @return [Enumerable] | |
# returned if exactly one entry is found | |
# | |
# @raise [InvalidCountError] | |
# raised if an invalid number of entries is found | |
# | |
# @api public | |
def exactly(count, default = Undefined, &block) | |
assert_default_or_block(default, &block) | |
entries = take(count.succ).to_a | |
return entries if entries.count.equal?(count) | |
return default unless default.equal?(Undefined) | |
block ||= -> (*args) { fail(InvalidCountError.new(*args)) } | |
block.call(count, self) | |
end | |
private | |
# Assert that a block and default argument cannot be provided together | |
# | |
# @param default [Object] | |
# | |
# @raise [ArgumentError] | |
# raised if a block and default value are provided | |
# | |
# @api private | |
def assert_default_or_block(default) | |
return unless block_given? && !default.equal?(Undefined) | |
fail ArgumentError, 'Must pass in a block or a default argument, not both' | |
end | |
end # EnumerableExtensions | |
ActiveRecord::Base.extend(EnumerableExtensions) | |
Array.module_eval { include EnumerableExtensions } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment