Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/scapy/automaton.py: 36%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) Philippe Biondi <phil@secdev.org>
5# Copyright (C) Gabriel Potter
7"""
8Automata with states, transitions and actions.
10TODO:
11 - add documentation for ioevent, as_supersocket...
12"""
14import ctypes
15import itertools
16import logging
17import os
18import random
19import socket
20import sys
21import threading
22import time
23import traceback
24import types
26import select
27from collections import deque
29from scapy.config import conf
30from scapy.consts import WINDOWS
31from scapy.data import MTU
32from scapy.error import log_runtime, warning
33from scapy.interfaces import _GlobInterfaceType
34from scapy.packet import Packet
35from scapy.plist import PacketList
36from scapy.supersocket import SuperSocket, StreamSocket
37from scapy.utils import do_graph
39# Typing imports
40from typing import (
41 Any,
42 Callable,
43 Deque,
44 Dict,
45 Generic,
46 Iterable,
47 Iterator,
48 List,
49 Optional,
50 Set,
51 Tuple,
52 Type,
53 TypeVar,
54 Union,
55 cast,
56)
57from scapy.compat import DecoratorCallable
60# winsock.h
61FD_READ = 0x00000001
64def select_objects(inputs, remain):
65 # type: (Iterable[Any], Union[float, int, None]) -> List[Any]
66 """
67 Select objects. Same than:
68 ``select.select(inputs, [], [], remain)``
70 But also works on Windows, only on objects whose fileno() returns
71 a Windows event. For simplicity, just use `ObjectPipe()` as a queue
72 that you can select on whatever the platform is.
74 If you want an object to be always included in the output of
75 select_objects (i.e. it's not selectable), just make fileno()
76 return a strictly negative value.
78 Example:
80 >>> a, b = ObjectPipe("a"), ObjectPipe("b")
81 >>> b.send("test")
82 >>> select_objects([a, b], 1)
83 [b]
85 :param inputs: objects to process
86 :param remain: timeout. If 0, poll. If None, block.
87 """
88 if not WINDOWS:
89 return select.select(inputs, [], [], remain)[0]
90 inputs = list(inputs)
91 events = []
92 created = []
93 results = set()
94 for i in inputs:
95 if getattr(i, "__selectable_force_select__", False):
96 # Native socket.socket object. We would normally use select.select.
97 evt = ctypes.windll.ws2_32.WSACreateEvent()
98 created.append(evt)
99 res = ctypes.windll.ws2_32.WSAEventSelect(
100 ctypes.c_void_p(i.fileno()),
101 evt,
102 FD_READ
103 )
104 if res == 0:
105 # Was a socket
106 events.append(evt)
107 else:
108 # Fallback to normal event
109 events.append(i.fileno())
110 elif i.fileno() < 0:
111 # Special case: On Windows, we consider that an object that returns
112 # a negative fileno (impossible), is always readable. This is used
113 # in very few places but important (e.g. PcapReader), where we have
114 # no valid fileno (and will stop on EOFError).
115 results.add(i)
116 remain = 0
117 else:
118 events.append(i.fileno())
119 if events:
120 # 0xFFFFFFFF = INFINITE
121 remainms = int(remain * 1000 if remain is not None else 0xFFFFFFFF)
122 if len(events) == 1:
123 res = ctypes.windll.kernel32.WaitForSingleObject(
124 ctypes.c_void_p(events[0]),
125 remainms
126 )
127 else:
128 # Sadly, the only way to emulate select() is to first check
129 # if any object is available using WaitForMultipleObjects
130 # then poll the others.
131 res = ctypes.windll.kernel32.WaitForMultipleObjects(
132 len(events),
133 (ctypes.c_void_p * len(events))(
134 *events
135 ),
136 False,
137 remainms
138 )
139 if res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout
140 results.add(inputs[res])
141 if len(events) > 1:
142 # Now poll the others, if any
143 for i, evt in enumerate(events):
144 res = ctypes.windll.kernel32.WaitForSingleObject(
145 ctypes.c_void_p(evt),
146 0 # poll: don't wait
147 )
148 if res == 0:
149 results.add(inputs[i])
150 # Cleanup created events, if any
151 for evt in created:
152 ctypes.windll.ws2_32.WSACloseEvent(evt)
153 return list(results)
156_T = TypeVar("_T")
159class ObjectPipe(Generic[_T]):
160 def __init__(self, name=None):
161 # type: (Optional[str]) -> None
162 self.name = name or "ObjectPipe"
163 self.closed = False
164 self.__rd, self.__wr = os.pipe()
165 self.__queue = deque() # type: Deque[_T]
166 if WINDOWS:
167 self._wincreate()
169 if WINDOWS:
170 def _wincreate(self):
171 # type: () -> None
172 self._fd = cast(int, ctypes.windll.kernel32.CreateEventA(
173 None, True, False,
174 ctypes.create_string_buffer(b"ObjectPipe %f" % random.random())
175 ))
177 def _winset(self):
178 # type: () -> None
179 if ctypes.windll.kernel32.SetEvent(ctypes.c_void_p(self._fd)) == 0:
180 warning(ctypes.FormatError(ctypes.GetLastError()))
182 def _winreset(self):
183 # type: () -> None
184 if ctypes.windll.kernel32.ResetEvent(ctypes.c_void_p(self._fd)) == 0:
185 warning(ctypes.FormatError(ctypes.GetLastError()))
187 def _winclose(self):
188 # type: () -> None
189 if ctypes.windll.kernel32.CloseHandle(ctypes.c_void_p(self._fd)) == 0:
190 warning(ctypes.FormatError(ctypes.GetLastError()))
192 def fileno(self):
193 # type: () -> int
194 if WINDOWS:
195 return self._fd
196 return self.__rd
198 def send(self, obj):
199 # type: (_T) -> int
200 self.__queue.append(obj)
201 if WINDOWS:
202 self._winset()
203 os.write(self.__wr, b"X")
204 return 1
206 def write(self, obj):
207 # type: (_T) -> None
208 self.send(obj)
210 def empty(self):
211 # type: () -> bool
212 return not bool(self.__queue)
214 def flush(self):
215 # type: () -> None
216 pass
218 def recv(self, n=0, options=socket.MsgFlag(0)):
219 # type: (Optional[int], socket.MsgFlag) -> Optional[_T]
220 if self.closed:
221 raise EOFError
222 if options & socket.MSG_PEEK:
223 if self.__queue:
224 return self.__queue[0]
225 return None
226 os.read(self.__rd, 1)
227 elt = self.__queue.popleft()
228 if WINDOWS and not self.__queue:
229 self._winreset()
230 return elt
232 def read(self, n=0):
233 # type: (Optional[int]) -> Optional[_T]
234 return self.recv(n)
236 def clear(self):
237 # type: () -> None
238 if not self.closed:
239 while not self.empty():
240 self.recv()
242 def close(self):
243 # type: () -> None
244 if not self.closed:
245 self.closed = True
246 os.close(self.__rd)
247 os.close(self.__wr)
248 if WINDOWS:
249 try:
250 self._winclose()
251 except ImportError:
252 # Python is shutting down
253 pass
255 def __repr__(self):
256 # type: () -> str
257 return "<%s at %s>" % (self.name, id(self))
259 def __del__(self):
260 # type: () -> None
261 self.close()
263 @staticmethod
264 def select(sockets, remain=conf.recv_poll_rate):
265 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
266 # Only handle ObjectPipes
267 results = []
268 for s in sockets:
269 if s.closed: # allow read to trigger EOF
270 results.append(s)
271 if results:
272 return results
273 return select_objects(sockets, remain)
276class Message:
277 type = None # type: str
278 pkt = None # type: Packet
279 result = None # type: str
280 state = None # type: Message
281 exc_info = None # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501
283 def __init__(self, **args):
284 # type: (Any) -> None
285 self.__dict__.update(args)
287 def __repr__(self):
288 # type: () -> str
289 return "<Message %s>" % " ".join(
290 "%s=%r" % (k, v)
291 for k, v in self.__dict__.items()
292 if not k.startswith("_")
293 )
296class Timer():
297 def __init__(self, time, prio=0, autoreload=False):
298 # type: (Union[int, float], int, bool) -> None
299 self._timeout = float(time) # type: float
300 self._time = 0 # type: float
301 self._just_expired = True
302 self._expired = True
303 self._prio = prio
304 self._func = _StateWrapper()
305 self._autoreload = autoreload
307 def get(self):
308 # type: () -> float
309 return self._timeout
311 def set(self, val):
312 # type: (float) -> None
313 self._timeout = val
315 def _reset(self):
316 # type: () -> None
317 self._time = self._timeout
318 self._expired = False
319 self._just_expired = False
321 def _reset_just_expired(self):
322 # type: () -> None
323 self._just_expired = False
325 def _running(self):
326 # type: () -> bool
327 return self._time > 0
329 def _remaining(self):
330 # type: () -> float
331 return max(self._time, 0)
333 def _decrement(self, time):
334 # type: (float) -> None
335 self._time -= time
336 if self._time <= 0:
337 if not self._expired:
338 self._just_expired = True
339 if self._autoreload:
340 # take overshoot into account
341 self._time = self._timeout + self._time
342 else:
343 self._expired = True
344 self._time = 0
346 def __lt__(self, obj):
347 # type: (Timer) -> bool
348 return ((self._time < obj._time) if self._time != obj._time
349 else (self._prio < obj._prio))
351 def __gt__(self, obj):
352 # type: (Timer) -> bool
353 return ((self._time > obj._time) if self._time != obj._time
354 else (self._prio > obj._prio))
356 def __eq__(self, obj):
357 # type: (Any) -> bool
358 if not isinstance(obj, Timer):
359 raise NotImplementedError()
360 return (self._time == obj._time) and (self._prio == obj._prio)
362 def __repr__(self):
363 # type: () -> str
364 return "<Timer %f(%f)>" % (self._time, self._timeout)
367class _TimerList():
368 def __init__(self):
369 # type: () -> None
370 self.timers = [] # type: list[Timer]
372 def add_timer(self, timer):
373 # type: (Timer) -> None
374 self.timers.append(timer)
376 def reset(self):
377 # type: () -> None
378 for t in self.timers:
379 t._reset()
381 def decrement(self, time):
382 # type: (float) -> None
383 for t in self.timers:
384 t._decrement(time)
386 def expired(self):
387 # type: () -> list[Timer]
388 lst = [t for t in self.timers if t._just_expired]
389 lst.sort(key=lambda x: x._prio, reverse=True)
390 for t in lst:
391 t._reset_just_expired()
392 return lst
394 def until_next(self):
395 # type: () -> Optional[float]
396 try:
397 return min([t._remaining() for t in self.timers if t._running()])
398 except ValueError:
399 return None # None means blocking
401 def count(self):
402 # type: () -> int
403 return len(self.timers)
405 def __iter__(self):
406 # type: () -> Iterator[Timer]
407 return self.timers.__iter__()
409 def __repr__(self):
410 # type: () -> str
411 return self.timers.__repr__()
414class _instance_state:
415 def __init__(self, instance):
416 # type: (Any) -> None
417 self.__self__ = instance.__self__
418 self.__func__ = instance.__func__
419 self.__self__.__class__ = instance.__self__.__class__
421 def __getattr__(self, attr):
422 # type: (str) -> Any
423 return getattr(self.__func__, attr)
425 def __call__(self, *args, **kargs):
426 # type: (Any, Any) -> Any
427 return self.__func__(self.__self__, *args, **kargs)
429 def breaks(self):
430 # type: () -> Any
431 return self.__self__.add_breakpoints(self.__func__)
433 def intercepts(self):
434 # type: () -> Any
435 return self.__self__.add_interception_points(self.__func__)
437 def unbreaks(self):
438 # type: () -> Any
439 return self.__self__.remove_breakpoints(self.__func__)
441 def unintercepts(self):
442 # type: () -> Any
443 return self.__self__.remove_interception_points(self.__func__)
446##############
447# Automata #
448##############
450class _StateWrapper:
451 __name__ = None # type: str
452 atmt_type = None # type: str
453 atmt_state = None # type: str
454 atmt_initial = None # type: int
455 atmt_final = None # type: int
456 atmt_stop = None # type: int
457 atmt_error = None # type: int
458 atmt_origfunc = None # type: _StateWrapper
459 atmt_prio = None # type: int
460 atmt_as_supersocket = None # type: Optional[str]
461 atmt_condname = None # type: str
462 atmt_ioname = None # type: str
463 atmt_timeout = None # type: Timer
464 atmt_cond = None # type: Dict[str, int]
465 __code__ = None # type: types.CodeType
466 __call__ = None # type: Callable[..., ATMT.NewStateRequested]
469class ATMT:
470 STATE = "State"
471 ACTION = "Action"
472 CONDITION = "Condition"
473 RECV = "Receive condition"
474 TIMEOUT = "Timeout condition"
475 EOF = "EOF condition"
476 IOEVENT = "I/O event"
478 class NewStateRequested(Exception):
479 def __init__(self, state_func, automaton, *args, **kargs):
480 # type: (Any, ATMT, Any, Any) -> None
481 self.func = state_func
482 self.state = state_func.atmt_state
483 self.initial = state_func.atmt_initial
484 self.error = state_func.atmt_error
485 self.stop = state_func.atmt_stop
486 self.final = state_func.atmt_final
487 Exception.__init__(self, "Request state [%s]" % self.state)
488 self.automaton = automaton
489 self.args = args
490 self.kargs = kargs
491 self.action_parameters() # init action parameters
493 def action_parameters(self, *args, **kargs):
494 # type: (Any, Any) -> ATMT.NewStateRequested
495 self.action_args = args
496 self.action_kargs = kargs
497 return self
499 def run(self):
500 # type: () -> Any
501 return self.func(self.automaton, *self.args, **self.kargs)
503 def __repr__(self):
504 # type: () -> str
505 return "NewStateRequested(%s)" % self.state
507 @staticmethod
508 def state(initial=0, # type: int
509 final=0, # type: int
510 stop=0, # type: int
511 error=0 # type: int
512 ):
513 # type: (...) -> Callable[[DecoratorCallable], DecoratorCallable]
514 def deco(f, initial=initial, final=final):
515 # type: (_StateWrapper, int, int) -> _StateWrapper
516 f.atmt_type = ATMT.STATE
517 f.atmt_state = f.__name__
518 f.atmt_initial = initial
519 f.atmt_final = final
520 f.atmt_stop = stop
521 f.atmt_error = error
523 def _state_wrapper(self, *args, **kargs):
524 # type: (ATMT, Any, Any) -> ATMT.NewStateRequested
525 return ATMT.NewStateRequested(f, self, *args, **kargs)
527 state_wrapper = cast(_StateWrapper, _state_wrapper)
528 state_wrapper.__name__ = "%s_wrapper" % f.__name__
529 state_wrapper.atmt_type = ATMT.STATE
530 state_wrapper.atmt_state = f.__name__
531 state_wrapper.atmt_initial = initial
532 state_wrapper.atmt_final = final
533 state_wrapper.atmt_stop = stop
534 state_wrapper.atmt_error = error
535 state_wrapper.atmt_origfunc = f
536 return state_wrapper
537 return deco # type: ignore
539 @staticmethod
540 def action(cond, prio=0):
541 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
542 def deco(f, cond=cond):
543 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
544 if not hasattr(f, "atmt_type"):
545 f.atmt_cond = {}
546 f.atmt_type = ATMT.ACTION
547 f.atmt_cond[cond.atmt_condname] = prio
548 return f
549 return deco
551 @staticmethod
552 def condition(state, prio=0):
553 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
554 def deco(f, state=state):
555 # type: (_StateWrapper, _StateWrapper) -> Any
556 f.atmt_type = ATMT.CONDITION
557 f.atmt_state = state.atmt_state
558 f.atmt_condname = f.__name__
559 f.atmt_prio = prio
560 return f
561 return deco
563 @staticmethod
564 def receive_condition(state, prio=0):
565 # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
566 def deco(f, state=state):
567 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
568 f.atmt_type = ATMT.RECV
569 f.atmt_state = state.atmt_state
570 f.atmt_condname = f.__name__
571 f.atmt_prio = prio
572 return f
573 return deco
575 @staticmethod
576 def ioevent(state, # type: _StateWrapper
577 name, # type: str
578 prio=0, # type: int
579 as_supersocket=None # type: Optional[str]
580 ):
581 # type: (...) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
582 def deco(f, state=state):
583 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
584 f.atmt_type = ATMT.IOEVENT
585 f.atmt_state = state.atmt_state
586 f.atmt_condname = f.__name__
587 f.atmt_ioname = name
588 f.atmt_prio = prio
589 f.atmt_as_supersocket = as_supersocket
590 return f
591 return deco
593 @staticmethod
594 def timeout(state, timeout):
595 # type: (_StateWrapper, Union[int, float]) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501
596 def deco(f, state=state, timeout=Timer(timeout)):
597 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper
598 f.atmt_type = ATMT.TIMEOUT
599 f.atmt_state = state.atmt_state
600 f.atmt_timeout = timeout
601 f.atmt_timeout._func = f
602 f.atmt_condname = f.__name__
603 return f
604 return deco
606 @staticmethod
607 def timer(state, timeout, prio=0):
608 # type: (_StateWrapper, Union[float, int], int) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501
609 def deco(f, state=state, timeout=Timer(timeout, prio=prio, autoreload=True)):
610 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper
611 f.atmt_type = ATMT.TIMEOUT
612 f.atmt_state = state.atmt_state
613 f.atmt_timeout = timeout
614 f.atmt_timeout._func = f
615 f.atmt_condname = f.__name__
616 return f
617 return deco
619 @staticmethod
620 def eof(state):
621 # type: (_StateWrapper) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501
622 def deco(f, state=state):
623 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
624 f.atmt_type = ATMT.EOF
625 f.atmt_state = state.atmt_state
626 f.atmt_condname = f.__name__
627 return f
628 return deco
631class _ATMT_Command:
632 RUN = "RUN"
633 NEXT = "NEXT"
634 FREEZE = "FREEZE"
635 STOP = "STOP"
636 FORCESTOP = "FORCESTOP"
637 END = "END"
638 EXCEPTION = "EXCEPTION"
639 SINGLESTEP = "SINGLESTEP"
640 BREAKPOINT = "BREAKPOINT"
641 INTERCEPT = "INTERCEPT"
642 ACCEPT = "ACCEPT"
643 REPLACE = "REPLACE"
644 REJECT = "REJECT"
647class _ATMT_supersocket(SuperSocket):
648 def __init__(self,
649 name, # type: str
650 ioevent, # type: str
651 automaton, # type: Type[Automaton]
652 proto, # type: Callable[[bytes], Any]
653 *args, # type: Any
654 **kargs # type: Any
655 ):
656 # type: (...) -> None
657 self.name = name
658 self.ioevent = ioevent
659 self.proto = proto
660 # write, read
661 self.spa, self.spb = ObjectPipe[Any]("spa"), \
662 ObjectPipe[Any]("spb")
663 kargs["external_fd"] = {ioevent: (self.spa, self.spb)}
664 kargs["is_atmt_socket"] = True
665 kargs["atmt_socket"] = self.name
666 self.atmt = automaton(*args, **kargs)
667 self.atmt.runbg()
669 def send(self, s):
670 # type: (Any) -> int
671 return self.spa.send(s)
673 def fileno(self):
674 # type: () -> int
675 return self.spb.fileno()
677 # note: _ATMT_supersocket may return bytes in certain cases, which
678 # is expected. We cheat on typing.
679 def recv(self, n=MTU, **kwargs): # type: ignore
680 # type: (int, **Any) -> Any
681 r = self.spb.recv(n)
682 if self.proto is not None and r is not None:
683 r = self.proto(r, **kwargs)
684 return r
686 def close(self):
687 # type: () -> None
688 if not self.closed:
689 self.atmt.stop()
690 self.atmt.destroy()
691 self.spa.close()
692 self.spb.close()
693 self.closed = True
695 @staticmethod
696 def select(sockets, remain=conf.recv_poll_rate):
697 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
698 return select_objects(sockets, remain)
701class _ATMT_to_supersocket:
702 def __init__(self, name, ioevent, automaton):
703 # type: (str, str, Type[Automaton]) -> None
704 self.name = name
705 self.ioevent = ioevent
706 self.automaton = automaton
708 def __call__(self, proto, *args, **kargs):
709 # type: (Callable[[bytes], Any], Any, Any) -> _ATMT_supersocket
710 return _ATMT_supersocket(
711 self.name, self.ioevent, self.automaton,
712 proto, *args, **kargs
713 )
716class Automaton_metaclass(type):
717 def __new__(cls, name, bases, dct):
718 # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton]
719 cls = super(Automaton_metaclass, cls).__new__( # type: ignore
720 cls, name, bases, dct
721 )
722 cls.states = {}
723 cls.recv_conditions = {} # type: Dict[str, List[_StateWrapper]]
724 cls.conditions = {} # type: Dict[str, List[_StateWrapper]]
725 cls.ioevents = {} # type: Dict[str, List[_StateWrapper]]
726 cls.timeout = {} # type: Dict[str, _TimerList]
727 cls.eofs = {} # type: Dict[str, _StateWrapper]
728 cls.actions = {} # type: Dict[str, List[_StateWrapper]]
729 cls.initial_states = [] # type: List[_StateWrapper]
730 cls.stop_state = None # type: Optional[_StateWrapper]
731 cls.ionames = []
732 cls.iosupersockets = []
734 members = {}
735 classes = [cls]
736 while classes:
737 c = classes.pop(0) # order is important to avoid breaking method overloading # noqa: E501
738 classes += list(c.__bases__)
739 for k, v in c.__dict__.items(): # type: ignore
740 if k not in members:
741 members[k] = v
743 decorated = [v for v in members.values()
744 if hasattr(v, "atmt_type")]
746 for m in decorated:
747 if m.atmt_type == ATMT.STATE:
748 s = m.atmt_state
749 cls.states[s] = m
750 cls.recv_conditions[s] = []
751 cls.ioevents[s] = []
752 cls.conditions[s] = []
753 cls.timeout[s] = _TimerList()
754 if m.atmt_initial:
755 cls.initial_states.append(m)
756 if m.atmt_stop:
757 if cls.stop_state is not None:
758 raise ValueError("There can only be a single stop state !")
759 cls.stop_state = m
760 elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT, ATMT.EOF]: # noqa: E501
761 cls.actions[m.atmt_condname] = []
763 for m in decorated:
764 if m.atmt_type == ATMT.CONDITION:
765 cls.conditions[m.atmt_state].append(m)
766 elif m.atmt_type == ATMT.RECV:
767 cls.recv_conditions[m.atmt_state].append(m)
768 elif m.atmt_type == ATMT.EOF:
769 cls.eofs[m.atmt_state] = m
770 elif m.atmt_type == ATMT.IOEVENT:
771 cls.ioevents[m.atmt_state].append(m)
772 cls.ionames.append(m.atmt_ioname)
773 if m.atmt_as_supersocket is not None:
774 cls.iosupersockets.append(m)
775 elif m.atmt_type == ATMT.TIMEOUT:
776 cls.timeout[m.atmt_state].add_timer(m.atmt_timeout)
777 elif m.atmt_type == ATMT.ACTION:
778 for co in m.atmt_cond:
779 cls.actions[co].append(m)
781 for v in itertools.chain(
782 cls.conditions.values(),
783 cls.recv_conditions.values(),
784 cls.ioevents.values()
785 ):
786 v.sort(key=lambda x: x.atmt_prio)
787 for condname, actlst in cls.actions.items():
788 actlst.sort(key=lambda x: x.atmt_cond[condname])
790 for ioev in cls.iosupersockets:
791 setattr(cls, ioev.atmt_as_supersocket,
792 _ATMT_to_supersocket(
793 ioev.atmt_as_supersocket,
794 ioev.atmt_ioname,
795 cast(Type["Automaton"], cls)))
797 # Inject signature
798 try:
799 import inspect
800 cls.__signature__ = inspect.signature(cls.parse_args) # type: ignore # noqa: E501
801 except (ImportError, AttributeError):
802 pass
804 return cast(Type["Automaton"], cls)
806 def build_graph(self):
807 # type: () -> str
808 s = 'digraph "%s" {\n' % self.__class__.__name__
810 se = "" # Keep initial nodes at the beginning for better rendering
811 for st in self.states.values():
812 if st.atmt_initial:
813 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se # noqa: E501
814 elif st.atmt_final:
815 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state # noqa: E501
816 elif st.atmt_error:
817 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state # noqa: E501
818 elif st.atmt_stop:
819 se += '\t"%s" [ style=filled, fillcolor=orange, shape=box, root=true ];\n' % st.atmt_state # noqa: E501
820 s += se
822 for st in self.states.values():
823 names = list(
824 st.atmt_origfunc.__code__.co_names +
825 st.atmt_origfunc.__code__.co_consts
826 )
827 while names:
828 n = names.pop()
829 if n in self.states:
830 s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n)
831 elif n in self.__dict__:
832 # function indirection
833 if callable(self.__dict__[n]):
834 names.extend(self.__dict__[n].__code__.co_names)
835 names.extend(self.__dict__[n].__code__.co_consts)
837 for c, sty, k, v in (
838 [("purple", "solid", k, v) for k, v in self.conditions.items()] +
839 [("red", "solid", k, v) for k, v in self.recv_conditions.items()] +
840 [("orange", "solid", k, v) for k, v in self.ioevents.items()] +
841 [("black", "dashed", k, [v]) for k, v in self.eofs.items()]
842 ):
843 for f in v:
844 names = list(f.__code__.co_names + f.__code__.co_consts)
845 while names:
846 n = names.pop()
847 if n in self.states:
848 line = f.atmt_condname
849 for x in self.actions[f.atmt_condname]:
850 line += "\\l>[%s]" % x.__name__
851 s += '\t"%s" -> "%s" [label="%s", color=%s, style=%s];\n' % (
852 k,
853 n,
854 line,
855 c,
856 sty,
857 )
858 elif n in self.__dict__:
859 # function indirection
860 if callable(self.__dict__[n]) and hasattr(self.__dict__[n], "__code__"): # noqa: E501
861 names.extend(self.__dict__[n].__code__.co_names)
862 names.extend(self.__dict__[n].__code__.co_consts)
863 for k, timers in self.timeout.items():
864 for timer in timers:
865 for n in (timer._func.__code__.co_names +
866 timer._func.__code__.co_consts):
867 if n in self.states:
868 line = "%s/%.1fs" % (timer._func.atmt_condname,
869 timer.get())
870 for x in self.actions[timer._func.atmt_condname]:
871 line += "\\l>[%s]" % x.__name__
872 s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line) # noqa: E501
873 s += "}\n"
874 return s
876 def graph(self, **kargs):
877 # type: (Any) -> Optional[str]
878 s = self.build_graph()
879 return do_graph(s, **kargs)
882class Automaton(metaclass=Automaton_metaclass):
883 states = {} # type: Dict[str, _StateWrapper]
884 state = None # type: ATMT.NewStateRequested
885 recv_conditions = {} # type: Dict[str, List[_StateWrapper]]
886 conditions = {} # type: Dict[str, List[_StateWrapper]]
887 eofs = {} # type: Dict[str, _StateWrapper]
888 ioevents = {} # type: Dict[str, List[_StateWrapper]]
889 timeout = {} # type: Dict[str, _TimerList]
890 actions = {} # type: Dict[str, List[_StateWrapper]]
891 initial_states = [] # type: List[_StateWrapper]
892 stop_state = None # type: Optional[_StateWrapper]
893 ionames = [] # type: List[str]
894 iosupersockets = [] # type: List[SuperSocket]
896 # used for spawn()
897 pkt_cls = conf.raw_layer
898 socketcls = StreamSocket
900 # Internals
901 def __init__(self, *args, **kargs):
902 # type: (Any, Any) -> None
903 external_fd = kargs.pop("external_fd", {})
904 if "sock" in kargs:
905 # We use a bi-directional sock
906 self.sock = kargs["sock"]
907 else:
908 # Separate sockets
909 self.sock = None
910 self.send_sock_class = kargs.pop("ll", conf.L3socket)
911 self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
912 self.listen_sock = None # type: Optional[SuperSocket]
913 self.send_sock = None # type: Optional[SuperSocket]
914 self.is_atmt_socket = kargs.pop("is_atmt_socket", False)
915 self.atmt_socket = kargs.pop("atmt_socket", None)
916 self.started = threading.Lock()
917 self.threadid = None # type: Optional[int]
918 self.breakpointed = None
919 self.breakpoints = set() # type: Set[_StateWrapper]
920 self.interception_points = set() # type: Set[_StateWrapper]
921 self.intercepted_packet = None # type: Union[None, Packet]
922 self.debug_level = 0
923 self.init_args = args
924 self.init_kargs = kargs
925 self.io = type.__new__(type, "IOnamespace", (), {})
926 self.oi = type.__new__(type, "IOnamespace", (), {})
927 self.cmdin = ObjectPipe[Message]("cmdin")
928 self.cmdout = ObjectPipe[Message]("cmdout")
929 self.ioin = {}
930 self.ioout = {}
931 self.packets = PacketList() # type: PacketList
932 self.atmt_session = kargs.pop("session", None)
933 for n in self.__class__.ionames:
934 extfd = external_fd.get(n)
935 if not isinstance(extfd, tuple):
936 extfd = (extfd, extfd)
937 ioin, ioout = extfd
938 if ioin is None:
939 ioin = ObjectPipe("ioin")
940 else:
941 ioin = self._IO_fdwrapper(ioin, None)
942 if ioout is None:
943 ioout = ObjectPipe("ioout")
944 else:
945 ioout = self._IO_fdwrapper(None, ioout)
947 self.ioin[n] = ioin
948 self.ioout[n] = ioout
949 ioin.ioname = n
950 ioout.ioname = n
951 setattr(self.io, n, self._IO_mixer(ioout, ioin))
952 setattr(self.oi, n, self._IO_mixer(ioin, ioout))
954 for stname in self.states:
955 setattr(self, stname,
956 _instance_state(getattr(self, stname)))
958 self.start()
960 def parse_args(self, debug=0, store=0, session=None, **kargs):
961 # type: (int, int, Any, Any) -> None
962 self.debug_level = debug
963 if debug:
964 conf.logLevel = logging.DEBUG
965 self.atmt_session = session
966 self.socket_kargs = kargs
967 self.store_packets = store
969 @classmethod
970 def spawn(cls,
971 port: int,
972 iface: Optional[_GlobInterfaceType] = None,
973 local_ip: Optional[str] = None,
974 bg: bool = False,
975 **kwargs: Any) -> Optional[socket.socket]:
976 """
977 Spawn a TCP server that listens for connections and start the automaton
978 for each new client.
980 :param port: the port to listen to
981 :param bg: background mode? (default: False)
983 Note that in background mode, you shall close the TCP server as such::
985 srv = MyAutomaton.spawn(8080, bg=True)
986 srv.shutdown(socket.SHUT_RDWR) # important
987 srv.close()
988 """
989 from scapy.arch import get_if_addr
990 # create server sock and bind it
991 ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
992 if local_ip is None:
993 local_ip = get_if_addr(iface or conf.iface)
994 try:
995 ssock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
996 except OSError:
997 pass
998 ssock.bind((local_ip, port))
999 ssock.listen(5)
1000 clients = []
1001 if kwargs.get("verb", True):
1002 print(conf.color_theme.green(
1003 "Server %s started listening on %s" % (
1004 cls.__name__,
1005 (local_ip, port),
1006 )
1007 ))
1009 def _run() -> None:
1010 # Wait for clients forever
1011 try:
1012 while True:
1013 atmt_server = None
1014 clientsocket, address = ssock.accept()
1015 if kwargs.get("verb", True):
1016 print(conf.color_theme.gold(
1017 "\u2503 Connection received from %s" % repr(address)
1018 ))
1019 try:
1020 # Start atmt class with socket
1021 if cls.socketcls is not None:
1022 sock = cls.socketcls(clientsocket, cls.pkt_cls)
1023 else:
1024 sock = clientsocket
1025 atmt_server = cls(
1026 sock=sock,
1027 iface=iface, **kwargs
1028 )
1029 except OSError:
1030 if atmt_server is not None:
1031 atmt_server.destroy()
1032 if kwargs.get("verb", True):
1033 print("X Connection aborted.")
1034 if kwargs.get("debug", 0) > 0:
1035 traceback.print_exc()
1036 continue
1037 clients.append((atmt_server, clientsocket))
1038 # start atmt
1039 atmt_server.runbg()
1040 # housekeeping
1041 for atmt, clientsocket in clients:
1042 if not atmt.isrunning():
1043 atmt.destroy()
1044 except KeyboardInterrupt:
1045 print("X Exiting.")
1046 ssock.shutdown(socket.SHUT_RDWR)
1047 except OSError:
1048 print("X Server closed.")
1049 if kwargs.get("debug", 0) > 0:
1050 traceback.print_exc()
1051 finally:
1052 for atmt, clientsocket in clients:
1053 try:
1054 atmt.forcestop(wait=False)
1055 atmt.destroy()
1056 except Exception:
1057 pass
1058 try:
1059 clientsocket.shutdown(socket.SHUT_RDWR)
1060 clientsocket.close()
1061 except Exception:
1062 pass
1063 ssock.close()
1064 if bg:
1065 # Background
1066 threading.Thread(target=_run).start()
1067 return ssock
1068 else:
1069 # Non-background
1070 _run()
1071 return None
1073 def master_filter(self, pkt):
1074 # type: (Packet) -> bool
1075 return True
1077 def my_send(self, pkt, **kwargs):
1078 # type: (Packet, **Any) -> None
1079 if not self.send_sock:
1080 raise ValueError("send_sock is None !")
1081 self.send_sock.send(pkt, **kwargs)
1083 def update_sock(self, sock):
1084 # type: (SuperSocket) -> None
1085 """
1086 Update the socket used by the automata.
1087 Typically used in an eof event to reconnect.
1088 """
1089 self.sock = sock
1090 if self.listen_sock is not None:
1091 self.listen_sock = self.sock
1092 if self.send_sock:
1093 self.send_sock = self.sock
1095 def timer_by_name(self, name):
1096 # type: (str) -> Optional[Timer]
1097 for _, timers in self.timeout.items():
1098 for timer in timers: # type: Timer
1099 if timer._func.atmt_condname == name:
1100 return timer
1101 return None
1103 # Utility classes and exceptions
1104 class _IO_fdwrapper:
1105 def __init__(self,
1106 rd, # type: Union[int, ObjectPipe[bytes], None]
1107 wr # type: Union[int, ObjectPipe[bytes], None]
1108 ):
1109 # type: (...) -> None
1110 self.rd = rd
1111 self.wr = wr
1112 if isinstance(self.rd, socket.socket):
1113 self.__selectable_force_select__ = True
1115 def fileno(self):
1116 # type: () -> int
1117 if isinstance(self.rd, int):
1118 return self.rd
1119 elif self.rd:
1120 return self.rd.fileno()
1121 return 0
1123 def read(self, n=65535):
1124 # type: (int) -> Optional[bytes]
1125 if isinstance(self.rd, int):
1126 return os.read(self.rd, n)
1127 elif self.rd:
1128 return self.rd.recv(n)
1129 return None
1131 def write(self, msg):
1132 # type: (bytes) -> int
1133 if isinstance(self.wr, int):
1134 return os.write(self.wr, msg)
1135 elif self.wr:
1136 return self.wr.send(msg)
1137 return 0
1139 def recv(self, n=65535):
1140 # type: (int) -> Optional[bytes]
1141 return self.read(n)
1143 def send(self, msg):
1144 # type: (bytes) -> int
1145 return self.write(msg)
1147 class _IO_mixer:
1148 def __init__(self,
1149 rd, # type: ObjectPipe[Any]
1150 wr, # type: ObjectPipe[Any]
1151 ):
1152 # type: (...) -> None
1153 self.rd = rd
1154 self.wr = wr
1156 def fileno(self):
1157 # type: () -> Any
1158 if isinstance(self.rd, ObjectPipe):
1159 return self.rd.fileno()
1160 return self.rd
1162 def recv(self, n=None):
1163 # type: (Optional[int]) -> Any
1164 return self.rd.recv(n)
1166 def read(self, n=None):
1167 # type: (Optional[int]) -> Any
1168 return self.recv(n)
1170 def send(self, msg):
1171 # type: (str) -> int
1172 return self.wr.send(msg)
1174 def write(self, msg):
1175 # type: (str) -> int
1176 return self.send(msg)
1178 class AutomatonException(Exception):
1179 def __init__(self, msg, state=None, result=None):
1180 # type: (str, Optional[Message], Optional[str]) -> None
1181 Exception.__init__(self, msg)
1182 self.state = state
1183 self.result = result
1185 class AutomatonError(AutomatonException):
1186 pass
1188 class ErrorState(AutomatonException):
1189 pass
1191 class Stuck(AutomatonException):
1192 pass
1194 class AutomatonStopped(AutomatonException):
1195 pass
1197 class Breakpoint(AutomatonStopped):
1198 pass
1200 class Singlestep(AutomatonStopped):
1201 pass
1203 class InterceptionPoint(AutomatonStopped):
1204 def __init__(self, msg, state=None, result=None, packet=None):
1205 # type: (str, Optional[Message], Optional[str], Optional[Packet]) -> None
1206 Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
1207 self.packet = packet
1209 class CommandMessage(AutomatonException):
1210 pass
1212 # Services
1213 def debug(self, lvl, msg):
1214 # type: (int, str) -> None
1215 if self.debug_level >= lvl:
1216 log_runtime.debug(msg)
1218 def isrunning(self):
1219 # type: () -> bool
1220 return self.started.locked()
1222 def send(self, pkt, **kwargs):
1223 # type: (Packet, **Any) -> None
1224 if self.state.state in self.interception_points:
1225 self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary())
1226 self.intercepted_packet = pkt
1227 self.cmdout.send(
1228 Message(type=_ATMT_Command.INTERCEPT,
1229 state=self.state, pkt=pkt)
1230 )
1231 cmd = self.cmdin.recv()
1232 if not cmd:
1233 self.debug(3, "CANCELLED")
1234 return
1235 self.intercepted_packet = None
1236 if cmd.type == _ATMT_Command.REJECT:
1237 self.debug(3, "INTERCEPT: packet rejected")
1238 return
1239 elif cmd.type == _ATMT_Command.REPLACE:
1240 pkt = cmd.pkt
1241 self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary()) # noqa: E501
1242 elif cmd.type == _ATMT_Command.ACCEPT:
1243 self.debug(3, "INTERCEPT: packet accepted")
1244 else:
1245 raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type) # noqa: E501
1246 self.my_send(pkt, **kwargs)
1247 self.debug(3, "SENT : %s" % pkt.summary())
1249 if self.store_packets:
1250 self.packets.append(pkt.copy())
1252 def __iter__(self):
1253 # type: () -> Automaton
1254 return self
1256 def __del__(self):
1257 # type: () -> None
1258 self.destroy()
1260 def _run_condition(self, cond, *args, **kargs):
1261 # type: (_StateWrapper, Any, Any) -> None
1262 try:
1263 self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501
1264 cond(self, *args, **kargs)
1265 except ATMT.NewStateRequested as state_req:
1266 self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) # noqa: E501
1267 if cond.atmt_type == ATMT.RECV:
1268 if self.store_packets:
1269 self.packets.append(args[0])
1270 for action in self.actions[cond.atmt_condname]:
1271 self.debug(2, " + Running action [%s]" % action.__name__)
1272 action(self, *state_req.action_args, **state_req.action_kargs)
1273 raise
1274 except Exception as e:
1275 self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) # noqa: E501
1276 raise
1277 else:
1278 self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501
1280 def _do_start(self, *args, **kargs):
1281 # type: (Any, Any) -> None
1282 ready = threading.Event()
1283 _t = threading.Thread(
1284 target=self._do_control,
1285 args=(ready,) + (args),
1286 kwargs=kargs,
1287 name="scapy.automaton _do_start"
1288 )
1289 _t.daemon = True
1290 _t.start()
1291 ready.wait()
1293 def _do_control(self, ready, *args, **kargs):
1294 # type: (threading.Event, Any, Any) -> None
1295 with self.started:
1296 self.threadid = threading.current_thread().ident
1297 if self.threadid is None:
1298 self.threadid = 0
1300 # Update default parameters
1301 a = args + self.init_args[len(args):]
1302 k = self.init_kargs.copy()
1303 k.update(kargs)
1304 self.parse_args(*a, **k)
1306 # Start the automaton
1307 self.state = self.initial_states[0](self)
1308 self.send_sock = self.sock or self.send_sock_class(**self.socket_kargs)
1309 if self.recv_conditions:
1310 # Only start a receiving socket if we have at least one recv_conditions
1311 self.listen_sock = self.sock or self.recv_sock_class(**self.socket_kargs) # noqa: E501
1312 self.packets = PacketList(name="session[%s]" % self.__class__.__name__)
1314 singlestep = True
1315 iterator = self._do_iter()
1316 self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
1317 # Sync threads
1318 ready.set()
1319 try:
1320 while True:
1321 c = self.cmdin.recv()
1322 if c is None:
1323 return None
1324 self.debug(5, "Received command %s" % c.type)
1325 if c.type == _ATMT_Command.RUN:
1326 singlestep = False
1327 elif c.type == _ATMT_Command.NEXT:
1328 singlestep = True
1329 elif c.type == _ATMT_Command.FREEZE:
1330 continue
1331 elif c.type == _ATMT_Command.STOP:
1332 if self.stop_state:
1333 # There is a stop state
1334 self.state = self.stop_state()
1335 iterator = self._do_iter()
1336 else:
1337 # Act as FORCESTOP
1338 break
1339 elif c.type == _ATMT_Command.FORCESTOP:
1340 break
1341 while True:
1342 state = next(iterator)
1343 if isinstance(state, self.CommandMessage):
1344 break
1345 elif isinstance(state, self.Breakpoint):
1346 c = Message(type=_ATMT_Command.BREAKPOINT, state=state) # noqa: E501
1347 self.cmdout.send(c)
1348 break
1349 if singlestep:
1350 c = Message(type=_ATMT_Command.SINGLESTEP, state=state) # noqa: E501
1351 self.cmdout.send(c)
1352 break
1353 except (StopIteration, RuntimeError):
1354 c = Message(type=_ATMT_Command.END,
1355 result=self.final_state_output)
1356 self.cmdout.send(c)
1357 except Exception as e:
1358 exc_info = sys.exc_info()
1359 self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, "".join(traceback.format_exception(*exc_info)))) # noqa: E501
1360 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) # noqa: E501
1361 self.cmdout.send(m)
1362 self.debug(3, "Stopping control thread (tid=%i)" % self.threadid)
1363 self.threadid = None
1364 if self.listen_sock:
1365 self.listen_sock.close()
1366 if self.send_sock:
1367 self.send_sock.close()
1369 def _do_iter(self):
1370 # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, ATMT.NewStateRequested, None]] # noqa: E501
1371 while True:
1372 try:
1373 self.debug(1, "## state=[%s]" % self.state.state)
1375 # Entering a new state. First, call new state function
1376 if self.state.state in self.breakpoints and self.state.state != self.breakpointed: # noqa: E501
1377 self.breakpointed = self.state.state
1378 yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, # noqa: E501
1379 state=self.state.state)
1380 self.breakpointed = None
1381 state_output = self.state.run()
1382 if self.state.error:
1383 raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), # noqa: E501
1384 result=state_output, state=self.state.state) # noqa: E501
1385 if self.state.final:
1386 self.final_state_output = state_output
1387 return
1389 if state_output is None:
1390 state_output = ()
1391 elif not isinstance(state_output, list):
1392 state_output = state_output,
1394 timers = self.timeout[self.state.state]
1395 # If there are commandMessage, we should skip immediate
1396 # conditions.
1397 if not select_objects([self.cmdin], 0):
1398 # Then check immediate conditions
1399 for cond in self.conditions[self.state.state]:
1400 self._run_condition(cond, *state_output)
1402 # If still there and no conditions left, we are stuck!
1403 if (len(self.recv_conditions[self.state.state]) == 0 and
1404 len(self.ioevents[self.state.state]) == 0 and
1405 timers.count() == 0):
1406 raise self.Stuck("stuck in [%s]" % self.state.state,
1407 state=self.state.state,
1408 result=state_output)
1410 # Finally listen and pay attention to timeouts
1411 timers.reset()
1412 time_previous = time.time()
1414 fds = [self.cmdin] # type: List[Union[SuperSocket, ObjectPipe[Any]]]
1415 select_func = select_objects
1416 if self.listen_sock and self.recv_conditions[self.state.state]:
1417 fds.append(self.listen_sock)
1418 select_func = self.listen_sock.select # type: ignore
1419 for ioev in self.ioevents[self.state.state]:
1420 fds.append(self.ioin[ioev.atmt_ioname])
1421 while True:
1422 time_current = time.time()
1423 timers.decrement(time_current - time_previous)
1424 time_previous = time_current
1425 for timer in timers.expired():
1426 self._run_condition(timer._func, *state_output)
1427 remain = timers.until_next()
1429 self.debug(5, "Select on %r" % fds)
1430 r = select_func(fds, remain)
1431 self.debug(5, "Selected %r" % r)
1432 for fd in r:
1433 self.debug(5, "Looking at %r" % fd)
1434 if fd == self.cmdin:
1435 yield self.CommandMessage("Received command message") # noqa: E501
1436 elif fd == self.listen_sock:
1437 try:
1438 pkt = self.listen_sock.recv()
1439 except EOFError:
1440 # Socket was closed abruptly. This will likely only
1441 # ever happen when a client socket is passed to the
1442 # automaton (not the case when the automaton is
1443 # listening on a promiscuous conf.L2sniff)
1444 self.listen_sock.close()
1445 # False so that it is still reset by update_sock
1446 self.listen_sock = False # type: ignore
1447 fds.remove(fd)
1448 if self.state.state in self.eofs:
1449 # There is an eof state
1450 eof = self.eofs[self.state.state]
1451 self.debug(2, "Condition EOF [%s] taken" % eof.__name__) # noqa: E501
1452 raise self.eofs[self.state.state](self)
1453 else:
1454 # There isn't. Therefore, it's a closing condition.
1455 raise EOFError("Socket ended arbruptly.")
1456 if self.atmt_session is not None:
1457 # Apply session if provided
1458 pkt = self.atmt_session.process(pkt)
1459 if pkt is not None:
1460 if self.master_filter(pkt):
1461 self.debug(3, "RECVD: %s" % pkt.summary()) # noqa: E501
1462 for rcvcond in self.recv_conditions[self.state.state]: # noqa: E501
1463 self._run_condition(rcvcond, pkt, *state_output) # noqa: E501
1464 else:
1465 self.debug(4, "FILTR: %s" % pkt.summary()) # noqa: E501
1466 else:
1467 self.debug(3, "IOEVENT on %s" % fd.ioname)
1468 for ioevt in self.ioevents[self.state.state]:
1469 if ioevt.atmt_ioname == fd.ioname:
1470 self._run_condition(ioevt, fd, *state_output) # noqa: E501
1472 except ATMT.NewStateRequested as state_req:
1473 self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state)) # noqa: E501
1474 self.state = state_req
1475 yield state_req
1477 def __repr__(self):
1478 # type: () -> str
1479 return "<Automaton %s [%s]>" % (
1480 self.__class__.__name__,
1481 ["HALTED", "RUNNING"][self.isrunning()]
1482 )
1484 # Public API
1485 def add_interception_points(self, *ipts):
1486 # type: (Any) -> None
1487 for ipt in ipts:
1488 if hasattr(ipt, "atmt_state"):
1489 ipt = ipt.atmt_state
1490 self.interception_points.add(ipt)
1492 def remove_interception_points(self, *ipts):
1493 # type: (Any) -> None
1494 for ipt in ipts:
1495 if hasattr(ipt, "atmt_state"):
1496 ipt = ipt.atmt_state
1497 self.interception_points.discard(ipt)
1499 def add_breakpoints(self, *bps):
1500 # type: (Any) -> None
1501 for bp in bps:
1502 if hasattr(bp, "atmt_state"):
1503 bp = bp.atmt_state
1504 self.breakpoints.add(bp)
1506 def remove_breakpoints(self, *bps):
1507 # type: (Any) -> None
1508 for bp in bps:
1509 if hasattr(bp, "atmt_state"):
1510 bp = bp.atmt_state
1511 self.breakpoints.discard(bp)
1513 def start(self, *args, **kargs):
1514 # type: (Any, Any) -> None
1515 if self.isrunning():
1516 raise ValueError("Already started")
1517 # Start the control thread
1518 self._do_start(*args, **kargs)
1520 def run(self,
1521 resume=None, # type: Optional[Message]
1522 wait=True # type: Optional[bool]
1523 ):
1524 # type: (...) -> Any
1525 if resume is None:
1526 resume = Message(type=_ATMT_Command.RUN)
1527 self.cmdin.send(resume)
1528 if wait:
1529 try:
1530 c = self.cmdout.recv()
1531 if c is None:
1532 return None
1533 except KeyboardInterrupt:
1534 self.cmdin.send(Message(type=_ATMT_Command.FREEZE))
1535 return None
1536 if c.type == _ATMT_Command.END:
1537 return c.result
1538 elif c.type == _ATMT_Command.INTERCEPT:
1539 raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) # noqa: E501
1540 elif c.type == _ATMT_Command.SINGLESTEP:
1541 raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state) # noqa: E501
1542 elif c.type == _ATMT_Command.BREAKPOINT:
1543 raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501
1544 elif c.type == _ATMT_Command.EXCEPTION:
1545 # this code comes from the `six` module (`.reraise()`)
1546 # to raise an exception with specified exc_info.
1547 value = c.exc_info[0]() if c.exc_info[1] is None else c.exc_info[1] # type: ignore # noqa: E501
1548 if value.__traceback__ is not c.exc_info[2]:
1549 raise value.with_traceback(c.exc_info[2])
1550 raise value
1551 return None
1553 def runbg(self, resume=None, wait=False):
1554 # type: (Optional[Message], Optional[bool]) -> None
1555 self.run(resume, wait)
1557 def __next__(self):
1558 # type: () -> Any
1559 return self.run(resume=Message(type=_ATMT_Command.NEXT))
1561 def _flush_inout(self):
1562 # type: () -> None
1563 # Flush command pipes
1564 for cmd in [self.cmdin, self.cmdout]:
1565 cmd.clear()
1567 def destroy(self):
1568 # type: () -> None
1569 """
1570 Destroys a stopped Automaton: this cleanups all opened file descriptors.
1571 Required on PyPy for instance where the garbage collector behaves differently.
1572 """
1573 if not hasattr(self, "started"):
1574 return # was never started.
1575 if self.isrunning():
1576 raise ValueError("Can't close running Automaton ! Call stop() beforehand")
1577 # Close command pipes
1578 self.cmdin.close()
1579 self.cmdout.close()
1580 self._flush_inout()
1581 # Close opened ioins/ioouts
1582 for i in itertools.chain(self.ioin.values(), self.ioout.values()):
1583 if isinstance(i, ObjectPipe):
1584 i.close()
1586 def stop(self, wait=True):
1587 # type: (bool) -> None
1588 try:
1589 self.cmdin.send(Message(type=_ATMT_Command.STOP))
1590 except OSError:
1591 pass
1592 if wait:
1593 with self.started:
1594 self._flush_inout()
1596 def forcestop(self, wait=True):
1597 # type: (bool) -> None
1598 try:
1599 self.cmdin.send(Message(type=_ATMT_Command.FORCESTOP))
1600 except OSError:
1601 pass
1602 if wait:
1603 with self.started:
1604 self._flush_inout()
1606 def restart(self, *args, **kargs):
1607 # type: (Any, Any) -> None
1608 self.stop()
1609 self.start(*args, **kargs)
1611 def accept_packet(self,
1612 pkt=None, # type: Optional[Packet]
1613 wait=False # type: Optional[bool]
1614 ):
1615 # type: (...) -> Any
1616 rsm = Message()
1617 if pkt is None:
1618 rsm.type = _ATMT_Command.ACCEPT
1619 else:
1620 rsm.type = _ATMT_Command.REPLACE
1621 rsm.pkt = pkt
1622 return self.run(resume=rsm, wait=wait)
1624 def reject_packet(self,
1625 wait=False # type: Optional[bool]
1626 ):
1627 # type: (...) -> Any
1628 rsm = Message(type=_ATMT_Command.REJECT)
1629 return self.run(resume=rsm, wait=wait)