Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scapy/automaton.py: 37%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1015 statements  

1# SPDX-License-Identifier: GPL-2.0-only 

2# This file is part of Scapy 

3# See https://scapy.net/ for more information 

4# Copyright (C) Philippe Biondi <phil@secdev.org> 

5# Copyright (C) Gabriel Potter 

6 

7""" 

8Automata with states, transitions and actions. 

9 

10TODO: 

11 - add documentation for ioevent, as_supersocket... 

12""" 

13 

14import ctypes 

15import itertools 

16import logging 

17import os 

18import random 

19import socket 

20import sys 

21import threading 

22import time 

23import traceback 

24import types 

25 

26import select 

27from collections import deque 

28 

29from scapy.config import conf 

30from scapy.consts import WINDOWS 

31from scapy.data import MTU 

32from scapy.error import log_runtime, warning 

33from scapy.interfaces import _GlobInterfaceType 

34from scapy.packet import Packet 

35from scapy.plist import PacketList 

36from scapy.supersocket import SuperSocket, StreamSocket 

37from scapy.utils import do_graph 

38 

39# Typing imports 

40from typing import ( 

41 Any, 

42 Callable, 

43 Deque, 

44 Dict, 

45 Generic, 

46 Iterable, 

47 Iterator, 

48 List, 

49 Optional, 

50 Set, 

51 Tuple, 

52 Type, 

53 TypeVar, 

54 Union, 

55 cast, 

56) 

57from scapy.compat import DecoratorCallable 

58 

59 

60def select_objects(inputs, remain): 

61 # type: (Iterable[Any], Union[float, int, None]) -> List[Any] 

62 """ 

63 Select objects. Same than: 

64 ``select.select(inputs, [], [], remain)`` 

65 

66 But also works on Windows, only on objects whose fileno() returns 

67 a Windows event. For simplicity, just use `ObjectPipe()` as a queue 

68 that you can select on whatever the platform is. 

69 

70 If you want an object to be always included in the output of 

71 select_objects (i.e. it's not selectable), just make fileno() 

72 return a strictly negative value. 

73 

74 Example: 

75 

76 >>> a, b = ObjectPipe("a"), ObjectPipe("b") 

77 >>> b.send("test") 

78 >>> select_objects([a, b], 1) 

79 [b] 

80 

81 :param inputs: objects to process 

82 :param remain: timeout. If 0, poll. If None, block. 

83 """ 

84 if not WINDOWS: 

85 return select.select(inputs, [], [], remain)[0] 

86 natives = [] 

87 events = [] 

88 results = set() 

89 for i in list(inputs): 

90 if getattr(i, "__selectable_force_select__", False): 

91 natives.append(i) 

92 elif i.fileno() < 0: 

93 # Special case: On Windows, we consider that an object that returns 

94 # a negative fileno (impossible), is always readable. This is used 

95 # in very few places but important (e.g. PcapReader), where we have 

96 # no valid fileno (and will stop on EOFError). 

97 results.add(i) 

98 else: 

99 events.append(i) 

100 if natives: 

101 results = results.union(set(select.select(natives, [], [], remain)[0])) 

102 if results: 

103 # We have native results, poll. 

104 remain = 0 

105 if events: 

106 # 0xFFFFFFFF = INFINITE 

107 remainms = int(remain * 1000 if remain is not None else 0xFFFFFFFF) 

108 if len(events) == 1: 

109 res = ctypes.windll.kernel32.WaitForSingleObject( 

110 ctypes.c_void_p(events[0].fileno()), 

111 remainms 

112 ) 

113 else: 

114 # Sadly, the only way to emulate select() is to first check 

115 # if any object is available using WaitForMultipleObjects 

116 # then poll the others. 

117 res = ctypes.windll.kernel32.WaitForMultipleObjects( 

118 len(events), 

119 (ctypes.c_void_p * len(events))( 

120 *[x.fileno() for x in events] 

121 ), 

122 False, 

123 remainms 

124 ) 

125 if res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout 

126 results.add(events[res]) 

127 if len(events) > 1: 

128 # Now poll the others, if any 

129 for evt in events: 

130 res = ctypes.windll.kernel32.WaitForSingleObject( 

131 ctypes.c_void_p(evt.fileno()), 

132 0 # poll: don't wait 

133 ) 

134 if res == 0: 

135 results.add(evt) 

136 return list(results) 

137 

138 

139_T = TypeVar("_T") 

140 

141 

142class ObjectPipe(Generic[_T]): 

143 def __init__(self, name=None): 

144 # type: (Optional[str]) -> None 

145 self.name = name or "ObjectPipe" 

146 self.closed = False 

147 self.__rd, self.__wr = os.pipe() 

148 self.__queue = deque() # type: Deque[_T] 

149 if WINDOWS: 

150 self._wincreate() 

151 

152 if WINDOWS: 

153 def _wincreate(self): 

154 # type: () -> None 

155 self._fd = cast(int, ctypes.windll.kernel32.CreateEventA( 

156 None, True, False, 

157 ctypes.create_string_buffer(b"ObjectPipe %f" % random.random()) 

158 )) 

159 

160 def _winset(self): 

161 # type: () -> None 

162 if ctypes.windll.kernel32.SetEvent(ctypes.c_void_p(self._fd)) == 0: 

163 warning(ctypes.FormatError(ctypes.GetLastError())) 

164 

165 def _winreset(self): 

166 # type: () -> None 

167 if ctypes.windll.kernel32.ResetEvent(ctypes.c_void_p(self._fd)) == 0: 

168 warning(ctypes.FormatError(ctypes.GetLastError())) 

169 

170 def _winclose(self): 

171 # type: () -> None 

172 if ctypes.windll.kernel32.CloseHandle(ctypes.c_void_p(self._fd)) == 0: 

173 warning(ctypes.FormatError(ctypes.GetLastError())) 

174 

175 def fileno(self): 

176 # type: () -> int 

177 if WINDOWS: 

178 return self._fd 

179 return self.__rd 

180 

181 def send(self, obj): 

182 # type: (_T) -> int 

183 self.__queue.append(obj) 

184 if WINDOWS: 

185 self._winset() 

186 os.write(self.__wr, b"X") 

187 return 1 

188 

189 def write(self, obj): 

190 # type: (_T) -> None 

191 self.send(obj) 

192 

193 def empty(self): 

194 # type: () -> bool 

195 return not bool(self.__queue) 

196 

197 def flush(self): 

198 # type: () -> None 

199 pass 

200 

201 def recv(self, n=0, options=socket.MsgFlag(0)): 

202 # type: (Optional[int], socket.MsgFlag) -> Optional[_T] 

203 if self.closed: 

204 raise EOFError 

205 if options & socket.MSG_PEEK: 

206 if self.__queue: 

207 return self.__queue[0] 

208 return None 

209 os.read(self.__rd, 1) 

210 elt = self.__queue.popleft() 

211 if WINDOWS and not self.__queue: 

212 self._winreset() 

213 return elt 

214 

215 def read(self, n=0): 

216 # type: (Optional[int]) -> Optional[_T] 

217 return self.recv(n) 

218 

219 def clear(self): 

220 # type: () -> None 

221 if not self.closed: 

222 while not self.empty(): 

