Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scapy/automaton.py: 37%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) Philippe Biondi <phil@secdev.org>
5# Copyright (C) Gabriel Potter
7"""
8Automata with states, transitions and actions.
10TODO:
11 - add documentation for ioevent, as_supersocket...
12"""
14import ctypes
15import itertools
16import logging
17import os
18import random
19import socket
20import sys
21import threading
22import time
23import traceback
24import types
26import select
27from collections import deque
29from scapy.config import conf
30from scapy.consts import WINDOWS
31from scapy.data import MTU
32from scapy.error import log_runtime, warning
33from scapy.interfaces import _GlobInterfaceType
34from scapy.packet import Packet
35from scapy.plist import PacketList
36from scapy.supersocket import SuperSocket, StreamSocket
37from scapy.utils import do_graph
39# Typing imports
40from typing import (
41 Any,
42 Callable,
43 Deque,
44 Dict,
45 Generic,
46 Iterable,
47 Iterator,
48 List,
49 Optional,
50 Set,
51 Tuple,
52 Type,
53 TypeVar,
54 Union,
55 cast,
56)
57from scapy.compat import DecoratorCallable
60def select_objects(inputs, remain):
61 # type: (Iterable[Any], Union[float, int, None]) -> List[Any]
62 """
63 Select objects. Same than:
64 ``select.select(inputs, [], [], remain)``
66 But also works on Windows, only on objects whose fileno() returns
67 a Windows event. For simplicity, just use `ObjectPipe()` as a queue
68 that you can select on whatever the platform is.
70 If you want an object to be always included in the output of
71 select_objects (i.e. it's not selectable), just make fileno()
72 return a strictly negative value.
74 Example:
76 >>> a, b = ObjectPipe("a"), ObjectPipe("b")
77 >>> b.send("test")
78 >>> select_objects([a, b], 1)
79 [b]
81 :param inputs: objects to process
82 :param remain: timeout. If 0, poll. If None, block.
83 """
84 if not WINDOWS:
85 return select.select(inputs, [], [], remain)[0]
86 natives = []
87 events = []
88 results = set()
89 for i in list(inputs):
90 if getattr(i, "__selectable_force_select__", False):
91 natives.append(i)
92 elif i.fileno() < 0:
93 # Special case: On Windows, we consider that an object that returns
94 # a negative fileno (impossible), is always readable. This is used
95 # in very few places but important (e.g. PcapReader), where we have
96 # no valid fileno (and will stop on EOFError).
97 results.add(i)
98 else:
99 events.append(i)
100 if natives:
101 results = results.union(set(select.select(natives, [], [], remain)[0]))
102 if results:
103 # We have native results, poll.
104 remain = 0
105 if events:
106 # 0xFFFFFFFF = INFINITE
107 remainms = int(remain * 1000 if remain is not None else 0xFFFFFFFF)
108 if len(events) == 1:
109 res = ctypes.windll.kernel32.WaitForSingleObject(
110 ctypes.c_void_p(events[0].fileno()),
111 remainms
112 )
113 else:
114 # Sadly, the only way to emulate select() is to first check
115 # if any object is available using WaitForMultipleObjects
116 # then poll the others.
117 res = ctypes.windll.kernel32.WaitForMultipleObjects(
118 len(events),
119 (ctypes.c_void_p * len(events))(
120 *[x.fileno() for x in events]
121 ),
122 False,
123 remainms
124 )
125 if res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout
126 results.add(events[res])
127 if len(events) > 1:
128 # Now poll the others, if any
129 for evt in events:
130 res = ctypes.windll.kernel32.WaitForSingleObject(
131 ctypes.c_void_p(evt.fileno()),
132 0 # poll: don't wait
133 )
134 if res == 0:
135 results.add(evt)
136 return list(results)
139_T = TypeVar("_T")
142class ObjectPipe(Generic[_T]):
143 def __init__(self, name=None):
144 # type: (Optional[str]) -> None
145 self.name = name or "ObjectPipe"
146 self.closed = False
147 self.__rd, self.__wr = os.pipe()
148 self.__queue = deque() # type: Deque[_T]
149 if WINDOWS:
150 self._wincreate()
152 if WINDOWS:
153 def _wincreate(self):
154 # type: () -> None
155 self._fd = cast(int, ctypes.windll.kernel32.CreateEventA(
156 None, True, False,
157 ctypes.create_string_buffer(b"ObjectPipe %f" % random.random())
158 ))
160 def _winset(self):
161 # type: () -> None
162 if ctypes.windll.kernel32.SetEvent(ctypes.c_void_p(self._fd)) == 0:
163 warning(ctypes.FormatError(ctypes.GetLastError()))
165 def _winreset(self):
166 # type: () -> None
167 if ctypes.windll.kernel32.ResetEvent(ctypes.c_void_p(self._fd)) == 0:
168 warning(ctypes.FormatError(ctypes.GetLastError()))
170 def _winclose(self):
171 # type: () -> None
172 if ctypes.windll.kernel32.CloseHandle(ctypes.c_void_p(self._fd)) == 0:
173 warning(ctypes.FormatError(ctypes.GetLastError()))
175 def fileno(self):
176 # type: () -> int
177 if WINDOWS:
178 return self._fd
179 return self.__rd
181 def send(self, obj):
182 # type: (_T) -> int
183 self.__queue.append(obj)
184 if WINDOWS:
185 self._winset()
186 os.write(self.__wr, b"X")
187 return 1
189 def write(self, obj):
190 # type: (_T) -> None
191 self.send(obj)
193 def empty(self):
194 # type: () -> bool
195 return not bool(self.__queue)
197 def flush(self):
198 # type: () -> None
199 pass
201 def recv(self, n=0, options=socket.MsgFlag(0)):
202 # type: (Optional[int], socket.MsgFlag) -> Optional[_T]
203 if self.closed:
204 raise EOFError
205 if options & socket.MSG_PEEK:
206 if self.__queue:
207 return self.__queue[0]
208 return None
209 os.read(self.__rd, 1)
210 elt = self.__queue.popleft()
211 if WINDOWS and not self.__queue:
212 self._winreset()
213 return elt
215 def read(self, n=0):
216 # type: (Optional[int]) -> Optional[_T]
217 return self.recv(n)
219 def clear(self):
220 # type: () -> None
221 if not self.closed:
222 while not self.empty():
223 self.recv()
225 def close(self):
226 # type: () -> None
227 if not self.closed:
228 self.closed = True
229 os.close(self.__rd)
230 os.close(self.__wr)
231 if WINDOWS:
232 try:
233 self._winclose()
234 except ImportError:
235 # Python is shutting down
236 pass
238 def __repr__(self):
239 # type: () -> str
240 return "<%s at %s>" % (self.name, id(self))
242 def __del__(self):
243 # type: () -> None
244 self.close()
246 @staticmethod
247 def select(sockets, remain=conf.recv_poll_rate):
248 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
249 # Only handle ObjectPipes
250 results = []
251 for s in sockets:
252 if s.closed: # allow read to trigger EOF
253 results.append(s)
254 if results:
255 return results
256 return select_objects(sockets, remain)
259class Message:
260 type = None # type: str
261 pkt = None # type: Packet
262 result = None # type: str
263 state = None # type: Message
264 exc_info = None # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501
266 def __init__(self, **args):
267 # type: (Any) -> None
268 self.__dict__.update(args)
270 def __repr__(self):
271 # type: () -> str
272 return "<Message %s>" % " ".join(
273 "%s=%r" % (k, v)
274 for k, v in self.__dict__.items()
275 if not k.startswith("_")
276 )
279class Timer():
280 def __init__(self, time, prio=0, autoreload=False):
281 # type: (Union[int, float], int, bool) -> None
282 self._timeout = float(time) # type: float
283 self._time = 0 # type: float
284 self._just_expired = True
285 self._expired = True
286 self._prio = prio
287 self._func = _StateWrapper()
288 self._autoreload = autoreload
290 def get(self):
291 # type: () -> float
292 return self._timeout
294 def set(self, val):
295 # type: (float) -> None
296 self._timeout = val
298 def _reset(self):
299 # type: () -> None
300 self._time = self._timeout
301 self._expired = False
302 self._just_expired = False
304 def _reset_just_expired(self):
305 # type: () -> None
306 self._just_expired = False
308 def _running(self):
309 # type: () -> bool
310 return self._time > 0
312 def _remaining(self):
313 # type: () -> float
314 return max(self._time, 0)
316 def _decrement(self, time):
317 # type: (float) -> None
318 self._time -= time
319 if self._time <= 0:
320 if not self._expired:
321 self._just_expired = True
322 if self._autoreload:
323 # take overshoot into account
324 self._time = self._timeout + self._time
325 else:
326 self._expired = True
327 self._time = 0
329 def __lt__(self, obj):
330 # type: (Timer) -> bool
331 return ((self._time < obj._time) if self._time != obj._time
332 else (self._prio < obj._prio))
334 def __gt__(self, obj):
335 # type: (Timer) -> bool
336 return ((self._time > obj._time) if self._time != obj._time
337 else (self._prio > obj._prio))
339 def __eq__(self, obj):
340 # type: (Any) -> bool
341 if not isinstance(obj, Timer):
342 raise NotImplementedError()
343 return (self._time == obj._time) and (self._prio == obj._prio)
345 def __repr__(self):
346 # type: () -> str
347 return "<Timer %f(%f)>" % (self._time, self._timeout)
350class _TimerList():
351 def __init__(self):
352 # type: () -> None
353 self.timers = [] # type: list[Timer]
355 def add_timer(self, timer):
356 # type: (Timer) -> None
357 self.timers.append(timer)
359 def reset(self):
360 # type: () -> None
361 for t in self.timers:
362 t._reset()
364 def decrement(self, time):
365 # type: (float) -> None
366 for t in self.timers:
367 t._decrement(time)
369 def expired(self):
370 # type: () -> list[Timer]
371 lst = [t for t in self.timers if t._just_expired]
372 lst.sort(key=lambda x: x._prio, reverse=True)
373 for t in lst:
374 t._reset_just_expired()
375 return lst
377 def until_next(self):
378 # type: () -> Optional[float]
379 try:
380 return min([t._remaining() for t in self.timers if t._running()])
381 except ValueError:
382 return None # None means blocking
384 def count(self):
385 # type: () -> int
386 return len(self.timers)
388 def __iter__(self):
389 # type: () -> Iterator[Timer]
390 return self.timers.__iter__()
392 def __repr__(self):
393 # type: () -> str
394 return self.timers.__repr__()
397class _instance_state:
398 def __init__(self, instance):
399 # type: (Any) -> None
400 self.__self__ = instance.__self__
401 self.__func__ = instance.__func__
402 self.__self__.__class__ = instance.__self__.__class__
404 def __getattr__(self, attr):
405 # type: (str) -> Any
406 return getattr(self.__func__, attr)
408 def __call__(self, *args, **kargs):
409 # type: (Any, Any) -> Any
410 return self.__func__(self.__self__, *args, **kargs)
412 def breaks(self):
413 # type: () -> Any
414 return self.__self__.add_breakpoints(self.__func__)
416 def intercepts(self):
417 # type: () -> Any
418 return self.__self__.add_interception_points(self.__func__)
420 def unbreaks(self):
421 # type: () -> Any
422 return self.__self__.remove_breakpoints(self.__func__)
424 def unintercepts(self):
425 # type: () -> Any
426 return self.__self__.remove_interception_points(self.__func__)
429##############
430# Automata #
431##############
433class _StateWrapper:
434 __name__ = None # type: str
435 atmt_type = None # type: str
436 atmt_state = None # type: str
437 atmt_initial = None # type: int
438 atmt_final = None # type: int
439 atmt_stop = None # type: int
440 atmt_error = None # type: int
441 atmt_origfunc = None # type: _StateWrapper
442 atmt_prio = None # type: int
443 atmt_as_supersocket = None # type: Optional[str]
444 atmt_condname = None # type: str
445 atmt_ioname = None # type: str
446 atmt_timeout = None # type: Timer
447 atmt_cond = None # type: Dict[str, int]
448 __code__ = None # type: types.CodeType
449 __call__ = None # type: Callable[..., ATMT.NewStateRequested]
452class ATMT:
453 STATE = "State"
454 ACTION = "Action"
455 CONDITION = "Condition"
456 RECV = "Receive condition"
457 TIMEOUT = "Timeout condition"
458 EOF = "EOF condition"
459 IOEVENT = "I/O event"
461 class NewStateRequested(Exception):
462 def __init__(self, state_func, automaton, *args, **kargs):
463 # type: (Any, ATMT, Any, Any) -> None
464 self.func = state_func
465 self.state = state_func.atmt_state
466 self.initial = state_func.atmt_initial
467 self.error = state_func.atmt_error
468 self.stop = state_func.atmt_stop
469 self.final = state_func.atmt_final
470 Exception.__init__(self, "Request state [%s]" % self.state)
471 self.automaton = automaton
472 self.args = args
473 self.kargs = kargs
474 self.action_parameters() # init action parameters
476 def action_parameters(self, *args, **kargs):
477 # type: (Any, Any) -> ATMT.NewStateRequested
478 self.action_args = args
479 self.action_kargs = kargs
480 return self
482 def run(self):
483 # type: () -> Any
484 return self.func(self.automaton, *self.args, **self.kargs)
486 def __repr__(self):
487 # type: () -> str
488 return "NewStateRequested(%s)" % self.state
490 @staticmethod
491 def state(initial=0, # type: int
492 final=0, # type: int
493 stop=0, # type: int
494 error=0 # type: int
495 ):
496 # type: (...) -> Callable[[DecoratorCallable], DecoratorCallable]
497 def deco(f, initial=initial, final=final):
498 # type: (_StateWrapper, int, int) -> _StateWrapper
499 f.atmt_type = ATMT.STATE
500 f.atmt_state = f.__name__
501 f.atmt_initial = initial
502 f.atmt_final = final
503 f.atmt_stop = stop
504 f.atmt_error = error
506 def _state_wrapper(self, *args, **kargs):
507 # type: (ATMT, Any, Any) -> ATMT.NewStateRequested
508 return ATMT.NewStateRequested(f, self, *args, **kargs)
510 state_wrapper = cast(_StateWrapper, _state_wrapper)
511 state_wrapper.__name__ = "%s_wrapper" % f.__name__
512 state_wrapper.atmt_type = ATMT.STATE
513 state_wrapper.atmt_state = f.__name__
514 state_wrapper.atmt_initial = initial
515 state_wrapper.atmt_final = final
516 state_wrapper.atmt_stop = stop
517 state_wrapper.atmt_error = error
518 state_wrapper.atmt_origfunc = f
519 return state_wrapper
520 return deco # type: ignore
522 @staticmethod
523 def action(cond, prio=0):
524 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
525 def deco(f, cond=cond):
526 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
527 if not hasattr(f, "atmt_type"):
528 f.atmt_cond = {}
529 f.atmt_type = ATMT.ACTION
530 f.atmt_cond[cond.atmt_condname] = prio
531 return f
532 return deco
534 @staticmethod
535 def condition(state, prio=0):
536 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
537 def deco(f, state=state):
538 # type: (_StateWrapper, _StateWrapper) -> Any
539 f.atmt_type = ATMT.CONDITION
540 f.atmt_state = state.atmt_state
541 f.atmt_condname = f.__name__
542 f.atmt_prio = prio
543 return f
544 return deco
546 @staticmethod
547 def receive_condition(state, prio=0):
548 # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
549 def deco(f, state=state):
550 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
551 f.atmt_type = ATMT.RECV
552 f.atmt_state = state.atmt_state
553 f.atmt_condname = f.__name__
554 f.atmt_prio = prio
555 return f
556 return deco
558 @staticmethod
559 def ioevent(state, # type: _StateWrapper
560 name, # type: str
561 prio=0, # type: int
562 as_supersocket=None # type: Optional[str]
563 ):
564 # type: (...) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
565 def deco(f, state=state):
566 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
567 f.atmt_type = ATMT.IOEVENT
568 f.atmt_state = state.atmt_state
569 f.atmt_condname = f.__name__
570 f.atmt_ioname = name
571 f.atmt_prio = prio
572 f.atmt_as_supersocket = as_supersocket
573 return f
574 return deco
576 @staticmethod
577 def timeout(state, timeout):
578 # type: (_StateWrapper, Union[int, float]) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501
579 def deco(f, state=state, timeout=Timer(timeout)):
580 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper
581 f.atmt_type = ATMT.TIMEOUT
582 f.atmt_state = state.atmt_state
583 f.atmt_timeout = timeout
584 f.atmt_timeout._func = f
585 f.atmt_condname = f.__name__
586 return f
587 return deco
589 @staticmethod
590 def timer(state, timeout, prio=0):
591 # type: (_StateWrapper, Union[float, int], int) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501
592 def deco(f, state=state, timeout=Timer(timeout, prio=prio, autoreload=True)):
593 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper
594 f.atmt_type = ATMT.TIMEOUT
595 f.atmt_state = state.atmt_state
596 f.atmt_timeout = timeout
597 f.atmt_timeout._func = f
598 f.atmt_condname = f.__name__
599 return f
600 return deco
602 @staticmethod
603 def eof(state):
604 # type: (_StateWrapper) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
605 def deco(f, state=state):
606 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
607 f.atmt_type = ATMT.EOF
608 f.atmt_state = state.atmt_state
609 f.atmt_condname = f.__name__
610 return f
611 return deco
614class _ATMT_Command:
615 RUN = "RUN"
616 NEXT = "NEXT"
617 FREEZE = "FREEZE"
618 STOP = "STOP"
619 FORCESTOP = "FORCESTOP"
620 END = "END"
621 EXCEPTION = "EXCEPTION"
622 SINGLESTEP = "SINGLESTEP"
623 BREAKPOINT = "BREAKPOINT"
624 INTERCEPT = "INTERCEPT"
625 ACCEPT = "ACCEPT"
626 REPLACE = "REPLACE"
627 REJECT = "REJECT"
630class _ATMT_supersocket(SuperSocket):
631 def __init__(self,
632 name, # type: str
633 ioevent, # type: str
634 automaton, # type: Type[Automaton]
635 proto, # type: Callable[[bytes], Any]
636 *args, # type: Any
637 **kargs # type: Any
638 ):
639 # type: (...) -> None
640 self.name = name
641 self.ioevent = ioevent
642 self.proto = proto
643 # write, read
644 self.spa, self.spb = ObjectPipe[Any]("spa"), \
645 ObjectPipe[Any]("spb")
646 kargs["external_fd"] = {ioevent: (self.spa, self.spb)}
647 kargs["is_atmt_socket"] = True
648 kargs["atmt_socket"] = self.name
649 self.atmt = automaton(*args, **kargs)
650 self.atmt.runbg()
652 def send(self, s):
653 # type: (Any) -> int
654 return self.spa.send(s)
656 def fileno(self):
657 # type: () -> int
658 return self.spb.fileno()
660 # note: _ATMT_supersocket may return bytes in certain cases, which
661 # is expected. We cheat on typing.
662 def recv(self, n=MTU, **kwargs): # type: ignore
663 # type: (int, **Any) -> Any
664 r = self.spb.recv(n)
665 if self.proto is not None and r is not None:
666 r = self.proto(r, **kwargs)
667 return r
669 def close(self):
670 # type: () -> None
671 if not self.closed:
672 self.atmt.stop()
673 self.atmt.destroy()
674 self.spa.close()
675 self.spb.close()
676 self.closed = True
678 @staticmethod
679 def select(sockets, remain=conf.recv_poll_rate):
680 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
681 return select_objects(sockets, remain)
684class _ATMT_to_supersocket:
685 def __init__(self, name, ioevent, automaton):
686 # type: (str, str, Type[Automaton]) -> None
687 self.name = name
688 self.ioevent = ioevent
689 self.automaton = automaton
691 def __call__(self, proto, *args, **kargs):
692 # type: (Callable[[bytes], Any], Any, Any) -> _ATMT_supersocket
693 return _ATMT_supersocket(
694 self.name, self.ioevent, self.automaton,
695 proto, *args, **kargs
696 )
699class Automaton_metaclass(type):
700 def __new__(cls, name, bases, dct):
701 # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton]
702 cls = super(Automaton_metaclass, cls).__new__( # type: ignore
703 cls, name, bases, dct
704 )
705 cls.states = {}
706 cls.recv_conditions = {} # type: Dict[str, List[_StateWrapper]]
707 cls.conditions = {} # type: Dict[str, List[_StateWrapper]]
708 cls.ioevents = {} # type: Dict[str, List[_StateWrapper]]
709 cls.timeout = {} # type: Dict[str, _TimerList]
710 cls.eofs = {} # type: Dict[str, _StateWrapper]
711 cls.actions = {} # type: Dict[str, List[_StateWrapper]]
712 cls.initial_states = [] # type: List[_StateWrapper]
713 cls.stop_state = None # type: Optional[_StateWrapper]
714 cls.ionames = []
715 cls.iosupersockets = []
717 members = {}
718 classes = [cls]
719 while classes:
720 c = classes.pop(0) # order is important to avoid breaking method overloading # noqa: E501
721 classes += list(c.__bases__)
722 for k, v in c.__dict__.items(): # type: ignore
723 if k not in members:
724 members[k] = v
726 decorated = [v for v in members.values()
727 if hasattr(v, "atmt_type")]
729 for m in decorated:
730 if m.atmt_type == ATMT.STATE:
731 s = m.atmt_state
732 cls.states[s] = m
733 cls.recv_conditions[s] = []
734 cls.ioevents[s] = []
735 cls.conditions[s] = []
736 cls.timeout[s] = _TimerList()
737 if m.atmt_initial:
738 cls.initial_states.append(m)
739 if m.atmt_stop:
740 if cls.stop_state is not None:
741 raise ValueError("There can only be a single stop state !")
742 cls.stop_state = m
743 elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT, ATMT.EOF]: # noqa: E501
744 cls.actions[m.atmt_condname] = []
746 for m in decorated:
747 if m.atmt_type == ATMT.CONDITION:
748 cls.conditions[m.atmt_state].append(m)
749 elif m.atmt_type == ATMT.RECV:
750 cls.recv_conditions[m.atmt_state].append(m)
751 elif m.atmt_type == ATMT.EOF:
752 cls.eofs[m.atmt_state] = m
753 elif m.atmt_type == ATMT.IOEVENT:
754 cls.ioevents[m.atmt_state].append(m)
755 cls.ionames.append(m.atmt_ioname)
756 if m.atmt_as_supersocket is not None:
757 cls.iosupersockets.append(m)
758 elif m.atmt_type == ATMT.TIMEOUT:
759 cls.timeout[m.atmt_state].add_timer(m.atmt_timeout)
760 elif m.atmt_type == ATMT.ACTION:
761 for co in m.atmt_cond:
762 cls.actions[co].append(m)
764 for v in itertools.chain(
765 cls.conditions.values(),
766 cls.recv_conditions.values(),
767 cls.ioevents.values()
768 ):
769 v.sort(key=lambda x: x.atmt_prio)
770 for condname, actlst in cls.actions.items():
771 actlst.sort(key=lambda x: x.atmt_cond[condname])
773 for ioev in cls.iosupersockets:
774 setattr(cls, ioev.atmt_as_supersocket,
775 _ATMT_to_supersocket(
776 ioev.atmt_as_supersocket,
777 ioev.atmt_ioname,
778 cast(Type["Automaton"], cls)))
780 # Inject signature
781 try:
782 import inspect
783 cls.__signature__ = inspect.signature(cls.parse_args) # type: ignore # noqa: E501
784 except (ImportError, AttributeError):
785 pass
787 return cast(Type["Automaton"], cls)
789 def build_graph(self):
790 # type: () -> str
791 s = 'digraph "%s" {\n' % self.__class__.__name__
793 se = "" # Keep initial nodes at the beginning for better rendering
794 for st in self.states.values():
795 if st.atmt_initial:
796 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se # noqa: E501
797 elif st.atmt_final:
798 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state # noqa: E501
799 elif st.atmt_error:
800 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state # noqa: E501
801 elif st.atmt_stop:
802 se += '\t"%s" [ style=filled, fillcolor=orange, shape=box, root=true ];\n' % st.atmt_state # noqa: E501
803 s += se
805 for st in self.states.values():
806 names = list(
807 st.atmt_origfunc.__code__.co_names +
808 st.atmt_origfunc.__code__.co_consts
809 )
810 while names:
811 n = names.pop()
812 if n in self.states:
813 s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n)
814 elif n in self.__dict__:
815 # function indirection
816 if callable(self.__dict__[n]):
817 names.extend(self.__dict__[n].__code__.co_names)
818 names.extend(self.__dict__[n].__code__.co_consts)
820 for c, sty, k, v in (
821 [("purple", "solid", k, v) for k, v in self.conditions.items()] +
822 [("red", "solid", k, v) for k, v in self.recv_conditions.items()] +
823 [("orange", "solid", k, v) for k, v in self.ioevents.items()] +
824 [("black", "dashed", k, [v]) for k, v in self.eofs.items()]
825 ):
826 for f in v:
827 names = list(f.__code__.co_names + f.__code__.co_consts)
828 while names:
829 n = names.pop()
830 if n in self.states:
831 line = f.atmt_condname
832 for x in self.actions[f.atmt_condname]:
833 line += "\\l>[%s]" % x.__name__
834 s += '\t"%s" -> "%s" [label="%s", color=%s, style=%s];\n' % (
835 k,
836 n,
837 line,
838 c,
839 sty,
840 )
841 elif n in self.__dict__:
842 # function indirection
843 if callable(self.__dict__[n]) and hasattr(self.__dict__[n], "__code__"): # noqa: E501
844 names.extend(self.__dict__[n].__code__.co_names)
845 names.extend(self.__dict__[n].__code__.co_consts)
846 for k, timers in self.timeout.items():
847 for timer in timers:
848 for n in (timer._func.__code__.co_names +
849 timer._func.__code__.co_consts):
850 if n in self.states:
851 line = "%s/%.1fs" % (timer._func.atmt_condname,
852 timer.get())
853 for x in self.actions[timer._func.atmt_condname]:
854 line += "\\l>[%s]" % x.__name__
855 s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line) # noqa: E501
856 s += "}\n"
857 return s
859 def graph(self, **kargs):
860 # type: (Any) -> Optional[str]
861 s = self.build_graph()
862 return do_graph(s, **kargs)
865class Automaton(metaclass=Automaton_metaclass):
866 states = {} # type: Dict[str, _StateWrapper]
867 state = None # type: ATMT.NewStateRequested
868 recv_conditions = {} # type: Dict[str, List[_StateWrapper]]
869 conditions = {} # type: Dict[str, List[_StateWrapper]]
870 eofs = {} # type: Dict[str, _StateWrapper]
871 ioevents = {} # type: Dict[str, List[_StateWrapper]]
872 timeout = {} # type: Dict[str, _TimerList]
873 actions = {} # type: Dict[str, List[_StateWrapper]]
874 initial_states = [] # type: List[_StateWrapper]
875 stop_state = None # type: Optional[_StateWrapper]
876 ionames = [] # type: List[str]
877 iosupersockets = [] # type: List[SuperSocket]
879 # used for spawn()
880 pkt_cls = conf.raw_layer
881 socketcls = StreamSocket
883 # Internals
884 def __init__(self, *args, **kargs):
885 # type: (Any, Any) -> None
886 external_fd = kargs.pop("external_fd", {})
887 if "sock" in kargs:
888 # We use a bi-directional sock
889 self.sock = kargs["sock"]
890 else:
891 # Separate sockets
892 self.sock = None
893 self.send_sock_class = kargs.pop("ll", conf.L3socket)
894 self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
895 self.listen_sock = None # type: Optional[SuperSocket]
896 self.send_sock = None # type: Optional[SuperSocket]
897 self.is_atmt_socket = kargs.pop("is_atmt_socket", False)
898 self.atmt_socket = kargs.pop("atmt_socket", None)
899 self.started = threading.Lock()
900 self.threadid = None # type: Optional[int]
901 self.breakpointed = None
902 self.breakpoints = set() # type: Set[_StateWrapper]
903 self.interception_points = set() # type: Set[_StateWrapper]
904 self.intercepted_packet = None # type: Union[None, Packet]
905 self.debug_level = 0
906 self.init_args = args
907 self.init_kargs = kargs
908 self.io = type.__new__(type, "IOnamespace", (), {})
909 self.oi = type.__new__(type, "IOnamespace", (), {})
910 self.cmdin = ObjectPipe[Message]("cmdin")
911 self.cmdout = ObjectPipe[Message]("cmdout")
912 self.ioin = {}
913 self.ioout = {}
914 self.packets = PacketList() # type: PacketList
915 for n in self.__class__.ionames:
916 extfd = external_fd.get(n)
917 if not isinstance(extfd, tuple):
918 extfd = (extfd, extfd)
919 ioin, ioout = extfd
920 if ioin is None:
921 ioin = ObjectPipe("ioin")
922 else:
923 ioin = self._IO_fdwrapper(ioin, None)
924 if ioout is None:
925 ioout = ObjectPipe("ioout")
926 else:
927 ioout = self._IO_fdwrapper(None, ioout)
929 self.ioin[n] = ioin
930 self.ioout[n] = ioout
931 ioin.ioname = n
932 ioout.ioname = n
933 setattr(self.io, n, self._IO_mixer(ioout, ioin))
934 setattr(self.oi, n, self._IO_mixer(ioin, ioout))
936 for stname in self.states:
937 setattr(self, stname,
938 _instance_state(getattr(self, stname)))
940 self.start()
942 def parse_args(self, debug=0, store=0, **kargs):
943 # type: (int, int, Any) -> None
944 self.debug_level = debug
945 if debug:
946 conf.logLevel = logging.DEBUG
947 self.socket_kargs = kargs
948 self.store_packets = store
950 @classmethod
951 def spawn(cls,
952 port: int,
953 iface: Optional[_GlobInterfaceType] = None,
954 bg: bool = False,
955 **kwargs: Any) -> Optional[socket.socket]:
956 """
957 Spawn a TCP server that listens for connections and start the automaton
958 for each new client.
960 :param port: the port to listen to
961 :param bg: background mode? (default: False)
962 """
963 from scapy.arch import get_if_addr
964 # create server sock and bind it
965 ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
966 local_ip = get_if_addr(iface or conf.iface)
967 try:
968 ssock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
969 except OSError:
970 pass
971 ssock.bind((local_ip, port))
972 ssock.listen(5)
973 clients = []
974 if kwargs.get("verb", True):
975 print(conf.color_theme.green(
976 "Server %s started listening on %s" % (
977 cls.__name__,
978 (local_ip, port),
979 )
980 ))
982 def _run() -> None:
983 # Wait for clients forever
984 try:
985 while True:
986 clientsocket, address = ssock.accept()
987 if kwargs.get("verb", True):
988 print(conf.color_theme.gold(
989 "\u2503 Connection received from %s" % repr(address)
990 ))
991 # Start atmt class with socket
992 sock = cls.socketcls(clientsocket, cls.pkt_cls)
993 atmt_server = cls(
994 sock=sock,
995 iface=iface, **kwargs
996 )
997 clients.append((atmt_server, clientsocket))
998 # start atmt
999 atmt_server.runbg()
1000 except KeyboardInterrupt:
1001 print("X Exiting.")
1002 ssock.shutdown(socket.SHUT_RDWR)
1003 except OSError:
1004 print("X Server closed.")
1005 finally:
1006 for atmt, clientsocket in clients:
1007 try:
1008 atmt.forcestop(wait=False)
1009 except Exception:
1010 pass
1011 try:
1012 clientsocket.shutdown(socket.SHUT_RDWR)
1013 clientsocket.close()
1014 except Exception:
1015 pass
1016 ssock.close()
1017 if bg:
1018 # Background
1019 threading.Thread(target=_run).start()
1020 return ssock
1021 else:
1022 # Non-background
1023 _run()
1024 return None
1026 def master_filter(self, pkt):
1027 # type: (Packet) -> bool
1028 return True
1030 def my_send(self, pkt, **kwargs):
1031 # type: (Packet, **Any) -> None
1032 if not self.send_sock:
1033 raise ValueError("send_sock is None !")
1034 self.send_sock.send(pkt, **kwargs)
1036 def update_sock(self, sock):
1037 # type: (SuperSocket) -> None
1038 """
1039 Update the socket used by the automata.
1040 Typically used in an eof event to reconnect.
1041 """
1042 self.sock = sock
1043 if self.listen_sock is not None:
1044 self.listen_sock = self.sock
1045 if self.send_sock:
1046 self.send_sock = self.sock
1048 def timer_by_name(self, name):
1049 # type: (str) -> Optional[Timer]
1050 for _, timers in self.timeout.items():
1051 for timer in timers: # type: Timer
1052 if timer._func.atmt_condname == name:
1053 return timer
1054 return None
1056 # Utility classes and exceptions
1057 class _IO_fdwrapper:
1058 def __init__(self,
1059 rd, # type: Union[int, ObjectPipe[bytes], None]
1060 wr # type: Union[int, ObjectPipe[bytes], None]
1061 ):
1062 # type: (...) -> None
1063 self.rd = rd
1064 self.wr = wr
1065 if isinstance(self.rd, socket.socket):
1066 self.__selectable_force_select__ = True
1068 def fileno(self):
1069 # type: () -> int
1070 if isinstance(self.rd, int):
1071 return self.rd
1072 elif self.rd:
1073 return self.rd.fileno()
1074 return 0
1076 def read(self, n=65535):
1077 # type: (int) -> Optional[bytes]
1078 if isinstance(self.rd, int):
1079 return os.read(self.rd, n)
1080 elif self.rd:
1081 return self.rd.recv(n)
1082 return None
1084 def write(self, msg):
1085 # type: (bytes) -> int
1086 if isinstance(self.wr, int):
1087 return os.write(self.wr, msg)
1088 elif self.wr:
1089 return self.wr.send(msg)
1090 return 0
1092 def recv(self, n=65535):
1093 # type: (int) -> Optional[bytes]
1094 return self.read(n)
1096 def send(self, msg):
1097 # type: (bytes) -> int
1098 return self.write(msg)
1100 class _IO_mixer:
1101 def __init__(self,
1102 rd, # type: ObjectPipe[Any]
1103 wr, # type: ObjectPipe[Any]
1104 ):
1105 # type: (...) -> None
1106 self.rd = rd
1107 self.wr = wr
1109 def fileno(self):
1110 # type: () -> Any
1111 if isinstance(self.rd, ObjectPipe):
1112 return self.rd.fileno()
1113 return self.rd
1115 def recv(self, n=None):
1116 # type: (Optional[int]) -> Any
1117 return self.rd.recv(n)
1119 def read(self, n=None):
1120 # type: (Optional[int]) -> Any
1121 return self.recv(n)
1123 def send(self, msg):
1124 # type: (str) -> int
1125 return self.wr.send(msg)
1127 def write(self, msg):
1128 # type: (str) -> int
1129 return self.send(msg)
1131 class AutomatonException(Exception):
1132 def __init__(self, msg, state=None, result=None):
1133 # type: (str, Optional[Message], Optional[str]) -> None
1134 Exception.__init__(self, msg)
1135 self.state = state
1136 self.result = result
1138 class AutomatonError(AutomatonException):
1139 pass
1141 class ErrorState(AutomatonException):
1142 pass
1144 class Stuck(AutomatonException):
1145 pass
1147 class AutomatonStopped(AutomatonException):
1148 pass
1150 class Breakpoint(AutomatonStopped):
1151 pass
1153 class Singlestep(AutomatonStopped):
1154 pass
1156 class InterceptionPoint(AutomatonStopped):
1157 def __init__(self, msg, state=None, result=None, packet=None):
1158 # type: (str, Optional[Message], Optional[str], Optional[Packet]) -> None
1159 Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
1160 self.packet = packet
1162 class CommandMessage(AutomatonException):
1163 pass
1165 # Services
1166 def debug(self, lvl, msg):
1167 # type: (int, str) -> None
1168 if self.debug_level >= lvl:
1169 log_runtime.debug(msg)
1171 def isrunning(self):
1172 # type: () -> bool
1173 return self.started.locked()
1175 def send(self, pkt, **kwargs):
1176 # type: (Packet, **Any) -> None
1177 if self.state.state in self.interception_points:
1178 self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary())
1179 self.intercepted_packet = pkt
1180 self.cmdout.send(
1181 Message(type=_ATMT_Command.INTERCEPT,
1182 state=self.state, pkt=pkt)
1183 )
1184 cmd = self.cmdin.recv()
1185 if not cmd:
1186 self.debug(3, "CANCELLED")
1187 return
1188 self.intercepted_packet = None
1189 if cmd.type == _ATMT_Command.REJECT:
1190 self.debug(3, "INTERCEPT: packet rejected")
1191 return
1192 elif cmd.type == _ATMT_Command.REPLACE:
1193 pkt = cmd.pkt
1194 self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary()) # noqa: E501
1195 elif cmd.type == _ATMT_Command.ACCEPT:
1196 self.debug(3, "INTERCEPT: packet accepted")
1197 else:
1198 raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type) # noqa: E501
1199 self.my_send(pkt, **kwargs)
1200 self.debug(3, "SENT : %s" % pkt.summary())
1202 if self.store_packets:
1203 self.packets.append(pkt.copy())
1205 def __iter__(self):
1206 # type: () -> Automaton
1207 return self
1209 def __del__(self):
1210 # type: () -> None
1211 self.destroy()
1213 def _run_condition(self, cond, *args, **kargs):
1214 # type: (_StateWrapper, Any, Any) -> None
1215 try:
1216 self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501
1217 cond(self, *args, **kargs)
1218 except ATMT.NewStateRequested as state_req:
1219 self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) # noqa: E501
1220 if cond.atmt_type == ATMT.RECV:
1221 if self.store_packets:
1222 self.packets.append(args[0])
1223 for action in self.actions[cond.atmt_condname]:
1224 self.debug(2, " + Running action [%s]" % action.__name__)
1225 action(self, *state_req.action_args, **state_req.action_kargs)
1226 raise
1227 except Exception as e:
1228 self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) # noqa: E501
1229 raise
1230 else:
1231 self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501
1233 def _do_start(self, *args, **kargs):
1234 # type: (Any, Any) -> None
1235 ready = threading.Event()
1236 _t = threading.Thread(
1237 target=self._do_control,
1238 args=(ready,) + (args),
1239 kwargs=kargs,
1240 name="scapy.automaton _do_start"
1241 )
1242 _t.daemon = True
1243 _t.start()
1244 ready.wait()
1246 def _do_control(self, ready, *args, **kargs):
1247 # type: (threading.Event, Any, Any) -> None
1248 with self.started:
1249 self.threadid = threading.current_thread().ident
1250 if self.threadid is None:
1251 self.threadid = 0
1253 # Update default parameters
1254 a = args + self.init_args[len(args):]
1255 k = self.init_kargs.copy()
1256 k.update(kargs)
1257 self.parse_args(*a, **k)
1259 # Start the automaton
1260 self.state = self.initial_states[0](self)
1261 self.send_sock = self.sock or self.send_sock_class(**self.socket_kargs)
1262 if self.recv_conditions:
1263 # Only start a receiving socket if we have at least one recv_conditions
1264 self.listen_sock = self.sock or self.recv_sock_class(**self.socket_kargs) # noqa: E501
1265 self.packets = PacketList(name="session[%s]" % self.__class__.__name__)
1267 singlestep = True
1268 iterator = self._do_iter()
1269 self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
1270 # Sync threads
1271 ready.set()
1272 try:
1273 while True:
1274 c = self.cmdin.recv()
1275 if c is None:
1276 return None
1277 self.debug(5, "Received command %s" % c.type)
1278 if c.type == _ATMT_Command.RUN:
1279 singlestep = False
1280 elif c.type == _ATMT_Command.NEXT:
1281 singlestep = True
1282 elif c.type == _ATMT_Command.FREEZE:
1283 continue
1284 elif c.type == _ATMT_Command.STOP:
1285 if self.stop_state:
1286 # There is a stop state
1287 self.state = self.stop_state()
1288 iterator = self._do_iter()
1289 else:
1290 # Act as FORCESTOP
1291 break
1292 elif c.type == _ATMT_Command.FORCESTOP:
1293 break
1294 while True:
1295 state = next(iterator)
1296 if isinstance(state, self.CommandMessage):
1297 break
1298 elif isinstance(state, self.Breakpoint):
1299 c = Message(type=_ATMT_Command.BREAKPOINT, state=state) # noqa: E501
1300 self.cmdout.send(c)
1301 break
1302 if singlestep:
1303 c = Message(type=_ATMT_Command.SINGLESTEP, state=state) # noqa: E501
1304 self.cmdout.send(c)
1305 break
1306 except (StopIteration, RuntimeError):
1307 c = Message(type=_ATMT_Command.END,
1308 result=self.final_state_output)
1309 self.cmdout.send(c)
1310 except Exception as e:
1311 exc_info = sys.exc_info()
1312 self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, "".join(traceback.format_exception(*exc_info)))) # noqa: E501
1313 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) # noqa: E501
1314 self.cmdout.send(m)
1315 self.debug(3, "Stopping control thread (tid=%i)" % self.threadid)
1316 self.threadid = None
1317 if self.listen_sock:
1318 self.listen_sock.close()
1319 if self.send_sock:
1320 self.send_sock.close()
1322 def _do_iter(self):
1323 # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, ATMT.NewStateRequested, None]] # noqa: E501
1324 while True:
1325 try:
1326 self.debug(1, "## state=[%s]" % self.state.state)
1328 # Entering a new state. First, call new state function
1329 if self.state.state in self.breakpoints and self.state.state != self.breakpointed: # noqa: E501
1330 self.breakpointed = self.state.state
1331 yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, # noqa: E501
1332 state=self.state.state)
1333 self.breakpointed = None
1334 state_output = self.state.run()
1335 if self.state.error:
1336 raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), # noqa: E501
1337 result=state_output, state=self.state.state) # noqa: E501
1338 if self.state.final:
1339 self.final_state_output = state_output
1340 return
1342 if state_output is None:
1343 state_output = ()
1344 elif not isinstance(state_output, list):
1345 state_output = state_output,
1347 timers = self.timeout[self.state.state]
1348 # If there are commandMessage, we should skip immediate
1349 # conditions.
1350 if not select_objects([self.cmdin], 0):
1351 # Then check immediate conditions
1352 for cond in self.conditions[self.state.state]:
1353 self._run_condition(cond, *state_output)
1355 # If still there and no conditions left, we are stuck!
1356 if (len(self.recv_conditions[self.state.state]) == 0 and
1357 len(self.ioevents[self.state.state]) == 0 and
1358 timers.count() == 0):
1359 raise self.Stuck("stuck in [%s]" % self.state.state,
1360 state=self.state.state,
1361 result=state_output)
1363 # Finally listen and pay attention to timeouts
1364 timers.reset()
1365 time_previous = time.time()
1367 fds = [self.cmdin] # type: List[Union[SuperSocket, ObjectPipe[Any]]]
1368 select_func = select_objects
1369 if self.listen_sock and self.recv_conditions[self.state.state]:
1370 fds.append(self.listen_sock)
1371 select_func = self.listen_sock.select # type: ignore
1372 for ioev in self.ioevents[self.state.state]:
1373 fds.append(self.ioin[ioev.atmt_ioname])
1374 while True:
1375 time_current = time.time()
1376 timers.decrement(time_current - time_previous)
1377 time_previous = time_current
1378 for timer in timers.expired():
1379 self._run_condition(timer._func, *state_output)
1380 remain = timers.until_next()
1382 self.debug(5, "Select on %r" % fds)
1383 r = select_func(fds, remain)
1384 self.debug(5, "Selected %r" % r)
1385 for fd in r:
1386 self.debug(5, "Looking at %r" % fd)
1387 if fd == self.cmdin:
1388 yield self.CommandMessage("Received command message") # noqa: E501
1389 elif fd == self.listen_sock:
1390 try:
1391 pkt = self.listen_sock.recv()
1392 except EOFError:
1393 # Socket was closed abruptly. This will likely only
1394 # ever happen when a client socket is passed to the
1395 # automaton (not the case when the automaton is
1396 # listening on a promiscuous conf.L2sniff)
1397 self.listen_sock.close()
1398 # False so that it is still reset by update_sock
1399 self.listen_sock = False # type: ignore
1400 fds.remove(fd)
1401 if self.state.state in self.eofs:
1402 # There is an eof state
1403 eof = self.eofs[self.state.state]
1404 self.debug(2, "Condition EOF [%s] taken" % eof.__name__) # noqa: E501
1405 raise self.eofs[self.state.state](self)
1406 else:
1407 # There isn't. Therefore, it's a closing condition.
1408 raise EOFError("Socket ended arbruptly.")
1409 if pkt is not None:
1410 if self.master_filter(pkt):
1411 self.debug(3, "RECVD: %s" % pkt.summary()) # noqa: E501
1412 for rcvcond in self.recv_conditions[self.state.state]: # noqa: E501
1413 self._run_condition(rcvcond, pkt, *state_output) # noqa: E501
1414 else:
1415 self.debug(4, "FILTR: %s" % pkt.summary()) # noqa: E501
1416 else:
1417 self.debug(3, "IOEVENT on %s" % fd.ioname)
1418 for ioevt in self.ioevents[self.state.state]:
1419 if ioevt.atmt_ioname == fd.ioname:
1420 self._run_condition(ioevt, fd, *state_output) # noqa: E501
1422 except ATMT.NewStateRequested as state_req:
1423 self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state)) # noqa: E501
1424 self.state = state_req
1425 yield state_req
1427 def __repr__(self):
1428 # type: () -> str
1429 return "<Automaton %s [%s]>" % (
1430 self.__class__.__name__,
1431 ["HALTED", "RUNNING"][self.isrunning()]
1432 )
1434 # Public API
1435 def add_interception_points(self, *ipts):
1436 # type: (Any) -> None
1437 for ipt in ipts:
1438 if hasattr(ipt, "atmt_state"):
1439 ipt = ipt.atmt_state
1440 self.interception_points.add(ipt)
1442 def remove_interception_points(self, *ipts):
1443 # type: (Any) -> None
1444 for ipt in ipts:
1445 if hasattr(ipt, "atmt_state"):
1446 ipt = ipt.atmt_state
1447 self.interception_points.discard(ipt)
1449 def add_breakpoints(self, *bps):
1450 # type: (Any) -> None
1451 for bp in bps:
1452 if hasattr(bp, "atmt_state"):
1453 bp = bp.atmt_state
1454 self.breakpoints.add(bp)
1456 def remove_breakpoints(self, *bps):
1457 # type: (Any) -> None
1458 for bp in bps:
1459 if hasattr(bp, "atmt_state"):
1460 bp = bp.atmt_state
1461 self.breakpoints.discard(bp)
1463 def start(self, *args, **kargs):
1464 # type: (Any, Any) -> None
1465 if self.isrunning():
1466 raise ValueError("Already started")
1467 # Start the control thread
1468 self._do_start(*args, **kargs)
1470 def run(self,
1471 resume=None, # type: Optional[Message]
1472 wait=True # type: Optional[bool]
1473 ):
1474 # type: (...) -> Any
1475 if resume is None:
1476 resume = Message(type=_ATMT_Command.RUN)
1477 self.cmdin.send(resume)
1478 if wait:
1479 try:
1480 c = self.cmdout.recv()
1481 if c is None:
1482 return None
1483 except KeyboardInterrupt:
1484 self.cmdin.send(Message(type=_ATMT_Command.FREEZE))
1485 return None
1486 if c.type == _ATMT_Command.END:
1487 return c.result
1488 elif c.type == _ATMT_Command.INTERCEPT:
1489 raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) # noqa: E501
1490 elif c.type == _ATMT_Command.SINGLESTEP:
1491 raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state) # noqa: E501
1492 elif c.type == _ATMT_Command.BREAKPOINT:
1493 raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501
1494 elif c.type == _ATMT_Command.EXCEPTION:
1495 # this code comes from the `six` module (`.reraise()`)
1496 # to raise an exception with specified exc_info.
1497 value = c.exc_info[0]() if c.exc_info[1] is None else c.exc_info[1] # type: ignore # noqa: E501
1498 if value.__traceback__ is not c.exc_info[2]:
1499 raise value.with_traceback(c.exc_info[2])
1500 raise value
1501 return None
1503 def runbg(self, resume=None, wait=False):
1504 # type: (Optional[Message], Optional[bool]) -> None
1505 self.run(resume, wait)
1507 def __next__(self):
1508 # type: () -> Any
1509 return self.run(resume=Message(type=_ATMT_Command.NEXT))
1511 def _flush_inout(self):
1512 # type: () -> None
1513 # Flush command pipes
1514 for cmd in [self.cmdin, self.cmdout]:
1515 cmd.clear()
1517 def destroy(self):
1518 # type: () -> None
1519 """
1520 Destroys a stopped Automaton: this cleanups all opened file descriptors.
1521 Required on PyPy for instance where the garbage collector behaves differently.
1522 """
1523 if not hasattr(self, "started"):
1524 return # was never started.
1525 if self.isrunning():
1526 raise ValueError("Can't close running Automaton ! Call stop() beforehand")
1527 # Close command pipes
1528 self.cmdin.close()
1529 self.cmdout.close()
1530 self._flush_inout()
1531 # Close opened ioins/ioouts
1532 for i in itertools.chain(self.ioin.values(), self.ioout.values()):
1533 if isinstance(i, ObjectPipe):
1534 i.close()
1536 def stop(self, wait=True):
1537 # type: (bool) -> None
1538 try:
1539 self.cmdin.send(Message(type=_ATMT_Command.STOP))
1540 except OSError:
1541 pass
1542 if wait:
1543 with self.started:
1544 self._flush_inout()
1546 def forcestop(self, wait=True):
1547 # type: (bool) -> None
1548 try:
1549 self.cmdin.send(Message(type=_ATMT_Command.FORCESTOP))
1550 except OSError:
1551 pass
1552 if wait:
1553 with self.started:
1554 self._flush_inout()
1556 def restart(self, *args, **kargs):
1557 # type: (Any, Any) -> None
1558 self.stop()
1559 self.start(*args, **kargs)
1561 def accept_packet(self,
1562 pkt=None, # type: Optional[Packet]
1563 wait=False # type: Optional[bool]
1564 ):
1565 # type: (...) -> Any
1566 rsm = Message()
1567 if pkt is None:
1568 rsm.type = _ATMT_Command.ACCEPT
1569 else:
1570 rsm.type = _ATMT_Command.REPLACE
1571 rsm.pkt = pkt
1572 return self.run(resume=rsm, wait=wait)
1574 def reject_packet(self,
1575 wait=False # type: Optional[bool]
1576 ):
1577 # type: (...) -> Any
1578 rsm = Message(type=_ATMT_Command.REJECT)
1579 return self.run(resume=rsm, wait=wait)