Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/atheris/instrument_bytecode.py: 85%

462 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1# Copyright 2021 Google LLC 

2# Copyright 2021 Fraunhofer FKIE 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15"""This module provides the instrumentation functionality for atheris. 

16 

17Mainly the function patch_code(), which can instrument a code object and the 

18helper class Instrumentor. 

19""" 

20import collections 

21import dis 

22import gc 

23import sys 

24import types 

25from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union 

26 

27from . import utils 

28from .native import _reserve_counter # type: ignore[import] 

29from .version_dependent import add_bytes_to_jump_arg 

30from .version_dependent import CONDITIONAL_JUMPS 

31from .version_dependent import ENDS_FUNCTION 

32from .version_dependent import get_code_object 

33from .version_dependent import get_lnotab 

34from .version_dependent import HAVE_ABS_REFERENCE 

35from .version_dependent import HAVE_REL_REFERENCE 

36from .version_dependent import REL_REFERENCE_IS_INVERTED 

37from .version_dependent import rel_reference_scale 

38from .version_dependent import jump_arg_bytes 

39from .version_dependent import REVERSE_CMP_OP 

40from .version_dependent import UNCONDITIONAL_JUMPS 

41from .version_dependent import rot_n 

42from .version_dependent import call 

43from .version_dependent import cache_count 

44from .version_dependent import caches 

45from .version_dependent import get_instructions 

46from .version_dependent import generate_exceptiontable 

47from .version_dependent import parse_exceptiontable 

48from .version_dependent import ExceptionTable 

49from .version_dependent import ExceptionTableEntry 

50from .version_dependent import args_terminator 

51from .version_dependent import CALLABLE_STACK_ENTRIES 

52 

53_TARGET_MODULE = "atheris" 

54_COVERAGE_FUNCTION = "_trace_branch" 

55_COMPARE_FUNCTION = "_trace_cmp" 

56 

57# TODO(b/207008147): Use NewType to differentiate the many int and str types. 

58 

59 

60class Instruction: 

61 """A single bytecode instruction after every EXTENDED_ARG has been resolved. 

62 

63 It is assumed that all instructions are always 2*n bytes long. 

64 

65 Sometimes the Python-Interpreter pads instructions with 'EXTENDED_ARG 0' 

66 so instructions must have a minimum size. 

67 

68 Attributes: 

69 lineno: 

70 Line number in the original source code. 

71 offset: 

72 Offset of an instruction in bytes. 

73 opcode: 

74 Integer identifier of the bytecode operation. 

75 mnemonic: 

76 Human readable name of the opcode. 

77 arg: 

78 Optional (default 0) argument to the instruction. This may index into 

79 CodeType.co_consts or it may be the address for jump instructions. 

80 reference: 

81 For jump instructions, the absolute address in bytes of the target. For 

82 other instructions, None. 

83 """ 

84 

85 @classmethod 

86 def get_fixed_size(cls) -> int: 

87 return 2 

88 

89 def __init__( 

90 self, 

91 lineno: int, 

92 offset: int, 

93 opcode: int, 

94 arg: int = 0, 

95 min_size: int = 0, 

96 positions=None, 

97 ): 

98 self.lineno = lineno 

99 self.offset = offset 

100 self.opcode = opcode 

101 self.mnemonic = dis.opname[opcode] 

102 self.arg = arg 

103 self._min_size = min_size 

104 self.positions = positions 

105 

106 if self.mnemonic in HAVE_REL_REFERENCE: 

107 self._is_relative: Optional[bool] = True 

108 self.reference: Optional[int] = ( 

109 self.offset 

110 + self.get_size() 

111 + jump_arg_bytes(self.arg) * rel_reference_scale(self.mnemonic) 

112 ) 

113 elif self.mnemonic in HAVE_ABS_REFERENCE: 

114 self._is_relative = False 

115 self.reference = jump_arg_bytes(self.arg) 

116 else: 

117 self._is_relative = None 

118 self.reference = None 

119 

120 self.check_state() 

121 

122 def __repr__(self) -> str: 

123 return ( 

124 f"{self.mnemonic}(arg={self.arg} offset={self.offset} " 

125 + f"reference={self.reference} getsize={self.get_size()} positions={self.positions})" 

126 ) 

127 

128 def has_argument(self) -> bool: 

129 return self.opcode >= dis.HAVE_ARGUMENT 

130 

131 def _get_arg_size(self) -> int: 

132 if self.arg >= (1 << 24): 

133 return 8 

134 elif self.arg >= (1 << 16): 

135 return 6 

136 elif self.arg >= (1 << 8): 

137 return 4 

138 else: 

139 return 2 

140 

141 def get_size(self) -> int: 

142 return max(self._get_arg_size(), self._min_size) 

143 

144 def get_stack_effect(self) -> int: 

145 # dis.stack_effect does not work for EXTENDED_ARG and NOP 

146 if self.mnemonic in ["EXTENDED_ARG", "NOP"]: 

147 return 0 

148 

149 return dis.stack_effect(self.opcode, 

150 (self.arg if self.has_argument() else None)) 

151 

152 def to_bytes(self) -> bytes: 

153 """Returns this instruction as bytes.""" 

154 size = self._get_arg_size() 

155 arg = self.arg 

156 ret = [self.opcode, arg & 0xff] 

157 

158 for _ in range(size // 2 - 1): 

159 arg >>= 8 

160 ret = [dis.opmap["EXTENDED_ARG"], arg & 0xff] + ret 

161 

162 while len(ret) < self._min_size: 

163 ret = [dis.opmap["EXTENDED_ARG"], 0] + ret 

164 

165 assert len(ret) == self.get_size() 

166 

167 return bytes(ret) 

168 

169 def adjust(self, changed_offset: int, size: int, keep_ref: bool) -> None: 

170 """Compensates the offsets in this instruction for a resize elsewhere. 

171 

172 Relative offsets may be invalidated due to two main events: 

173 (1) Insertion of instructions 

174 (2) Change of size of a single, already existing instruction 

175 

176 (1) Some instructions of size `size` (in bytes) have been inserted at offset 

177 `changed_offset` in the instruction listing. 

178 

179 (2) An instruction at offset changed_offset` - 0.5 has increased in size. 

180 If `changed_offset` is self.offset + 0.5, then self has increased. 

181 

182 Either way, adjust the current offset, reference and argument 

183 accordingly. 

184 

185 TODO(aidenhall): Replace the pattern of using +0.5 as a sentinal. 

186 

187 Args: 

188 changed_offset: The offset where instructions are inserted. 

189 size: The number of bytes of instructions inserted. 

190 keep_ref: if True, adjust our reference. 

191 """ 

192 old_offset = self.offset 

193 old_reference = self.reference 

194 

195 if old_offset < changed_offset < (old_offset + 1): 

196 if old_reference is not None: 

197 if self._is_relative: 

198 if self.mnemonic not in REL_REFERENCE_IS_INVERTED: 

199 self.reference += size # type: ignore[operator] 

200 else: 

201 self.arg = add_bytes_to_jump_arg(self.arg, size) 

202 elif old_reference > old_offset: 

203 self.reference += size # type: ignore[operator] 

204 self.arg = add_bytes_to_jump_arg(self.arg, size) 

205 

206 return 

207 

208 if changed_offset <= old_offset: 

209 self.offset += size 

210 

211 if old_reference is not None and not keep_ref: 

212 if changed_offset <= old_reference: 

213 self.reference += size # type: ignore[operator] 

214 

215 if self._is_relative: 

216 if self.mnemonic not in REL_REFERENCE_IS_INVERTED and ( 

217 old_offset < changed_offset <= old_reference 

218 ): 

219 self.arg = add_bytes_to_jump_arg(self.arg, size) 

220 elif self.mnemonic in REL_REFERENCE_IS_INVERTED and ( 

221 old_offset >= changed_offset >= old_reference 

222 ): 

223 self.arg = add_bytes_to_jump_arg(self.arg, size) 

224 else: 

225 if changed_offset <= old_reference: 

226 self.arg = add_bytes_to_jump_arg(self.arg, size) 

227 

228 def check_state(self) -> None: 

229 """Asserts that internal state is consistent.""" 

230 assert self.mnemonic != "EXTENDED_ARG" 

231 assert 0 <= self.arg <= 0x7fffffff 

232 assert 0 <= self.opcode < 256 

233 

234 if self.reference is not None: 

235 if self._is_relative: 

236 assert ( 

237 self.offset 

238 + self.get_size() 

239 + jump_arg_bytes(self.arg) * rel_reference_scale(self.mnemonic) 

240 == self.reference 

241 ) 

242 else: 

243 assert jump_arg_bytes(self.arg) == self.reference 

244 

245 def is_jump(self) -> bool: 

246 return self.mnemonic in CONDITIONAL_JUMPS or self.mnemonic in UNCONDITIONAL_JUMPS 

247 

248 def make_nop(self) -> None: 

249 self.opcode = dis.opmap["NOP"] 

250 self.mnemonic = "NOP" 

251 self.arg = 0 

252 self._is_relative = None 

253 self.reference = None 

254 self.check_state() 

255 

256 def cache_count(self) -> int: 

257 return cache_count(self.opcode) 

258 

259 

260class BasicBlock: 

261 """A block of bytecode instructions and the adresses it may jump to.""" 

262 

263 def __init__(self, instructions: List[Instruction], last_one: bool): 

264 self.instructions = instructions 

265 self.id = instructions[0].offset 

266 

267 last_instr = instructions[-1] 

268 

269 if last_one or last_instr.mnemonic in ENDS_FUNCTION: 

270 self.edges = [] 

271 elif last_instr.mnemonic in CONDITIONAL_JUMPS: 

272 self.edges = list( 

273 {last_instr.reference, last_instr.offset + last_instr.get_size()}) 

274 else: 

275 if last_instr.reference is not None: 

276 self.edges = [last_instr.reference] 

277 else: 

278 self.edges = [last_instr.offset + last_instr.get_size()] 

279 

280 def __iter__(self) -> Iterator[Instruction]: 

281 return iter(self.instructions) 

282 

283 def __repr__(self) -> str: 

284 return (f"BasicBlock(id={self.id}, edges={self.edges}, " + 

285 f"instructions={self.instructions})") 

286 

287 

288_SizeAndInstructions = Tuple[int, List[Instruction]] 

289 

290 

291class Instrumentor: 

292 """Implements the core instrumentation functionality. 

293 

294 It gets a single code object, builds a CFG of the bytecode and 

295 can instrument the code for coverage collection via trace_control_flow() 

296 and for data-flow tracing via trace_data_flow(). 

297 

298 How to insert code: 

299 1. Select a target basic block 

300 2. Build up the new code as a list of `Instruction` objects. 

301 Make sure to get the offsets right. 

302 3. Calculate the overall size needed by your new code (in bytes) 

303 4. Call _adjust() with your target offset and calculated size 

304 5. Insert your instruction list into the instruction list of the basic 

305 block 

306 6. Call _handle_size_changes() 

307 Take a look at trace_control_flow() and trace_data_flow() for examples. 

308 

309 Note that Instrumentor only supports insertions, not deletions. 

310 """ 

311 

312 def __init__(self, code: types.CodeType): 

313 self._cfg: collections.OrderedDict = collections.OrderedDict() 

314 self.consts = list(code.co_consts) 

315 self._names = list(code.co_names) 

316 self.num_counters = 0 

317 self._code = code 

318 

319 self._build_cfg() 

320 self._check_state() 

321 

322 def _insert_instruction(self, to_insert, lineno, offset, opcode, arg=0): 

323 to_insert.append(Instruction(lineno, offset, opcode, arg)) 

324 offset += to_insert[-1].get_size() 

325 return self._insert_instructions(to_insert, lineno, offset, caches(opcode)) 

326 

327 def _insert_instructions(self, to_insert, lineno, offset, tuples): 

328 for t in tuples: 

329 offset = self._insert_instruction(to_insert, lineno, offset, t[0], t[1]) 

330 return offset 

331 

332 def _build_cfg(self) -> None: 

333 """Builds control flow graph.""" 

334 lineno = self._code.co_firstlineno 

335 arg = None 

336 offset = None 

337 length = Instruction.get_fixed_size() 

338 instr_list = [] 

339 basic_block_borders = [] 

340 did_jump = False 

341 jump_targets = set() 

342 

343 self.exception_table = parse_exceptiontable(self._code) 

344 

345 for instruction in get_instructions(self._code): 

346 if instruction.starts_line is not None: 

347 lineno = instruction.starts_line 

348 

349 if instruction.opname == "EXTENDED_ARG": 

350 if arg is None: 

351 arg = 0 

352 offset = instruction.offset 

353 

354 arg <<= 8 

355 arg |= instruction.arg # type: ignore[operator] 

356 length += Instruction.get_fixed_size() # type: ignore[operator] 

357 

358 continue 

359 

360 elif arg is not None: 

361 assert offset is not None 

362 combined_arg = 0 

363 # https://bugs.python.org/issue45757 can cause .arg to be None 

364 if instruction.arg is not None: 

365 combined_arg = (arg << 8) | instruction.arg # type: ignore[operator] 

366 instr_list.append( 

367 Instruction( 

368 lineno, 

369 offset, 

370 instruction.opcode, 

371 combined_arg, 

372 min_size=length, 

373 positions=getattr(instruction, "positions", None), 

374 ) 

375 ) 

376 arg = None 

377 offset = None 

378 length = Instruction.get_fixed_size() 

379 

380 else: 

381 instr_list.append( 

382 Instruction( 

383 lineno, 

384 instruction.offset, 

385 instruction.opcode, 

386 instruction.arg or 0, 

387 positions=getattr(instruction, "positions", None), 

388 ) 

389 ) 

390 

391 if instr_list[-1].reference is not None: 

392 jump_targets.add(instr_list[-1].reference) 

393 

394 for c, instr in enumerate(instr_list): 

395 if instr.offset == 0 or instr.offset in jump_targets or did_jump: 

396 basic_block_borders.append(c) 

397 

398 if instr.is_jump(): 

399 did_jump = True 

400 else: 

401 did_jump = False 

402 

403 basic_block_borders.append(len(instr_list)) 

404 

405 for i in range(len(basic_block_borders) - 1): 

406 start_of_bb = basic_block_borders[i] 

407 end_of_bb = basic_block_borders[i + 1] 

408 bb = BasicBlock(instr_list[start_of_bb:end_of_bb], 

409 i == len(basic_block_borders) - 2) 

410 self._cfg[bb.id] = bb 

411 

412 def _check_state(self) -> None: 

413 """Asserts that the Instrumentor is in a valid state.""" 

414 assert self._cfg, "Control flow graph empty." 

415 seen_ids = set() 

416 

417 for basic_block in self._cfg.values(): 

418 assert basic_block.instructions, "BasicBlock has no instructions." 

419 

420 assert basic_block.id not in seen_ids 

421 seen_ids.add(basic_block.id) 

422 

423 for edge in basic_block.edges: 

424 assert edge in self._cfg, ( 

425 f"{basic_block} has an edge, {edge}, not in CFG {self._cfg}.") 

426 

427 listing = self._get_linear_instruction_listing() 

428 i = 0 

429 

430 assert listing[0].offset == 0 

431 

432 while i < len(listing) - 1: 

433 assert (listing[i].offset + listing[i].get_size() == listing[i + 

434 1].offset) 

435 listing[i].check_state() 

436 i += 1 

437 

438 def _get_name(self, name: str) -> int: 

439 """Returns an offset to `name` in co_names, appending if necessary.""" 

440 try: 

441 return self._names.index(name) 

442 except ValueError: 

443 self._names.append(name) 

444 return len(self._names) - 1 

445 

446 def _get_const(self, constant: Union[int, types.ModuleType]) -> int: 

447 """Returns the index of `constant` in self.consts, inserting if needed.""" 

448 for i in range(len(self.consts)): 

449 if isinstance(self.consts[i], 

450 type(constant)) and self.consts[i] == constant: 

451 return i 

452 

453 self.consts.append(constant) 

454 return len(self.consts) - 1 

455 

456 def _get_counter(self) -> int: 

457 counter = _reserve_counter() 

458 return self._get_const(counter) 

459 

460 def _adjust(self, offset: float, size: int, *keep_refs: str) -> None: 

461 """Adjust for `size` bytes of instructions inserted at `offset`. 

462 

463 Signal all instructions that some instructions of size `size` (in bytes) 

464 will be inserted at offset `offset`. Sometimes it is necessary that some 

465 instructions do not change their reference when a new insertion happens. 

466 

467 All those Instruction-objects whose reference shall not change must be 

468 in `keep_refs`. 

469 

470 Args: 

471 offset: Location that new instructions are inserted at 

472 size: How many bytes of new instructions are being inserted. 

473 *keep_refs: The Instructions whose reference shall not change. 

474 """ 

475 for basic_block in self._cfg.values(): 

476 for instr in basic_block: 

477 instr.adjust(offset, size, instr in keep_refs) 

478 

479 entry: ExceptionTableEntry 

480 for entry in self.exception_table.entries: 

481 if entry.start_offset > offset: 

482 entry.start_offset += size 

483 if entry.end_offset >= offset: 

484 entry.end_offset += size 

485 if entry.target > offset: 

486 entry.target += size 

487 

488 def _handle_size_changes(self) -> None: 

489 """Fixes instructions who's size increased with the last insertion. 

490 

491 After insertions have been made it could be that the argument of some 

492 instructions crossed certain boundaries so that more EXTENDED_ARGs are 

493 required to build the oparg. This function identifies all of those 

494 instructions whose size increased with the latest insertion and adjusts all 

495 other instructions to the new size. 

496 """ 

497 listing = self._get_linear_instruction_listing() 

498 

499 while True: 

500 found_invalid = False 

501 i = 0 

502 

503 while i < len(listing) - 1: 

504 next_offset = listing[i].offset + listing[i].get_size() 

505 

506 assert next_offset >= listing[i + 1].offset, ( 

507 "Something weird happened with the offsets at offset " + 

508 f"{listing[i].offset}") 

509 

510 if next_offset > listing[i + 1].offset: 

511 delta = next_offset - listing[i + 1].offset 

512 self._adjust(listing[i].offset + 0.5, delta) 

513 found_invalid = True 

514 

515 i += 1 

516 

517 if not found_invalid: 

518 break 

519 

520 def _get_linear_instruction_listing(self) -> List[Instruction]: 

521 listing = [] 

522 for basic_block in self._cfg.values(): 

523 for instr in basic_block: 

524 listing.append(instr) 

525 return listing 

526 

527 def to_code(self) -> types.CodeType: 

528 """Returns the instrumented code object.""" 

529 self._check_state() 

530 listing = self._get_linear_instruction_listing() 

531 code = bytes() 

532 stacksize = 0 

533 

534 for instr in listing: 

535 code += instr.to_bytes() 

536 stacksize = max(stacksize, stacksize + instr.get_stack_effect()) 

537 

538 co_exceptiontable = generate_exceptiontable( 

539 self._code, self.exception_table.entries 

540 ) 

541 

542 return get_code_object( 

543 self._code, 

544 stacksize, 

545 code, 

546 tuple(self.consts + ["__ATHERIS_INSTRUMENTED__"]), 

547 tuple(self._names), 

548 get_lnotab(self._code, listing), 

549 co_exceptiontable, 

550 ) 

551 

552 def _generate_trace_branch_invocation(self, lineno: int, 

553 offset: int) -> _SizeAndInstructions: 

554 """Builds the bytecode that calls atheris._trace_branch().""" 

555 to_insert = [] 

556 start_offset = offset 

557 const_atheris = self._get_const(sys.modules[_TARGET_MODULE]) 

558 name_cov = self._get_name(_COVERAGE_FUNCTION) 

559 

560 offset = self._insert_instructions( 

561 to_insert, lineno, offset, args_terminator() 

562 ) 

563 

564 offset = self._insert_instruction( 

565 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_atheris 

566 ) 

567 offset = self._insert_instruction( 

568 to_insert, lineno, offset, dis.opmap["LOAD_ATTR"], name_cov 

569 ) 

570 

571 offset = self._insert_instruction( 

572 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], self._get_counter() 

573 ) 

574 

575 offset = self._insert_instructions(to_insert, lineno, offset, call(1)) 

576 offset = self._insert_instruction( 

577 to_insert, lineno, offset, dis.opmap["POP_TOP"], 0 

578 ) 

579 

580 return offset - start_offset, to_insert 

581 

582 def _generate_cmp_invocation(self, op: int, lineno: int, 

583 offset: int) -> _SizeAndInstructions: 

584 """Builds the bytecode that calls atheris._trace_cmp(). 

585 

586 Only call this if the two objects being compared are non-constants. 

587 

588 Args: 

589 op: The comparison operation 

590 lineno: The line number of the operation 

591 offset: The offset to the operation instruction 

592 

593 Returns: 

594 The size of the instructions to insert, 

595 The instructions to insert 

596 """ 

597 to_insert = [] 

598 start_offset = offset 

599 const_atheris = self._get_const(sys.modules[_TARGET_MODULE]) 

600 name_cmp = self._get_name(_COMPARE_FUNCTION) 

601 const_op = self._get_const(op) 

602 const_counter = self._get_counter() 

603 const_false = self._get_const(False) 

604 

605 offset = self._insert_instructions( 

606 to_insert, lineno, offset, args_terminator() 

607 ) 

608 offset = self._insert_instruction( 

609 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_atheris 

610 ) 

611 offset = self._insert_instruction( 

612 to_insert, lineno, offset, dis.opmap["LOAD_ATTR"], name_cmp 

613 ) 

614 rot = rot_n(2 + CALLABLE_STACK_ENTRIES, CALLABLE_STACK_ENTRIES) 

615 offset = self._insert_instructions(to_insert, lineno, offset, rot) 

616 

617 offset = self._insert_instruction( 

618 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_op 

619 ) 

620 offset = self._insert_instruction( 

621 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_counter 

622 ) 

623 offset = self._insert_instruction( 

624 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_false 

625 ) 

626 

627 offset = self._insert_instructions(to_insert, lineno, offset, call(5)) 

628 

629 return offset - start_offset, to_insert 

630 

631 def _generate_const_cmp_invocation(self, op: int, lineno: int, offset: int, 

632 switch: bool) -> _SizeAndInstructions: 

633 """Builds the bytecode that calls atheris._trace_cmp(). 

634 

635 Only call this if one of the objects being compared is a constant coming 

636 from co_consts. If `switch` is true the constant is the second argument and 

637 needs to be switched with the first argument. 

638 

639 Args: 

640 op: The comparison operation. 

641 lineno: The line number of the operation 

642 offset: The initial number of instructions. 

643 switch: bool whether the second arg is constant instead of the first. 

644 

645 Returns: 

646 The number of bytes to insert, and the instructions. 

647 """ 

648 to_insert = [] 

649 start_offset = offset 

650 const_atheris = self._get_const(sys.modules[_TARGET_MODULE]) 

651 name_cmp = self._get_name(_COMPARE_FUNCTION) 

652 const_counter = self._get_counter() 

653 const_true = self._get_const(True) 

654 const_op = None 

655 

656 if switch: 

657 const_op = self._get_const(REVERSE_CMP_OP[op]) 

658 else: 

659 const_op = self._get_const(op) 

660 

661 offset = self._insert_instructions( 

662 to_insert, lineno, offset, args_terminator() 

663 ) 

664 offset = self._insert_instruction( 

665 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_atheris 

666 ) 

667 offset = self._insert_instruction( 

668 to_insert, lineno, offset, dis.opmap["LOAD_ATTR"], name_cmp 

669 ) 

670 rot = rot_n(2 + CALLABLE_STACK_ENTRIES, CALLABLE_STACK_ENTRIES) 

671 offset = self._insert_instructions(to_insert, lineno, offset, rot) 

672 

673 if switch: 

674 offset = self._insert_instructions(to_insert, lineno, offset, rot_n(2)) 

675 

676 offset = self._insert_instruction( 

677 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_op 

678 ) 

679 offset = self._insert_instruction( 

680 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_counter 

681 ) 

682 offset = self._insert_instruction( 

683 to_insert, lineno, offset, dis.opmap["LOAD_CONST"], const_true 

684 ) 

685 

686 offset = self._insert_instructions(to_insert, lineno, offset, call(5)) 

687 

688 return offset - start_offset, to_insert 

689 

690 def trace_control_flow(self) -> None: 

691 """Insert a call to atheris._trace_branch() branch's target block. 

692 

693 The argument of _trace_branch() is an id for the branch. 

694 

695 The following bytecode gets inserted: 

696 LOAD_CONST atheris 

697 LOAD_ATTR _trace_branch 

698 LOAD_CONST <id> 

699 CALL_FUNCTION 1 

700 POP_TOP ; _trace_branch() returns None, remove the 

701 return value 

702 """ 

703 already_instrumented = set() 

704 

705 # Insert at the first point after a RESUME instruction 

706 first_real_instr = None 

707 first_real_instr_slot = None 

708 previous_instructions = [] 

709 for i in range(len(self._cfg[0].instructions)): 

710 bb_instr = self._cfg[0].instructions[i] 

711 if bb_instr.mnemonic not in ("RESUME", "GEN_START"): 

712 first_real_instr = bb_instr 

713 first_real_instr_slot = i 

714 break 

715 

716 if first_real_instr is None: 

717 # This was an empty code object (e.g. empty module) 

718 return 

719 assert first_real_instr_slot is not None 

720 

721 total_size, to_insert = self._generate_trace_branch_invocation( 

722 first_real_instr.lineno, first_real_instr.offset 

723 ) 

724 self._adjust(first_real_instr.offset, total_size) 

725 self._cfg[0].instructions = ( 

726 self._cfg[0].instructions[0:first_real_instr_slot] 

727 + to_insert 

728 + self._cfg[0].instructions[first_real_instr_slot:] 

729 ) 

730 

731 for basic_block in self._cfg.values(): 

732 if len(basic_block.edges) == 2: 

733 for edge in basic_block.edges: 

734 bb = self._cfg[edge] 

735 

736 if bb.id not in already_instrumented: 

737 already_instrumented.add(bb.id) 

738 source_instr = [] 

739 offset = bb.instructions[0].offset 

740 

741 for source_bb in self._cfg.values(): 

742 if bb.id in source_bb.edges and source_bb.instructions[ 

743 -1].reference == offset: 

744 source_instr.append(source_bb.instructions[-1]) 

745 

746 total_size, to_insert = self._generate_trace_branch_invocation( 

747 bb.instructions[0].lineno, offset) 

748 

749 self._adjust(offset, total_size, *source_instr) 

750 

751 bb.instructions = to_insert + bb.instructions 

752 

753 self._handle_size_changes() 

754 

755 def trace_data_flow(self) -> None: 

756 """Instruments bytecode for data-flow tracing. 

757 

758 This works by replacing the instruction COMPARE_OP with a call to 

759 atheris._trace_cmp(). The arguments for _trace_cmp() are as follows: 

760 - obj1 and obj2: The two values to compare 

761 - opid: argument to COMPARE_OP 

762 - counter: The counter for this comparison. 

763 - is_const: whether obj1 is a constant in co_consts. 

764 

765 To detect if any of the values being compared is a constant, all push and 

766 pop operations have to be analyzed. If a constant appears in a comparison it 

767 must always be given as obj1 to _trace_cmp(). 

768 

769 The bytecode that gets inserted looks like this: 

770 LOAD_CONST atheris 

771 LOAD_ATTR _trace_cmp 

772 ROT_THREE ; move atheris._trace_cmp below the two 

773 objects 

774 LOAD_CONST <opid> 

775 LOAD_CONST <counter index> 

776 LOAD_CONST <is_const> 

777 CALL_FUNCTION 5 

778 """ 

779 stack_size = 0 

780 seen_consts = [] 

781 

782 for basic_block in self._cfg.values(): 

783 for c, instr in enumerate(basic_block.instructions): 

784 if instr.mnemonic == "LOAD_CONST": 

785 seen_consts.append(stack_size) 

786 elif instr.mnemonic == "COMPARE_OP" and instr.arg <= 5: 

787 # If the instruction has CACHEs afterward, we'll need to NOP them too. 

788 instr_caches = [] 

789 for i in range(c + 1, c + 1 + cache_count(instr.mnemonic)): 

790 instr_caches.append(basic_block.instructions[i]) 

791 

792 # Determine the two values on the top of the stack before COMPARE_OP 

793 consts_on_stack = [ 

794 c for c in seen_consts if stack_size - 2 <= c < stack_size 

795 ] 

796 tos_is_constant = stack_size - 1 in consts_on_stack 

797 tos1_is_constant = stack_size - 2 in consts_on_stack 

798 

799 if not (tos_is_constant and tos1_is_constant): 

800 offset = instr.offset 

801 total_size = None 

802 to_insert = None 

803 

804 # Both items are non-constants 

805 if (not tos_is_constant) and (not tos1_is_constant): 

806 total_size, to_insert = self._generate_cmp_invocation( 

807 instr.arg, instr.lineno, offset) 

808 

809 # One item is constant, one is non-constant 

810 else: 

811 total_size, to_insert = self._generate_const_cmp_invocation( 

812 instr.arg, instr.lineno, offset, tos_is_constant) 

813 

814 self._adjust(offset, total_size) 

815 

816 for i, new_instr in enumerate(to_insert): 

817 basic_block.instructions.insert(c + i, new_instr) 

818 

819 instr.make_nop() 

820 for cache_instr in instr_caches: 

821 cache_instr.make_nop() 

822 

823 stack_size += instr.get_stack_effect() 

824 seen_consts = [c for c in seen_consts if c < stack_size] 

825 

826 self._handle_size_changes() 

827 

828 def _print_disassembly(self) -> None: 

829 """Prints disassembly.""" 

830 print(f"Disassembly of {self._code.co_filename}:{self._code.co_name}") 

831 for basic_block in self._cfg.values(): 

832 print(" -bb-") 

833 for instr in basic_block: 

834 print(f" L.{instr.lineno} [{instr.offset}] {instr.mnemonic} ", end="") 

835 

836 if instr.has_argument(): 

837 print(f"{instr.arg} ", end="") 

838 

839 if instr._is_relative: 

840 print(f"(to {instr.reference})", end="") 

841 

842 print() 

843 

844 

845def patch_code(code: types.CodeType, 

846 trace_dataflow: bool, 

847 nested: bool = False) -> types.CodeType: 

848 """Returns code, patched with Atheris instrumentation. 

849 

850 Args: 

851 code: The byte code to instrument. 

852 trace_dataflow: Whether to trace dataflow or not. 

853 nested: If False, reserve counters, and patch modules. Recursive calls to 

854 this function are considered nested. 

855 """ 

856 inst = Instrumentor(code) 

857 

858 # If this code object has already been instrumented, skip it 

859 for const in inst.consts: 

860 # This avoids comparison between str and bytes (BytesWarning). 

861 if isinstance(const, str) and const == "__ATHERIS_INSTRUMENTED__": 

862 return code 

863 

864 inst.trace_control_flow() 

865 

866 if trace_dataflow: 

867 inst.trace_data_flow() 

868 

869 # Repeat this for all nested code objects 

870 for i in range(len(inst.consts)): 

871 if isinstance(inst.consts[i], types.CodeType): 

872 if (inst.consts[i].co_name == "<lambda>" or 

873 (not nested and inst.consts[i].co_name == "<module>") or 

874 inst.consts[i].co_name[0] != "<" or 

875 inst.consts[i].co_name[-1] != ">"): 

876 inst.consts[i] = patch_code(inst.consts[i], trace_dataflow, nested=True) 

877 

878 return inst.to_code() 

879 

880 

881T = TypeVar("T") 

882 

883 

884def instrument_func(func: Callable[..., T]) -> Callable[..., T]: 

885 """Add Atheris instrumentation to a specific function.""" 

886 func.__code__ = patch_code(func.__code__, True, True) 

887 

888 return func 

889 

890 

891def _is_instrumentable(obj: Any) -> bool: 

892 """Returns True if this object can be instrumented.""" 

893 try: 

894 # Only callables can be instrumented 

895 if not hasattr(obj, "__call__"): 

896 return False 

897 # Only objects with a __code__ member of type CodeType can be instrumented 

898 if not hasattr(obj, "__code__"): 

899 return False 

900 if not isinstance(obj.__code__, types.CodeType): 

901 return False 

902 # Only code in a real module can be instrumented 

903 if not hasattr(obj, "__module__"): 

904 return False 

905 if obj.__module__ not in sys.modules: 

906 return False 

907 # Bound methods can't be instrumented - instrument the real func instead 

908 if hasattr(obj, "__self__"): 

909 return False 

910 # Only Python functions and methods can be instrumented, nothing native 

911 if (not isinstance(obj, types.FunctionType)) and (not isinstance( 

912 obj, types.MethodType)): 

913 return False 

914 except Exception: # pylint: disable=broad-except 

915 # If accessing any of those fields produced an exception, the object 

916 # probably can't be instrumented 

917 return False 

918 

919 return True 

920 

921 

922def instrument_all() -> None: 

923 """Add Atheris instrementation to all Python code already imported. 

924 

925 This function is experimental. 

926 

927 This function is able to instrument core library functions that can't be 

928 instrumented by instrument_func or instrument_imports, as those functions are 

929 used in the implementation of the instrumentation. 

930 """ 

931 progress_renderer = None 

932 

933 funcs = [obj for obj in gc.get_objects() if _is_instrumentable(obj)] 

934 if sys.stderr.isatty(): 

935 sys.stderr.write("INFO: Instrumenting functions: ") 

936 progress_renderer = utils.ProgressRenderer(sys.stderr, len(funcs)) 

937 else: 

938 sys.stderr.write(f"INFO: Instrumenting {len(funcs)} functions...\n") 

939 

940 for i in range(len(funcs)): 

941 func = funcs[i] 

942 try: 

943 instrument_func(func) 

944 except Exception as e: # pylint: disable=broad-except 

945 if progress_renderer: 

946 progress_renderer.drop() 

947 sys.stderr.write(f"ERROR: Failed to instrument function {func}: {e}\n") 

948 if progress_renderer: 

949 progress_renderer.count = i + 1 

950 

951 if progress_renderer: 

952 progress_renderer.drop() 

953 else: 

954 print("INFO: Instrumentation complete.")