223 self.recv() 

224 

225 def close(self): 

226 # type: () -> None 

227 if not self.closed: 

228 self.closed = True 

229 os.close(self.__rd) 

230 os.close(self.__wr) 

231 if WINDOWS: 

232 try: 

233 self._winclose() 

234 except ImportError: 

235 # Python is shutting down 

236 pass 

237 

238 def __repr__(self): 

239 # type: () -> str 

240 return "<%s at %s>" % (self.name, id(self)) 

241 

242 def __del__(self): 

243 # type: () -> None 

244 self.close() 

245 

246 @staticmethod 

247 def select(sockets, remain=conf.recv_poll_rate): 

248 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] 

249 # Only handle ObjectPipes 

250 results = [] 

251 for s in sockets: 

252 if s.closed: # allow read to trigger EOF 

253 results.append(s) 

254 if results: 

255 return results 

256 return select_objects(sockets, remain) 

257 

258 

259class Message: 

260 type = None # type: str 

261 pkt = None # type: Packet 

262 result = None # type: str 

263 state = None # type: Message 

264 exc_info = None # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501 

265 

266 def __init__(self, **args): 

267 # type: (Any) -> None 

268 self.__dict__.update(args) 

269 

270 def __repr__(self): 

271 # type: () -> str 

272 return "<Message %s>" % " ".join( 

273 "%s=%r" % (k, v) 

274 for k, v in self.__dict__.items() 

275 if not k.startswith("_") 

276 ) 

277 

278 

279class Timer(): 

280 def __init__(self, time, prio=0, autoreload=False): 

281 # type: (Union[int, float], int, bool) -> None 

282 self._timeout = float(time) # type: float 

283 self._time = 0 # type: float 

284 self._just_expired = True 

285 self._expired = True 

286 self._prio = prio 

287 self._func = _StateWrapper() 

288 self._autoreload = autoreload 

289 

290 def get(self): 

291 # type: () -> float 

292 return self._timeout 

293 

294 def set(self, val): 

295 # type: (float) -> None 

296 self._timeout = val 

297 

298 def _reset(self): 

299 # type: () -> None 

300 self._time = self._timeout 

301 self._expired = False 

302 self._just_expired = False 

303 

304 def _reset_just_expired(self): 

305 # type: () -> None 

306 self._just_expired = False 

307 

308 def _running(self): 

309 # type: () -> bool 

310 return self._time > 0 

311 

312 def _remaining(self): 

313 # type: () -> float 

314 return max(self._time, 0) 

315 

316 def _decrement(self, time): 

317 # type: (float) -> None 

318 self._time -= time 

319 if self._time <= 0: 

320 if not self._expired: 

321 self._just_expired = True 

322 if self._autoreload: 

323 # take overshoot into account 

324 self._time = self._timeout + self._time 

325 else: 

326 self._expired = True 

327 self._time = 0 

328 

329 def __lt__(self, obj): 

330 # type: (Timer) -> bool 

331 return ((self._time < obj._time) if self._time != obj._time 

332 else (self._prio < obj._prio)) 

333 

334 def __gt__(self, obj): 

335 # type: (Timer) -> bool 

336 return ((self._time > obj._time) if self._time != obj._time 

337 else (self._prio > obj._prio)) 

338 

339 def __eq__(self, obj): 

340 # type: (Any) -> bool 

341 if not isinstance(obj, Timer): 

342 raise NotImplementedError() 

343 return (self._time == obj._time) and (self._prio == obj._prio) 

344 

345 def __repr__(self): 

346 # type: () -> str 

347 return "<Timer %f(%f)>" % (self._time, self._timeout) 

348 

349 

350class _TimerList(): 

351 def __init__(self): 

352 # type: () -> None 

353 self.timers = [] # type: list[Timer] 

354 

355 def add_timer(self, timer): 

356 # type: (Timer) -> None 

357 self.timers.append(timer) 

358 

359 def reset(self): 

360 # type: () -> None 

361 for t in self.timers: 

362 t._reset() 

363 

364 def decrement(self, time): 

365 # type: (float) -> None 

366 for t in self.timers: 

367 t._decrement(time) 

368 

369 def expired(self): 

370 # type: () -> list[Timer] 

371 lst = [t for t in self.timers if t._just_expired] 

372 lst.sort(key=lambda x: x._prio, reverse=True) 

373 for t in lst: 

374 t._reset_just_expired() 

375 return lst 

376 

377 def until_next(self): 

378 # type: () -> Optional[float] 

379 try: 

380 return min([t._remaining() for t in self.timers if t._running()]) 

381 except ValueError: 

382 return None # None means blocking 

383 

384 def count(self): 

385 # type: () -> int 

386 return len(self.timers) 

387 

388 def __iter__(self): 

389 # type: () -> Iterator[Timer] 

390 return self.timers.__iter__() 

391 

392 def __repr__(self): 

393 # type: () -> str 

394 return self.timers.__repr__() 

395 

396 

397class _instance_state: 

398 def __init__(self, instance): 

399 # type: (Any) -> None 

400 self.__self__ = instance.__self__ 

401 self.__func__ = instance.__func__ 

402 self.__self__.__class__ = instance.__self__.__class__ 

403 

404 def __getattr__(self, attr): 

405 # type: (str) -> Any 

406 return getattr(self.__func__, attr) 

407 

408 def __call__(self, *args, **kargs): 

409 # type: (Any, Any) -> Any 

410 return self.__func__(self.__self__, *args, **kargs) 

411 

412 def breaks(self): 

413 # type: () -> Any 

414 return self.__self__.add_breakpoints(self.__func__) 

415 

416 def intercepts(self): 

417 # type: () -> Any 

418 return self.__self__.add_interception_points(self.__func__) 

419 

420 def unbreaks(self): 

421 # type: () -> Any 

422 return self.__self__.remove_breakpoints(self.__func__) 

423 

424 def unintercepts(self): 

425 # type: () -> Any 

426 return self.__self__.remove_interception_points(self.__func__) 

427 

428 

429############## 

430# Automata # 

431############## 

432 

433class _StateWrapper: 

434 __name__ = None # type: str 

435 atmt_type = None # type: str 

436 atmt_state = None # type: str 

437 atmt_initial = None # type: int 

438 atmt_final = None # type: int 

439 atmt_stop = None # type: int 

440 atmt_error = None # type: int 

441 atmt_origfunc = None # type: _StateWrapper 

442 atmt_prio = None # type: int 

443 atmt_as_supersocket = None # type: Optional[str] 

444 atmt_condname = None # type: str 

445 atmt_ioname = None # type: str 

446 atmt_timeout = None # type: Timer 

447 atmt_cond = None # type: Dict[str, int] 

448 __code__ = None # type: types.CodeType 

449 __call__ = None # type: Callable[..., ATMT.NewStateRequested] 

450 

451 

452class ATMT: 

453 STATE = "State" 

454 ACTION = "Action" 

455 CONDITION = "Condition" 

456 RECV = "Receive condition" 

457 TIMEOUT = "Timeout condition" 

458 EOF = "EOF condition" 

459 IOEVENT = "I/O event" 

460 

461 class NewStateRequested(Exception): 

462 def __init__(self, state_func, automaton, *args, **kargs): 

463 # type: (Any, ATMT, Any, Any) -> None 

464 self.func = state_func 

465 self.state = state_func.atmt_state 

466 self.initial = state_func.atmt_initial 

467 self.error = state_func.atmt_error 

468 self.stop = state_func.atmt_stop 

469 self.final = state_func.atmt_final 

470 Exception.__init__(self, "Request state [%s]" % self.state) 

471 self.automaton = automaton 

472 self.args = args 

473 self.kargs = kargs 

474 self.action_parameters() # init action parameters 

475 

476 def action_parameters(self, *args, **kargs): 

477 # type: (Any, Any) -> ATMT.NewStateRequested 

478 self.action_args = args 

479 self.action_kargs = kargs 

480 return self 

481 

482 def run(self): 

483 # type: () -> Any 

484 return self.func(self.automaton, *self.args, **self.kargs) 

485 

486 def __repr__(self): 

487 # type: () -> str 

488 return "NewStateRequested(%s)" % self.state 

489 

490 @staticmethod 

491 def state(initial=0, # type: int 

492 final=0, # type: int 

493 stop=0, # type: int 

494 error=0 # type: int 

495 ): 

496 # type: (...) -> Callable[[DecoratorCallable], DecoratorCallable] 

497 def deco(f, initial=initial, final=final): 

498 # type: (_StateWrapper, int, int) -> _StateWrapper 

499 f.atmt_type = ATMT.STATE 

500 f.atmt_state = f.__name__ 

501 f.atmt_initial = initial 

502 f.atmt_final = final 

503 f.atmt_stop = stop 

504 f.atmt_error = error 

505 

506 def _state_wrapper(self, *args, **kargs): 

507 # type: (ATMT, Any, Any) -> ATMT.NewStateRequested 

508 return ATMT.NewStateRequested(f, self, *args, **kargs) 

509 

510 state_wrapper = cast(_StateWrapper, _state_wrapper) 

511 state_wrapper.__name__ = "%s_wrapper" % f.__name__ 

512 state_wrapper.atmt_type = ATMT.STATE 

513 state_wrapper.atmt_state = f.__name__ 

514 state_wrapper.atmt_initial = initial 

515 state_wrapper.atmt_final = final 

516 state_wrapper.atmt_stop = stop 

517 state_wrapper.atmt_error = error 

518 state_wrapper.atmt_origfunc = f 

519 return state_wrapper 

520 return deco # type: ignore 

521 

522 @staticmethod 

523 def action(cond, prio=0): 

524 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 

525 def deco(f, cond=cond): 

526 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 

527 if not hasattr(f, "atmt_type"): 

528 f.atmt_cond = {} 

529 f.atmt_type = ATMT.ACTION 

530 f.atmt_cond[cond.atmt_condname] = prio 

531 return f 

532 return deco 

533 

534 @staticmethod 

535 def condition(state, prio=0): 

536 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 

537 def deco(f, state=state): 

538 # type: (_StateWrapper, _StateWrapper) -> Any 

539 f.atmt_type = ATMT.CONDITION 

540 f.atmt_state = state.atmt_state 

541 f.atmt_condname = f.__name__ 

542 f.atmt_prio = prio 

543 return f 

544 return deco 

545 

546 @staticmethod 

547 def receive_condition(state, prio=0): 

548 # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 

549 def deco(f, state=state): 

550 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 

551 f.atmt_type = ATMT.RECV 

552 f.atmt_state = state.atmt_state 

553 f.atmt_condname = f.__name__ 

554 f.atmt_prio = prio 

555 return f 

556 return deco 

557 

558 @staticmethod 

559 def ioevent(state, # type: _StateWrapper 

560 name, # type: str 

561 prio=0, # type: int 

562 as_supersocket=None # type: Optional[str] 

563 ): 

564 # type: (...) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 

565 def deco(f, state=state): 

566 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 

567 f.atmt_type = ATMT.IOEVENT 

568 f.atmt_state = state.atmt_state 

569 f.atmt_condname = f.__name__ 

570 f.atmt_ioname = name 

571 f.atmt_prio = prio 

572 f.atmt_as_supersocket = as_supersocket 

573 return f 

574 return deco 

575 

576 @staticmethod 

577 def timeout(state, timeout): 

578 # type: (_StateWrapper, Union[int, float]) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501 

579 def deco(f, state=state, timeout=Timer(timeout)): 

580 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper 

581 f.atmt_type = ATMT.TIMEOUT 

582 f.atmt_state = state.atmt_state 

583 f.atmt_timeout = timeout 

584 f.atmt_timeout._func = f 

585 f.atmt_condname = f.__name__ 

586 return f 

587 return deco 

588 

589 @staticmethod 

590 def timer(state, timeout, prio=0): 

591 # type: (_StateWrapper, Union[float, int], int) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501 

592 def deco(f, state=state, timeout=Timer(timeout, prio=prio, autoreload=True)): 

593 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper 

594 f.atmt_type = ATMT.TIMEOUT 

595 f.atmt_state = state.atmt_state 

596 f.atmt_timeout = timeout 

597 f.atmt_timeout._func = f 

598 f.atmt_condname = f.__name__ 

599 return f 

600 return deco 

601 

602 @staticmethod 

603 def eof(state): 

604 # type: (_StateWrapper) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 

605 def deco(f, state=state): 

606 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 

607 f.atmt_type = ATMT.EOF 

608 f.atmt_state = state.atmt_state 

609 f.atmt_condname = f.__name__ 

610 return f 

611 return deco 

612 

613 

614class _ATMT_Command: 

615 RUN = "RUN" 

616 NEXT = "NEXT" 

617 FREEZE = "FREEZE" 

618 STOP = "STOP" 

619 FORCESTOP = "FORCESTOP" 

620 END = "END" 

621 EXCEPTION = "EXCEPTION" 

622 SINGLESTEP = "SINGLESTEP" 

623 BREAKPOINT = "BREAKPOINT" 

624 INTERCEPT = "INTERCEPT" 

625 ACCEPT = "ACCEPT" 

626 REPLACE = "REPLACE" 

627 REJECT = "REJECT" 

628 

629 

630class _ATMT_supersocket(SuperSocket): 

631 def __init__(self, 

632 name, # type: str 

633 ioevent, # type: str 

634 automaton, # type: Type[Automaton] 

635 proto, # type: Callable[[bytes], Any] 

636 *args, # type: Any 

637 **kargs # type: Any 

638 ): 

639 # type: (...) -> None 

640 self.name = name 

641 self.ioevent = ioevent 

642 self.proto = proto 

643 # write, read 

644 self.spa, self.spb = ObjectPipe[Any]("spa"), \ 

645 ObjectPipe[Any]("spb") 

646 kargs["external_fd"] = {ioevent: (self.spa, self.spb)} 

647 kargs["is_atmt_socket"] = True 

648 kargs["atmt_socket"] = self.name 

649 self.atmt = automaton(*args, **kargs) 

650 self.atmt.runbg() 

651 

652 def send(self, s): 

653 # type: (Any) -> int 

654 return self.spa.send(s) 

655 

656 def fileno(self): 

657 # type: () -> int 

658 return self.spb.fileno() 

659 

660 # note: _ATMT_supersocket may return bytes in certain cases, which 

661 # is expected. We cheat on typing. 

662 def recv(self, n=MTU, **kwargs): # type: ignore 

663 # type: (int, **Any) -> Any 

664 r = self.spb.recv(n) 

665 if self.proto is not None and r is not None: 

666 r = self.proto(r, **kwargs) 

667 return r 

668 

669 def close(self): 

670 # type: () -> None 

671 if not self.closed: 

672 self.atmt.stop() 

673 self.atmt.destroy() 

674 self.spa.close() 

675 self.spb.close() 

676 self.closed = True 

677 

678 @staticmethod 

679 def select(sockets, remain=conf.recv_poll_rate): 

680 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] 

681 return select_objects(sockets, remain) 

682 

683 

684class _ATMT_to_supersocket: 

685 def __init__(self, name, ioevent, automaton): 

686 # type: (str, str, Type[Automaton]) -> None 

687 self.name = name 

688 self.ioevent = ioevent 

689 self.automaton = automaton 

690 

691 def __call__(self, proto, *args, **kargs): 

692 # type: (Callable[[bytes], Any], Any, Any) -> _ATMT_supersocket 

693 return _ATMT_supersocket( 

694 self.name, self.ioevent, self.automaton, 

695 proto, *args, **kargs 

696 ) 

697 

698 

699class Automaton_metaclass(type): 

700 def __new__(cls, name, bases, dct): 

701 # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton] 

702 cls = super(Automaton_metaclass, cls).__new__( # type: ignore 

703 cls, name, bases, dct 

704 ) 

705 cls.states = {} 

706 cls.recv_conditions = {} # type: Dict[str, List[_StateWrapper]] 

707 cls.conditions = {} # type: Dict[str, List[_StateWrapper]] 

708 cls.ioevents = {} # type: Dict[str, List[_StateWrapper]] 

709 cls.timeout = {} # type: Dict[str, _TimerList] 

710 cls.eofs = {} # type: Dict[str, _StateWrapper] 

711 cls.actions = {} # type: Dict[str, List[_StateWrapper]] 

712 cls.initial_states = [] # type: List[_StateWrapper] 

713 cls.stop_state = None # type: Optional[_StateWrapper] 

714 cls.ionames = [] 

715 cls.iosupersockets = [] 

716 

717 members = {} 

718 classes = [cls] 

719 while classes: 

720 c = classes.pop(0) # order is important to avoid breaking method overloading # noqa: E501 

721 classes += list(c.__bases__) 

722 for k, v in c.__dict__.items(): # type: ignore 

723 if k not in members: 

724 members[k] = v 

725 

726 decorated = [v for v in members.values() 

727 if hasattr(v, "atmt_type")] 

728 

729 for m in decorated: 

730 if m.atmt_type == ATMT.STATE: 

731 s = m.atmt_state 

732 cls.states[s] = m 

733 cls.recv_conditions[s] = [] 

734 cls.ioevents[s] = [] 

735 cls.conditions[s] = [] 

736 cls.timeout[s] = _TimerList() 

737 if m.atmt_initial: 

738 cls.initial_states.append(m) 

739 if m.atmt_stop: 

740 if cls.stop_state is not None: 

741 raise ValueError("There can only be a single stop state !") 

742 cls.stop_state = m 

743 elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT, ATMT.EOF]: # noqa: E501 

744 cls.actions[m.atmt_condname] = [] 

745 

746 for m in decorated: 

747 if m.atmt_type == ATMT.CONDITION: 

748 cls.conditions[m.atmt_state].append(m) 

749 elif m.atmt_type == ATMT.RECV: 

750 cls.recv_conditions[m.atmt_state].append(m) 

751 elif m.atmt_type == ATMT.EOF: 

752 cls.eofs[m.atmt_state] = m 

753 elif m.atmt_type == ATMT.IOEVENT: 

754 cls.ioevents[m.atmt_state].append(m) 

755 cls.ionames.append(m.atmt_ioname) 

756 if m.atmt_as_supersocket is not None: 

757 cls.iosupersockets.append(m) 

758 elif m.atmt_type == ATMT.TIMEOUT: 

759 cls.timeout[m.atmt_state].add_timer(m.atmt_timeout) 

760 elif m.atmt_type == ATMT.ACTION: 

761 for co in m.atmt_cond: 

762 cls.actions[co].append(m) 

763 

764 for v in itertools.chain( 

765 cls.conditions.values(), 

766 cls.recv_conditions.values(), 

767 cls.ioevents.values() 

768 ): 

769 v.sort(key=lambda x: x.atmt_prio) 

770 for condname, actlst in cls.actions.items(): 

771 actlst.sort(key=lambda x: x.atmt_cond[condname]) 

772 

773 for ioev in cls.iosupersockets: 

774 setattr(cls, ioev.atmt_as_supersocket, 

775 _ATMT_to_supersocket( 

776 ioev.atmt_as_supersocket, 

777 ioev.atmt_ioname, 

778 cast(Type["Automaton"], cls))) 

779 

780 # Inject signature 

781 try: 

782 import inspect 

783 cls.__signature__ = inspect.signature(cls.parse_args) # type: ignore # noqa: E501 

784 except (ImportError, AttributeError): 

785 pass 

786 

787 return cast(Type["Automaton"], cls) 

788 

789 def build_graph(self): 

790 # type: () -> str 

791 s = 'digraph "%s" {\n' % self.__class__.__name__ 

792 

793 se = "" # Keep initial nodes at the beginning for better rendering 

794 for st in self.states.values(): 

795 if st.atmt_initial: 

796 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se # noqa: E501 

797 elif st.atmt_final: 

798 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state # noqa: E501 

799 elif st.atmt_error: 

800 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state # noqa: E501 

801 elif st.atmt_stop: 

802 se += '\t"%s" [ style=filled, fillcolor=orange, shape=box, root=true ];\n' % st.atmt_state # noqa: E501 

803 s += se 

804 

805 for st in self.states.values(): 

806 names = list( 

807 st.atmt_origfunc.__code__.co_names + 

808 st.atmt_origfunc.__code__.co_consts 

809 ) 

810 while names: 

811 n = names.pop() 

812 if n in self.states: 

813 s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n) 

814 elif n in self.__dict__: 

815 # function indirection 

816 if callable(self.__dict__[n]): 

817 names.extend(self.__dict__[n].__code__.co_names) 

818 names.extend(self.__dict__[n].__code__.co_consts) 

819 

820 for c, sty, k, v in ( 

821 [("purple", "solid", k, v) for k, v in self.conditions.items()] + 

822 [("red", "solid", k, v) for k, v in self.recv_conditions.items()] + 

823 [("orange", "solid", k, v) for k, v in self.ioevents.items()] + 

824 [("black", "dashed", k, [v]) for k, v in self.eofs.items()] 

825 ): 

826 for f in v: 

827 names = list(f.__code__.co_names + f.__code__.co_consts) 

828 while names: 

829 n = names.pop() 

830 if n in self.states: 

831 line = f.atmt_condname 

832 for x in self.actions[f.atmt_condname]: 

833 line += "\\l>[%s]" % x.__name__ 

834 s += '\t"%s" -> "%s" [label="%s", color=%s, style=%s];\n' % ( 

835 k, 

836 n, 

837 line, 

838 c, 

839 sty, 

840 ) 

841 elif n in self.__dict__: 

842 # function indirection 

843 if callable(self.__dict__[n]) and hasattr(self.__dict__[n], "__code__"): # noqa: E501 

844 names.extend(self.__dict__[n].__code__.co_names) 

845 names.extend(self.__dict__[n].__code__.co_consts) 

846 for k, timers in self.timeout.items(): 

847 for timer in timers: 

848 for n in (timer._func.__code__.co_names + 

849 timer._func.__code__.co_consts): 

850 if n in self.states: 

851 line = "%s/%.1fs" % (timer._func.atmt_condname, 

852 timer.get()) 

853 for x in self.actions[timer._func.atmt_condname]: 

854 line += "\\l>[%s]" % x.__name__ 

855 s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line) # noqa: E501 

856 s += "}\n" 

857 return s 

858 

859 def graph(self, **kargs): 

860 # type: (Any) -> Optional[str] 

861 s = self.build_graph() 

862 return do_graph(s, **kargs) 

863 

864 

865class Automaton(metaclass=Automaton_metaclass): 

866 states = {} # type: Dict[str, _StateWrapper] 

867 state = None # type: ATMT.NewStateRequested 

868 recv_conditions = {} # type: Dict[str, List[_StateWrapper]] 

869 conditions = {} # type: Dict[str, List[_StateWrapper]] 

870 eofs = {} # type: Dict[str, _StateWrapper] 

871 ioevents = {} # type: Dict[str, List[_StateWrapper]] 

872 timeout = {} # type: Dict[str, _TimerList] 

873 actions = {} # type: Dict[str, List[_StateWrapper]] 

874 initial_states = [] # type: List[_StateWrapper] 

875 stop_state = None # type: Optional[_StateWrapper] 

876 ionames = [] # type: List[str] 

877 iosupersockets = [] # type: List[SuperSocket] 

878 

879 # used for spawn() 

880 pkt_cls = conf.raw_layer 

881 socketcls = StreamSocket 

882 

883 # Internals 

884 def __init__(self, *args, **kargs): 

885 # type: (Any, Any) -> None 

886 external_fd = kargs.pop("external_fd", {}) 

887 if "sock" in kargs: 

888 # We use a bi-directional sock 

889 self.sock = kargs["sock"] 

890 else: 

891 # Separate sockets 

892 self.sock = None 

893 self.send_sock_class = kargs.pop("ll", conf.L3socket) 

894 self.recv_sock_class = kargs.pop("recvsock", conf.L2listen) 

895 self.listen_sock = None # type: Optional[SuperSocket] 

896 self.send_sock = None # type: Optional[SuperSocket] 

897 self.is_atmt_socket = kargs.pop("is_atmt_socket", False) 

898 self.atmt_socket = kargs.pop("atmt_socket", None) 

899 self.started = threading.Lock() 

900 self.threadid = None # type: Optional[int] 

901 self.breakpointed = None 

902 self.breakpoints = set() # type: Set[_StateWrapper] 

903 self.interception_points = set() # type: Set[_StateWrapper] 

904 self.intercepted_packet = None # type: Union[None, Packet] 

905 self.debug_level = 0 

906 self.init_args = args 

907 self.init_kargs = kargs 

908 self.io = type.__new__(type, "IOnamespace", (), {}) 

909 self.oi = type.__new__(type, "IOnamespace", (), {}) 

910 self.cmdin = ObjectPipe[Message]("cmdin") 

911 self.cmdout = ObjectPipe[Message]("cmdout") 

912 self.ioin = {} 

913 self.ioout = {} 

914 self.packets = PacketList() # type: PacketList 

915 for n in self.__class__.ionames: 

916 extfd = external_fd.get(n) 

917 if not isinstance(extfd, tuple): 

918 extfd = (extfd, extfd) 

919 ioin, ioout = extfd 

920 if ioin is None: 

921 ioin = ObjectPipe("ioin") 

922 else: 

923 ioin = self._IO_fdwrapper(ioin, None) 

924 if ioout is None: 

925 ioout = ObjectPipe("ioout") 

926 else: 

927 ioout = self._IO_fdwrapper(None, ioout) 

928 

929 self.ioin[n] = ioin 

930 self.ioout[n] = ioout 

931 ioin.ioname = n 

932 ioout.ioname = n 

933 setattr(self.io, n, self._IO_mixer(ioout, ioin)) 

934 setattr(self.oi, n, self._IO_mixer(ioin, ioout)) 

935 

936 for stname in self.states: 

937 setattr(self, stname, 

938 _instance_state(getattr(self, stname))) 

939 

940 self.start() 

941 

942 def parse_args(self, debug=0, store=0, **kargs): 

943 # type: (int, int, Any) -> None 

944 self.debug_level = debug 

945 if debug: 

946 conf.logLevel = logging.DEBUG 

947 self.socket_kargs = kargs 

948 self.store_packets = store 

949 

950 @classmethod 

951 def spawn(cls, 

952 port: int, 

953 iface: Optional[_GlobInterfaceType] = None, 

954 bg: bool = False, 

955 **kwargs: Any) -> Optional[socket.socket]: 

956 """ 

957 Spawn a TCP server that listens for connections and start the automaton 

958 for each new client. 

959 

960 :param port: the port to listen to 

961 :param bg: background mode? (default: False) 

962 """ 

963 from scapy.arch import get_if_addr 

964 # create server sock and bind it 

965 ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 

966 local_ip = get_if_addr(iface or conf.iface) 

967 try: 

968 ssock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

969 except OSError: 

970 pass 

971 ssock.bind((local_ip, port)) 

972 ssock.listen(5) 

973 clients = [] 

974 if kwargs.get("verb", True): 

975 print(conf.color_theme.green( 

976 "Server %s started listening on %s" % ( 

977 cls.__name__, 

978 (local_ip, port), 

979 ) 

980 )) 

981 

982 def _run() -> None: 

983 # Wait for clients forever 

984 try: 

985 while True: 

986 clientsocket, address = ssock.accept() 

987 if kwargs.get("verb", True): 

988 print(conf.color_theme.gold( 

989 "\u2503 Connection received from %s" % repr(address) 

990 )) 

991 # Start atmt class with socket 

992 sock = cls.socketcls(clientsocket, cls.pkt_cls) 

993 atmt_server = cls( 

994 sock=sock, 

995 iface=iface, **kwargs 

996 ) 

997 clients.append((atmt_server, clientsocket)) 

998 # start atmt 

999 atmt_server.runbg() 

1000 except KeyboardInterrupt: 

1001 print("X Exiting.") 

1002 ssock.shutdown(socket.SHUT_RDWR) 

1003 except OSError: 

1004 print("X Server closed.") 

1005 finally: 

1006 for atmt, clientsocket in clients: 

1007 try: 

1008 atmt.forcestop(wait=False) 

1009 except Exception: 

1010 pass 

1011 try: 

1012 clientsocket.shutdown(socket.SHUT_RDWR) 

1013 clientsocket.close() 

1014 except Exception: 

1015 pass 

1016 ssock.close() 

1017 if bg: 

1018 # Background 

1019 threading.Thread(target=_run).start() 

1020 return ssock 

1021 else: 

1022 # Non-background 

1023 _run() 

1024 return None 

1025 

1026 def master_filter(self, pkt): 

1027 # type: (Packet) -> bool 

1028 return True 

1029 

1030 def my_send(self, pkt, **kwargs): 

1031 # type: (Packet, **Any) -> None 

1032 if not self.send_sock: 

1033 raise ValueError("send_sock is None !") 

1034 self.send_sock.send(pkt, **kwargs) 

1035 

1036 def update_sock(self, sock): 

1037 # type: (SuperSocket) -> None 

1038 """ 

1039 Update the socket used by the automata. 

1040 Typically used in an eof event to reconnect. 

1041 """ 

1042 self.sock = sock 

1043 if self.listen_sock is not None: 

1044 self.listen_sock = self.sock 

1045 if self.send_sock: 

1046 self.send_sock = self.sock 

1047 

1048 def timer_by_name(self, name): 

1049 # type: (str) -> Optional[Timer] 

1050 for _, timers in self.timeout.items(): 

1051 for timer in timers: # type: Timer 

1052 if timer._func.atmt_condname == name: 

1053 return timer 

1054 return None 

1055 

1056 # Utility classes and exceptions 

1057 class _IO_fdwrapper: 

1058 def __init__(self, 

1059 rd, # type: Union[int, ObjectPipe[bytes], None] 

1060 wr # type: Union[int, ObjectPipe[bytes], None] 

1061 ): 

1062 # type: (...) -> None 

1063 self.rd = rd 

1064 self.wr = wr 

1065 if isinstance(self.rd, socket.socket): 

1066 self.__selectable_force_select__ = True 

1067 

1068 def fileno(self): 

1069 # type: () -> int 

1070 if isinstance(self.rd, int): 

1071 return self.rd 

1072 elif self.rd: 

1073 return self.rd.fileno() 

1074 return 0 

1075 

1076 def read(self, n=65535): 

1077 # type: (int) -> Optional[bytes] 

1078 if isinstance(self.rd, int): 

1079 return os.read(self.rd, n) 

1080 elif self.rd: 

1081 return self.rd.recv(n) 

1082 return None 

1083 

1084 def write(self, msg): 

1085 # type: (bytes) -> int 

1086 if isinstance(self.wr, int): 

1087 return os.write(self.wr, msg) 

1088 elif self.wr: 

1089 return self.wr.send(msg) 

1090 return 0 

1091 

1092 def recv(self, n=65535): 

1093 # type: (int) -> Optional[bytes] 

1094 return self.read(n) 

1095 

1096 def send(self, msg): 

1097 # type: (bytes) -> int 

1098 return self.write(msg) 

1099 

1100 class _IO_mixer: 

1101 def __init__(self, 

1102 rd, # type: ObjectPipe[Any] 

1103 wr, # type: ObjectPipe[Any] 

1104 ): 

1105 # type: (...) -> None 

1106 self.rd = rd 

1107 self.wr = wr 

1108 

1109 def fileno(self): 

1110 # type: () -> Any 

1111 if isinstance(self.rd, ObjectPipe): 

1112 return self.rd.fileno() 

1113 return self.rd 

1114 

1115 def recv(self, n=None): 

1116 # type: (Optional[int]) -> Any 

1117 return self.rd.recv(n) 

1118 

1119 def read(self, n=None): 

1120 # type: (Optional[int]) -> Any 

1121 return self.recv(n) 

1122 

1123 def send(self, msg): 

1124 # type: (str) -> int 

1125 return self.wr.send(msg) 

1126 

1127 def write(self, msg): 

1128 # type: (str) -> int 

1129 return self.send(msg) 

1130 

1131 class AutomatonException(Exception): 

1132 def __init__(self, msg, state=None, result=None): 

1133 # type: (str, Optional[Message], Optional[str]) -> None 

1134 Exception.__init__(self, msg) 

1135 self.state = state 

1136 self.result = result 

1137 

1138 class AutomatonError(AutomatonException): 

1139 pass 

1140 

1141 class ErrorState(AutomatonException): 

1142 pass 

1143 

1144 class Stuck(AutomatonException): 

1145 pass 

1146 

1147 class AutomatonStopped(AutomatonException): 

1148 pass 

1149 

1150 class Breakpoint(AutomatonStopped): 

1151 pass 

1152 

1153 class Singlestep(AutomatonStopped): 

1154 pass 

1155 

1156 class InterceptionPoint(AutomatonStopped): 

1157 def __init__(self, msg, state=None, result=None, packet=None): 

1158 # type: (str, Optional[Message], Optional[str], Optional[Packet]) -> None 

1159 Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result) 

1160 self.packet = packet 

1161 

1162 class CommandMessage(AutomatonException): 

1163 pass 

1164 

1165 # Services 

1166 def debug(self, lvl, msg): 

1167 # type: (int, str) -> None 

1168 if self.debug_level >= lvl: 

1169 log_runtime.debug(msg) 

1170 

1171 def isrunning(self): 

1172 # type: () -> bool 

1173 return self.started.locked() 

1174 

1175 def send(self, pkt, **kwargs): 

1176 # type: (Packet, **Any) -> None 

1177 if self.state.state in self.interception_points: 

1178 self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary()) 

1179 self.intercepted_packet = pkt 

1180 self.cmdout.send( 

1181 Message(type=_ATMT_Command.INTERCEPT, 

1182 state=self.state, pkt=pkt) 

1183 ) 

1184 cmd = self.cmdin.recv() 

1185 if not cmd: 

1186 self.debug(3, "CANCELLED") 

1187 return 

1188 self.intercepted_packet = None 

1189 if cmd.type == _ATMT_Command.REJECT: 

1190 self.debug(3, "INTERCEPT: packet rejected") 

1191 return 

1192 elif cmd.type == _ATMT_Command.REPLACE: 

1193 pkt = cmd.pkt 

1194 self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary()) # noqa: E501 

1195 elif cmd.type == _ATMT_Command.ACCEPT: 

1196 self.debug(3, "INTERCEPT: packet accepted") 

1197 else: 

1198 raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type) # noqa: E501 

1199 self.my_send(pkt, **kwargs) 

1200 self.debug(3, "SENT : %s" % pkt.summary()) 

1201 

1202 if self.store_packets: 

1203 self.packets.append(pkt.copy()) 

1204 

1205 def __iter__(self): 

1206 # type: () -> Automaton 

1207 return self 

1208 

1209 def __del__(self): 

1210 # type: () -> None 

1211 self.destroy() 

1212 

1213 def _run_condition(self, cond, *args, **kargs): 

1214 # type: (_StateWrapper, Any, Any) -> None 

1215 try: 

1216 self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 

1217 cond(self, *args, **kargs) 

1218 except ATMT.NewStateRequested as state_req: 

1219 self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) # noqa: E501 

1220 if cond.atmt_type == ATMT.RECV: 

1221 if self.store_packets: 

1222 self.packets.append(args[0]) 

1223 for action in self.actions[cond.atmt_condname]: 

1224 self.debug(2, " + Running action [%s]" % action.__name__) 

1225 action(self, *state_req.action_args, **state_req.action_kargs) 

1226 raise 

1227 except Exception as e: 

1228 self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) # noqa: E501 

1229 raise 

1230 else: 

1231 self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 

1232 

1233 def _do_start(self, *args, **kargs): 

1234 # type: (Any, Any) -> None 

1235 ready = threading.Event() 

1236 _t = threading.Thread( 

1237 target=self._do_control, 

1238 args=(ready,) + (args), 

1239 kwargs=kargs, 

1240 name="scapy.automaton _do_start" 

1241 ) 

1242 _t.daemon = True 

1243 _t.start() 

1244 ready.wait() 

1245 

1246 def _do_control(self, ready, *args, **kargs): 

1247 # type: (threading.Event, Any, Any) -> None 

1248 with self.started: 

1249 self.threadid = threading.current_thread().ident 

1250 if self.threadid is None: 

1251 self.threadid = 0 

1252 

1253 # Update default parameters 

1254 a = args + self.init_args[len(args):] 

1255 k = self.init_kargs.copy() 

1256 k.update(kargs) 

1257 self.parse_args(*a, **k) 

1258 

1259 # Start the automaton 

1260 self.state = self.initial_states[0](self) 

1261 self.send_sock = self.sock or self.send_sock_class(**self.socket_kargs) 

1262 if self.recv_conditions: 

1263 # Only start a receiving socket if we have at least one recv_conditions 

1264 self.listen_sock = self.sock or self.recv_sock_class(**self.socket_kargs) # noqa: E501 

1265 self.packets = PacketList(name="session[%s]" % self.__class__.__name__) 

1266 

1267 singlestep = True 

1268 iterator = self._do_iter() 

1269 self.debug(3, "Starting control thread [tid=%i]" % self.threadid) 

1270 # Sync threads 

1271 ready.set() 

1272 try: 

1273 while True: 

1274 c = self.cmdin.recv() 

1275 if c is None: 

1276 return None 

1277 self.debug(5, "Received command %s" % c.type) 

1278 if c.type == _ATMT_Command.RUN: 

1279 singlestep = False 

1280 elif c.type == _ATMT_Command.NEXT: 

1281 singlestep = True 

1282 elif c.type == _ATMT_Command.FREEZE: 

1283 continue 

1284 elif c.type == _ATMT_Command.STOP: 

1285 if self.stop_state: 

1286 # There is a stop state 

1287 self.state = self.stop_state() 

1288 iterator = self._do_iter() 

1289 else: 

1290 # Act as FORCESTOP 

1291 break 

1292 elif c.type == _ATMT_Command.FORCESTOP: 

1293 break 

1294 while True: 

1295 state = next(iterator) 

1296 if isinstance(state, self.CommandMessage): 

1297 break 

1298 elif isinstance(state, self.Breakpoint): 

1299 c = Message(type=_ATMT_Command.BREAKPOINT, state=state) # noqa: E501 

1300 self.cmdout.send(c) 

1301 break 

1302 if singlestep: 

1303 c = Message(type=_ATMT_Command.SINGLESTEP, state=state) # noqa: E501 

1304 self.cmdout.send(c) 

1305 break 

1306 except (StopIteration, RuntimeError): 

1307 c = Message(type=_ATMT_Command.END, 

1308 result=self.final_state_output) 

1309 self.cmdout.send(c) 

1310 except Exception as e: 

1311 exc_info = sys.exc_info() 

1312 self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, "".join(traceback.format_exception(*exc_info)))) # noqa: E501 

1313 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) # noqa: E501 

1314 self.cmdout.send(m) 

1315 self.debug(3, "Stopping control thread (tid=%i)" % self.threadid) 

1316 self.threadid = None 

1317 if self.listen_sock: 

1318 self.listen_sock.close() 

1319 if self.send_sock: 

1320 self.send_sock.close() 

1321 

1322 def _do_iter(self): 

1323 # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, ATMT.NewStateRequested, None]] # noqa: E501 

1324 while True: 

1325 try: 

1326 self.debug(1, "## state=[%s]" % self.state.state) 

1327 

1328 # Entering a new state. First, call new state function 

1329 if self.state.state in self.breakpoints and self.state.state != self.breakpointed: # noqa: E501 

1330 self.breakpointed = self.state.state 

1331 yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, # noqa: E501 

1332 state=self.state.state) 

1333 self.breakpointed = None 

1334 state_output = self.state.run() 

1335 if self.state.error: 

1336 raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), # noqa: E501 

1337 result=state_output, state=self.state.state) # noqa: E501 

1338 if self.state.final: 

1339 self.final_state_output = state_output 

1340 return 

1341 

1342 if state_output is None: 

1343 state_output = () 

1344 elif not isinstance(state_output, list): 

1345 state_output = state_output, 

1346 

1347 timers = self.timeout[self.state.state] 

1348 # If there are commandMessage, we should skip immediate 

1349 # conditions. 

1350 if not select_objects([self.cmdin], 0): 

1351 # Then check immediate conditions 

1352 for cond in self.conditions[self.state.state]: 

1353 self._run_condition(cond, *state_output) 

1354 

1355 # If still there and no conditions left, we are stuck! 

1356 if (len(self.recv_conditions[self.state.state]) == 0 and 

1357 len(self.ioevents[self.state.state]) == 0 and 

1358 timers.count() == 0): 

1359 raise self.Stuck("stuck in [%s]" % self.state.state, 

1360 state=self.state.state, 

1361 result=state_output) 

1362 

1363 # Finally listen and pay attention to timeouts 

1364 timers.reset() 

1365 time_previous = time.time() 

1366 

1367 fds = [self.cmdin] # type: List[Union[SuperSocket, ObjectPipe[Any]]] 

1368 select_func = select_objects 

1369 if self.listen_sock and self.recv_conditions[self.state.state]: 

1370 fds.append(self.listen_sock) 

1371 select_func = self.listen_sock.select # type: ignore 

1372 for ioev in self.ioevents[self.state.state]: 

1373 fds.append(self.ioin[ioev.atmt_ioname]) 

1374 while True: 

1375 time_current = time.time() 

1376 timers.decrement(time_current - time_previous) 

1377 time_previous = time_current 

1378 for timer in timers.expired(): 

1379 self._run_condition(timer._func, *state_output) 

1380 remain = timers.until_next() 

1381 

1382 self.debug(5, "Select on %r" % fds) 

1383 r = select_func(fds, remain) 

1384 self.debug(5, "Selected %r" % r) 

1385 for fd in r: 

1386 self.debug(5, "Looking at %r" % fd) 

1387 if fd == self.cmdin: 

1388 yield self.CommandMessage("Received command message") # noqa: E501 

1389 elif fd == self.listen_sock: 

1390 try: 

1391 pkt = self.listen_sock.recv() 

1392 except EOFError: 

1393 # Socket was closed abruptly. This will likely only 

1394 # ever happen when a client socket is passed to the 

1395 # automaton (not the case when the automaton is 

1396 # listening on a promiscuous conf.L2sniff) 

1397 self.listen_sock.close() 

1398 # False so that it is still reset by update_sock 

1399 self.listen_sock = False # type: ignore 

1400 fds.remove(fd) 

1401 if self.state.state in self.eofs: 

1402 # There is an eof state 

1403 eof = self.eofs[self.state.state] 

1404 self.debug(2, "Condition EOF [%s] taken" % eof.__name__) # noqa: E501 

1405 raise self.eofs[self.state.state](self) 

1406 else: 

1407 # There isn't. Therefore, it's a closing condition. 

1408 raise EOFError("Socket ended arbruptly.") 

1409 if pkt is not None: 

1410 if self.master_filter(pkt): 

1411 self.debug(3, "RECVD: %s" % pkt.summary()) # noqa: E501 

1412 for rcvcond in self.recv_conditions[self.state.state]: # noqa: E501 

1413 self._run_condition(rcvcond, pkt, *state_output) # noqa: E501 

1414 else: 

1415 self.debug(4, "FILTR: %s" % pkt.summary()) # noqa: E501 

1416 else: 

1417 self.debug(3, "IOEVENT on %s" % fd.ioname) 

1418 for ioevt in self.ioevents[self.state.state]: 

1419 if ioevt.atmt_ioname == fd.ioname: 

1420 self._run_condition(ioevt, fd, *state_output) # noqa: E501 

1421 

1422 except ATMT.NewStateRequested as state_req: 

1423 self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state)) # noqa: E501 

1424 self.state = state_req 

1425 yield state_req 

1426 

1427 def __repr__(self): 

1428 # type: () -> str 

1429 return "<Automaton %s [%s]>" % ( 

1430 self.__class__.__name__, 

1431 ["HALTED", "RUNNING"][self.isrunning()] 

1432 ) 

1433 

1434 # Public API 

1435 def add_interception_points(self, *ipts): 

1436 # type: (Any) -> None 

1437 for ipt in ipts: 

1438 if hasattr(ipt, "atmt_state"): 

1439 ipt = ipt.atmt_state 

1440 self.interception_points.add(ipt) 

1441 

1442 def remove_interception_points(self, *ipts): 

1443 # type: (Any) -> None 

1444 for ipt in ipts: 

1445 if hasattr(ipt, "atmt_state"): 

1446 ipt = ipt.atmt_state 

1447 self.interception_points.discard(ipt) 

1448 

1449 def add_breakpoints(self, *bps): 

1450 # type: (Any) -> None 

1451 for bp in bps: 

1452 if hasattr(bp, "atmt_state"): 

1453 bp = bp.atmt_state 

1454 self.breakpoints.add(bp) 

1455 

1456 def remove_breakpoints(self, *bps): 

1457 # type: (Any) -> None 

1458 for bp in bps: 

1459 if hasattr(bp, "atmt_state"): 

1460 bp = bp.atmt_state 

1461 self.breakpoints.discard(bp) 

1462 

1463 def start(self, *args, **kargs): 

1464 # type: (Any, Any) -> None 

1465 if self.isrunning(): 

1466 raise ValueError("Already started") 

1467 # Start the control thread 

1468 self._do_start(*args, **kargs) 

1469 

1470 def run(self, 

1471 resume=None, # type: Optional[Message] 

1472 wait=True # type: Optional[bool] 

1473 ): 

1474 # type: (...) -> Any 

1475 if resume is None: 

1476 resume = Message(type=_ATMT_Command.RUN) 

1477 self.cmdin.send(resume) 

1478 if wait: 

1479 try: 

1480 c = self.cmdout.recv() 

1481 if c is None: 

1482 return None 

1483 except KeyboardInterrupt: 

1484 self.cmdin.send(Message(type=_ATMT_Command.FREEZE)) 

1485 return None 

1486 if c.type == _ATMT_Command.END: 

1487 return c.result 

1488 elif c.type == _ATMT_Command.INTERCEPT: 

1489 raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) # noqa: E501 

1490 elif c.type == _ATMT_Command.SINGLESTEP: 

1491 raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state) # noqa: E501 

1492 elif c.type == _ATMT_Command.BREAKPOINT: 

1493 raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501 

1494 elif c.type == _ATMT_Command.EXCEPTION: 

1495 # this code comes from the `six` module (`.reraise()`) 

1496 # to raise an exception with specified exc_info. 

1497 value = c.exc_info[0]() if c.exc_info[1] is None else c.exc_info[1] # type: ignore # noqa: E501 

1498 if value.__traceback__ is not c.exc_info[2]: 

1499 raise value.with_traceback(c.exc_info[2]) 

1500 raise value 

1501 return None 

1502 

1503 def runbg(self, resume=None, wait=False): 

1504 # type: (Optional[Message], Optional[bool]) -> None 

1505 self.run(resume, wait) 

1506 

1507 def __next__(self): 

1508 # type: () -> Any 

1509 return self.run(resume=Message(type=_ATMT_Command.NEXT)) 

1510 

1511 def _flush_inout(self): 

1512 # type: () -> None 

1513 # Flush command pipes 

1514 for cmd in [self.cmdin, self.cmdout]: 

1515 cmd.clear() 

1516 

1517 def destroy(self): 

1518 # type: () -> None 

1519 """ 

1520 Destroys a stopped Automaton: this cleanups all opened file descriptors. 

1521 Required on PyPy for instance where the garbage collector behaves differently. 

1522 """ 

1523 if not hasattr(self, "started"): 

1524 return # was never started. 

1525 if self.isrunning(): 

1526 raise ValueError("Can't close running Automaton ! Call stop() beforehand") 

1527 # Close command pipes 

1528 self.cmdin.close() 

1529 self.cmdout.close() 

1530 self._flush_inout() 

1531 # Close opened ioins/ioouts 

1532 for i in itertools.chain(self.ioin.values(), self.ioout.values()): 

1533 if isinstance(i, ObjectPipe): 

1534 i.close() 

1535 

1536 def stop(self, wait=True): 

1537 # type: (bool) -> None 

1538 try: 

1539 self.cmdin.send(Message(type=_ATMT_Command.STOP)) 

1540 except OSError: 

1541 pass 

1542 if wait: 

1543 with self.started: 

1544 self._flush_inout() 

1545 

1546 def forcestop(self, wait=True): 

1547 # type: (bool) -> None 

1548 try: 

1549 self.cmdin.send(Message(type=_ATMT_Command.FORCESTOP)) 

1550 except OSError: 

1551 pass 

1552 if wait: 

1553 with self.started: 

1554 self._flush_inout() 

1555 

1556 def restart(self, *args, **kargs): 

1557 # type: (Any, Any) -> None 

1558 self.stop() 

1559 self.start(*args, **kargs) 

1560 

1561 def accept_packet(self, 

1562 pkt=None, # type: Optional[Packet] 

1563 wait=False # type: Optional[bool] 

1564 ): 

1565 # type: (...) -> Any 

1566 rsm = Message() 

1567 if pkt is None: 

1568 rsm.type = _ATMT_Command.ACCEPT 

1569 else: 

1570 rsm.type = _ATMT_Command.REPLACE 

1571 rsm.pkt = pkt 

1572 return self.run(resume=rsm, wait=wait) 

1573 

1574 def reject_packet(self, 

1575 wait=False # type: Optional[bool] 

1576 ): 

1577 # type: (...) -> Any 

1578 rsm = Message(type=_ATMT_Command.REJECT) 

1579 return self.run(resume=rsm, wait=wait)