Created
August 22, 2013 03:35
-
-
Save ssoroka/6302948 to your computer and use it in GitHub Desktop.
State machine in < 100 lines of ruby. No guards. (could be implemented trivially)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module StateMachine | |
def self.included(klass) | |
klass.send(:extend, ClassMethods) | |
klass.instance_eval do | |
after_initialize :set_initial_state | |
end | |
end | |
module ClassMethods | |
def state_column | |
'state' | |
end | |
def state(state, options = {}) | |
@initial_state ||= state | |
@after_state_change ||= [] | |
@on_enter_state ||= HashWithIndifferentAccess.new | |
@on_exit_state ||= HashWithIndifferentAccess.new | |
@states ||= {} | |
@states[state] = options | |
from = options[:from] || @states.keys | |
from.each do |from_state| | |
define_method("#{state}!") do | |
transition_to(state) | |
end | |
define_method("#{state}?") do | |
is?(state) | |
end | |
define_method("transition_from_#{from_state}_to_#{state}") do | |
on_exit_state = self.class.instance_variable_get("@on_exit_state") | |
on_enter_state = self.class.instance_variable_get("@on_enter_state") | |
if on_exit_state.try(:[], from_state) | |
on_exit_state[from_state].each do |blk| | |
instance_eval &blk | |
end | |
end | |
send("#{state_column}=", state) | |
if on_enter_state.try(:[], state) | |
on_enter_state[state].each do |blk| | |
instance_eval &blk | |
end | |
end | |
end | |
end | |
end | |
def on_enter_state(state, &block) | |
h = instance_variable_get("@on_enter_state") | |
h[state] ||= [] | |
h[state].push block | |
instance_variable_set("@on_enter_state", h) | |
end | |
def on_exit_state(state, &block) | |
h = instance_variable_get("@on_exit_state") | |
h[state] ||= [] | |
h[state].push block | |
instance_variable_set("@on_exit_state", h) | |
end | |
def after_state_change(*args, &block) | |
@after_state_change.push block | |
end | |
end | |
def state_column | |
self.class.state_column | |
end | |
def set_initial_state | |
current_value = send(state_column) | |
if current_value.nil? | |
send("#{state_column}=", self.class.instance_variable_get("@initial_state")) | |
end | |
end | |
def is?(state) | |
self.state.to_s == state.to_s | |
end | |
def can_transition_to?(new_state) | |
states = self.class.instance_variable_get("@states") | |
options = HashWithIndifferentAccess.new(states)[new_state] | |
options.blank? || options[:from].include?(state) | |
end | |
def transition_to(new_state) | |
old_state = state | |
send("transition_from_#{old_state}_to_#{new_state}") | |
save! | |
self.class.instance_variable_get("@after_state_change").each do |blk| | |
instance_exec(old_state, new_state, &blk) | |
end | |
self | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'test_helper' | |
require 'state_machine' | |
class StateMachineTest < ActiveSupport::TestCase | |
setup do | |
ActiveRecord::Migration.create_table "testtest" do |t| | |
t.string :state | |
end | |
@klass = Class.new(ActiveRecord::Base) do | |
self.table_name = 'testtest' | |
include StateMachine | |
state 'new' | |
state 'wip', from: %w(new) | |
state 'cancelled', from: %w(wip new) | |
state 'done', from: %w(wip) | |
end | |
end | |
test "state defaults to first state" do | |
klass = @klass.new | |
assert_equal 'new', klass.state | |
assert klass.can_transition_to?('wip') | |
assert klass.can_transition_to?(:wip) | |
assert !klass.can_transition_to?(:done) | |
end | |
test "can override state column" do | |
@klass.instance_eval do | |
attr_accessor :task_state | |
def state_column | |
'task_state' | |
end | |
end | |
klass = @klass.new | |
assert_equal 'new', klass.task_state | |
end | |
test "on_enter_state" do | |
@klass.instance_eval do | |
attr_accessor :wipped | |
on_enter_state :wip do | |
self.wipped = true | |
end | |
end | |
k = @klass.new | |
k.wipped = false | |
assert_equal false, k.wipped | |
k.wip! | |
assert_equal true, k.wipped | |
assert_equal 'wip', k.state | |
end | |
test "transitioning to an invalid state throws an error" do | |
assert_raises(NoMethodError) do | |
@klass.new.done! | |
end | |
end | |
test "multiple transitions" do | |
[email protected] | |
k.wip! | |
k.done! | |
assert_equal 'done', k.state | |
assert k.is? 'done' | |
assert k.is? :done | |
end | |
test "after_state_change" do | |
@klass.instance_eval do | |
attr_accessor :from_state, :to_state | |
after_state_change do |from_state, to_state| | |
self.from_state = from_state | |
self.to_state = to_state | |
end | |
end | |
[email protected] | |
assert_equal nil, k.from_state | |
assert_equal nil, k.to_state | |
k.wip! | |
assert_equal 'new', k.from_state | |
assert_equal 'wip', k.to_state | |
k.cancelled! | |
assert_equal 'wip', k.from_state | |
assert_equal 'cancelled', k.to_state | |
assert_raises NoMethodError do | |
k.wip! | |
end | |
# make sure it's unchanged: | |
assert_equal 'wip', k.from_state | |
assert_equal 'cancelled', k.to_state | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# example usage. | |
class Task < ApplicationModel | |
has_many :state_histories | |
scope :active, where(state: %w(new running paused deferred)) | |
include StateMachine | |
state :new | |
state :running, from: %w(new paused deferred) | |
state :paused, from: %w(running) | |
state :deferred, from: %w(new running paused deferred) | |
state :completed, from: %w(new running paused deferred) | |
state :closed, from: %w(new running paused deferred) | |
after_state_change do |from_state, to_state| | |
state_histories.create!(from_state: from_state, to_state: to_state) | |
end | |
on_enter_state :running do | |
self.timer_started_at = Time.now.utc | |
end | |
on_exit_state :running do | |
running_time = (Time.now.utc - timer_started_at).round | |
time_logs.create!(task: self, seconds: running_time) | |
self.timer_started_at = nil | |
end | |
#... | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment