Created
March 4, 2016 21:10
-
-
Save mohiji/0f74ab1fd98655e0466b to your computer and use it in GitHub Desktop.
GameplayKit style state machines in pure Swift
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
class State { | |
var stateMachine: StateMachine? = nil | |
func isValidNextState(stateType: State.Type) -> Bool { | |
return true | |
} | |
func didEnter(previousState: State?) {} | |
func willExit(nextState: State) {} | |
func update(deltaTime: Float) {} | |
} | |
class StateMachine { | |
let states: [State] | |
var currentState: State? = nil | |
init(_ states: [State]) { | |
self.states = states | |
for state in states { | |
state.stateMachine = self | |
} | |
} | |
func getState<T : State>(stateType: T.Type) -> T? { | |
for state in states { | |
if state.dynamicType == stateType { | |
return state as? T | |
} | |
} | |
return nil | |
} | |
func canEnterState(stateType: State.Type) -> Bool { | |
// Before anything else, see if the requested type exists in our states list | |
if let _ = getState(stateType) { | |
if let currentState = currentState { | |
return currentState.isValidNextState(stateType) | |
} else { | |
return true | |
} | |
} else { | |
return false | |
} | |
} | |
func enterState(stateType: State.Type) -> Bool { | |
if let nextState = getState(stateType) { | |
if let previousState = currentState { | |
if previousState.isValidNextState(stateType) { | |
previousState.willExit(nextState) | |
currentState = nextState | |
nextState.didEnter(previousState) | |
return true | |
} else { | |
return false | |
} | |
} else { | |
self.currentState = nextState | |
nextState.didEnter(nil) | |
return true | |
} | |
} | |
return false | |
} | |
func update(deltaTime: Float) { | |
currentState?.update(deltaTime) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import XCTest | |
/******************************************** | |
* The state graph used in these tests | |
* | |
* StateOne -----> StateTwo -----> State Three | |
* ^ | ^ | | |
* | | | | | |
* -------------| |----------------| | |
* | |
* StateThree will set properties when it receives any of the lifecycle messages | |
* so that we can make sure they were called. | |
* | |
* StateFour is a valid State, but isn't part of the state machine. | |
* | |
* And then a fake one that's just not a state: NotAState | |
*/ | |
class DoNothingState: State {} | |
class StateOne: State { | |
override func isValidNextState(stateType: State.Type) -> Bool { | |
return stateType == StateTwo.self | |
} | |
} | |
class StateTwo: State { | |
override func isValidNextState(stateType: State.Type) -> Bool { | |
return stateType == StateOne.self || | |
stateType == StateThree.self | |
} | |
} | |
class StateThree: State { | |
var didEnterWasCalled = false | |
var willExitWasCalled = false | |
var updateWasCalled = false | |
override func isValidNextState(stateType: State.Type) -> Bool { | |
return stateType == StateTwo.self | |
} | |
override func didEnter(previousState: State?) { | |
didEnterWasCalled = true | |
} | |
override func willExit(nextState: State) { | |
willExitWasCalled = true | |
} | |
override func update(deltaTime: Float) { | |
updateWasCalled = true | |
} | |
} | |
class StateFour: State {} | |
class StateMachineTests: XCTestCase { | |
func testDefaultStateMethods() { | |
let doNothingState = DoNothingState() | |
XCTAssertTrue(doNothingState.isValidNextState(DoNothingState.self)) | |
} | |
func testGetState() { | |
let doNothingState = DoNothingState() | |
let machine = StateMachine([doNothingState]) | |
let state = machine.getState(DoNothingState.self) | |
XCTAssertNotNil(state) | |
} | |
func testAllowedNextStates() { | |
let stateOne = StateOne() | |
XCTAssertTrue(stateOne.isValidNextState(StateTwo.self)) | |
XCTAssertFalse(stateOne.isValidNextState(StateOne.self)) | |
XCTAssertFalse(stateOne.isValidNextState(StateThree.self)) | |
} | |
func testCanEnterInitialState() { | |
let stateOne = StateOne() | |
let states: [State] = [stateOne, StateTwo(), StateThree()] | |
let machine = StateMachine(states) | |
XCTAssertTrue(machine.canEnterState(StateOne.self)) | |
XCTAssertTrue(machine.canEnterState(StateTwo.self)) | |
XCTAssertTrue(machine.canEnterState(StateThree.self)) | |
XCTAssertTrue(machine.enterState(StateOne.self)) | |
if let currentState = machine.currentState { | |
XCTAssertEqual(ObjectIdentifier(currentState), ObjectIdentifier(stateOne)) | |
} | |
} | |
func testCanEnterNextState() { | |
let states = [StateOne(), StateTwo(), StateThree()] | |
let machine = StateMachine(states) | |
XCTAssertTrue(machine.enterState(StateOne.self)) | |
XCTAssertFalse(machine.canEnterState(StateThree.self)) | |
XCTAssertTrue(machine.enterState(StateTwo.self)) | |
XCTAssertTrue(machine.canEnterState(StateThree.self)) | |
} | |
func testLifecycleMethods() { | |
let stateThree = StateThree() | |
let states = [StateOne(), StateTwo(), stateThree] | |
let machine = StateMachine(states) | |
machine.enterState(StateOne.self) | |
machine.enterState(StateTwo.self) | |
machine.enterState(StateThree.self) | |
XCTAssertTrue(stateThree.didEnterWasCalled) | |
machine.update(1.0) | |
XCTAssertTrue(stateThree.updateWasCalled) | |
machine.enterState(StateTwo.self) | |
XCTAssertTrue(stateThree.willExitWasCalled) | |
} | |
func testCantEnterInvalidState() { | |
let states = [StateOne(), StateTwo(), StateThree()] | |
let machine = StateMachine(states) | |
XCTAssertFalse(machine.enterState(StateFour.self)) | |
XCTAssertTrue(machine.enterState(StateOne.self)) | |
XCTAssertFalse(machine.enterState(StateThree.self)) | |
} | |
func testStateMachinePropertySet() { | |
let state = StateOne() | |
let machine = StateMachine([state]) | |
XCTAssertEqual(ObjectIdentifier(state.stateMachine!), ObjectIdentifier(machine)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment