1import re
2import inspect
3from .backwardscompat import callable
4
5# metaclass implementation idea from
6# http://blog.ianbicking.org/more-on-python-metaprogramming-comment-14.html
7_transition_gatherer = []
8
9def transition(event, from_, to, action=None, guard=None):
10 _transition_gatherer.append([event, from_, to, action, guard])
11
12_state_gatherer = []
13
14def state(name, enter=None, exit=None):
15 _state_gatherer.append([name, enter, exit])
16
17
18class MetaStateMachine(type):
19
20 def __new__(cls, name, bases, dictionary):
21 global _transition_gatherer, _state_gatherer
22 Machine = super(MetaStateMachine, cls).__new__(cls, name, bases, dictionary)
23 Machine._class_transitions = []
24 Machine._class_states = {}
25 for s in _state_gatherer:
26 Machine._add_class_state(*s)
27 for i in _transition_gatherer:
28 Machine._add_class_transition(*i)
29 _transition_gatherer = []
30 _state_gatherer = []
31 return Machine
32
33
34StateMachineBase = MetaStateMachine('StateMachineBase', (object, ), {})
35
36
37class StateMachine(StateMachineBase):
38
39 def __init__(self):
40 self._bring_definitions_to_object_level()
41 self._inject_into_parts()
42 self._validate_machine_definitions()
43 if callable(self.initial_state):
44 self.initial_state = self.initial_state()
45 self._current_state_object = self._state_by_name(self.initial_state)
46 self._current_state_object.run_enter(self)
47 self._create_state_getters()
48
49 def __new__(cls, *args, **kwargs):
50 obj = super(StateMachine, cls).__new__(cls)
51 obj._states = {}
52 obj._transitions = []
53 return obj
54
55 def _bring_definitions_to_object_level(self):
56 self._states.update(self.__class__._class_states)
57 self._transitions.extend(self.__class__._class_transitions)
58
59 def _inject_into_parts(self):
60 for collection in [self._states.values(), self._transitions]:
61 for component in collection:
62 component.machine = self
63
64 def _validate_machine_definitions(self):
65 if len(self._states) < 2:
66 raise InvalidConfiguration('There must be at least two states')
67 if not getattr(self, 'initial_state', None):
68 raise InvalidConfiguration('There must exist an initial state')
69
70 @classmethod
71 def _add_class_state(cls, name, enter, exit):
72 cls._class_states[name] = _State(name, enter, exit)
73
74 def add_state(self, name, enter=None, exit=None):
75 state = _State(name, enter, exit)
76 setattr(self, state.getter_name(), state.getter_method().__get__(self, self.__class__))
77 self._states[name] = state
78
79 def _current_state_name(self):
80 return self._current_state_object.name
81
82 current_state = property(_current_state_name)
83
84 def changing_state(self, from_, to):
85 """
86 This method is called whenever a state change is executed
87 """
88 pass
89
90 def _new_state(self, state):
91 self.changing_state(self._current_state_object.name, state.name)
92 self._current_state_object = state
93
94 def _state_objects(self):
95 return list(self._states.values())
96
97 def states(self):
98 return [s.name for s in self._state_objects()]
99
100 @classmethod
101 def _add_class_transition(cls, event, from_, to, action, guard):
102 transition = _Transition(event, [cls._class_states[s] for s in _listize(from_)],
103 cls._class_states[to], action, guard)
104 cls._class_transitions.append(transition)
105 setattr(cls, event, transition.event_method())
106
107 def add_transition(self, event, from_, to, action=None, guard=None):
108 transition = _Transition(event, [self._state_by_name(s) for s in _listize(from_)],
109 self._state_by_name(to), action, guard)
110 self._transitions.append(transition)
111 setattr(self, event, transition.event_method().__get__(self, self.__class__))
112
113 def _process_transitions(self, event_name, *args, **kwargs):
114 transitions = self._transitions_by_name(event_name)
115 transitions = self._ensure_from_validity(transitions)
116 this_transition = self._check_guards(transitions)
117 this_transition.run(self, *args, **kwargs)
118
119 def _create_state_getters(self):
120 for state in self._state_objects():
121 setattr(self, state.getter_name(), state.getter_method().__get__(self, self.__class__))
122
123 def _state_by_name(self, name):
124 for state in self._state_objects():
125 if state.name == name:
126 return state
127
128 def _transitions_by_name(self, name):
129 return list(filter(lambda transition: transition.event == name, self._transitions))
130
131 def _ensure_from_validity(self, transitions):
132 valid_transitions = list(filter(
133 lambda transition: transition.is_valid_from(self._current_state_object),
134 transitions))
135 if len(valid_transitions) == 0:
136 raise InvalidTransition("Cannot %s from %s" % (
137 transitions[0].event, self.current_state))
138 return valid_transitions
139
140 def _check_guards(self, transitions):
141 allowed_transitions = []
142 for transition in transitions:
143 if transition.check_guard(self):
144 allowed_transitions.append(transition)
145 if len(allowed_transitions) == 0:
146 raise GuardNotSatisfied("Guard is not satisfied for this transition")
147 elif len(allowed_transitions) > 1:
148 raise ForkedTransition("More than one transition was allowed for this event")
149 return allowed_transitions[0]
150
151
152class _Transition(object):
153
154 def __init__(self, event, from_, to, action, guard):
155 self.event = event
156 self.from_ = from_
157 self.to = to
158 self.action = action
159 self.guard = _Guard(guard)
160
161 def event_method(self):
162 def generated_event(machine, *args, **kwargs):
163 these_transitions = machine._process_transitions(self.event, *args, **kwargs)
164 generated_event.__doc__ = 'event %s' % self.event
165 generated_event.__name__ = self.event
166 return generated_event
167
168 def is_valid_from(self, from_):
169 return from_ in _listize(self.from_)
170
171 def check_guard(self, machine):
172 return self.guard.check(machine)
173
174 def run(self, machine, *args, **kwargs):
175 machine._current_state_object.run_exit(machine)
176 machine._new_state(self.to)
177 self.to.run_enter(machine)
178 _ActionRunner(machine).run(self.action, *args, **kwargs)
179
180
181class _Guard(object):
182
183 def __init__(self, action):
184 self.action = action
185
186 def check(self, machine):
187 if self.action is None:
188 return True
189 items = _listize(self.action)
190 result = True
191 for item in items:
192 result = result and self._evaluate(machine, item)
193 return result
194
195 def _evaluate(self, machine, item):
196 if callable(item):
197 return item(machine)
198 else:
199 guard = getattr(machine, item)
200 if callable(guard):
201 guard = guard()
202 return guard
203
204
205class _State(object):
206
207 def __init__(self, name, enter, exit):
208 self.name = name
209 self.enter = enter
210 self.exit = exit
211
212 def getter_name(self):
213 return 'is_%s' % self.name
214
215 def getter_method(self):
216 def state_getter(self_machine):
217 return self_machine.current_state == self.name
218 return state_getter
219
220 def run_enter(self, machine):
221 _ActionRunner(machine).run(self.enter)
222
223 def run_exit(self, machine):
224 _ActionRunner(machine).run(self.exit)
225
226
227class _ActionRunner(object):
228
229 def __init__(self, machine):
230 self.machine = machine
231
232 def run(self, action_param, *args, **kwargs):
233 if not action_param:
234 return
235 action_items = _listize(action_param)
236 for action_item in action_items:
237 self._run_action(action_item, *args, **kwargs)
238
239 def _run_action(self, action, *args, **kwargs):
240 if callable(action):
241 self._try_to_run_with_args(action, self.machine, *args, **kwargs)
242 else:
243 self._try_to_run_with_args(getattr(self.machine, action), *args, **kwargs)
244
245 def _try_to_run_with_args(self, action, *args, **kwargs):
246 try:
247 action(*args, **kwargs)
248 except TypeError:
249 action()
250
251
252class InvalidConfiguration(Exception):
253 pass
254
255
256class InvalidTransition(Exception):
257 pass
258
259
260class GuardNotSatisfied(Exception):
261 pass
262
263
264class ForkedTransition(Exception):
265 pass
266
267
268def _listize(value):
269 return type(value) in [list, tuple] and value or [value]
270