Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/xla/xla.py: 20%

240 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""xla is an experimental library that provides XLA support APIs.""" 

16 

17import contextlib 

18 

19 

20from tensorflow.compiler.jit.ops import xla_ops 

21from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import 

22from tensorflow.core.framework import attr_value_pb2 

23from tensorflow.python.distribute import summary_op_util 

24from tensorflow.python.eager import context 

25from tensorflow.python.eager import def_function 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import control_flow_ops 

29from tensorflow.python.ops import variable_scope 

30from tensorflow.python.platform import tf_logging as logging 

31from tensorflow.python.util import compat 

32from tensorflow.python.util import nest 

33from tensorflow.python.util import tf_inspect 

34from tensorflow.python.util.compat import collections_abc 

35from tensorflow.python.util.deprecation import deprecated 

36from tensorflow.python.util.tf_export import tf_export 

37 

38_XLA_COMPILE_ATTR = '_xla_compile_id' 

39_MAX_WARNING_LINES = 5 

40 

41# Operations that indicate some error in the users graph. For example, XLA 

42# computation should not have any Placeholder op. 

43_DENYLISTED_OPS = set([ 

44 'Placeholder', 

45]) 

46 

47# XLA doesn't currently support reading of intermediate tensors, thus some ops 

48# are not supported. 

49_UNSUPPORTED_OPS = set([ 

50 'AudioSummary', 

51 'AudioSummaryV2', 

52 'HistogramSummary', 

53 'ImageSummary', 

54 'MergeSummary', 

55 'Print', 

56 'ScalarSummary', 

57 'TensorSummary', 

58 'TensorSummaryV2', 

59]) 

60 

61 

62@tf_export('xla.experimental.compile') 

63@deprecated( 

64 None, 'xla.experimental.compile is deprecated. Consider using ' 

65 '`@tf.function(jit_compile=True)`.', 

66 warn_once=True) 

67def compile(computation, inputs=None): # pylint: disable=redefined-builtin 

68 """Builds an operator that compiles and runs `computation` with XLA. 

69 

70 NOTE: In eager mode, `computation` will have `@tf.function` semantics. 

71 

72 Args: 

73 computation: A Python function that builds a computation to apply to the 

74 input. If the function takes n inputs, 'inputs' should be a list of n 

75 `Tensor`s. 

76 

77 `computation` may return a list of `Tensor`s and `Operation`s. 

78 `Tensor`s must come before `Operation`s in the returned list. 

79 

80 All `Operation`s returned from `computation` will be executed when 

81 evaluating any of the returned output tensors. 

82 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 

83 can be a nested structure containing values that can be converted to 

84 `Tensor`s. Note that passing an N-dimension list of compatible values will 

85 result in an N-dimension list of scalar `Tensor`s rather than a single 

86 Rank-N `Tensor`. If you need a different behavior, convert parts of 

87 `inputs` to `Tensor`s with `tf.convert_to_tensor`. 

88 

89 Returns: 

90 List of `Tensor`s corresponding to the `Tensor`s from 

91 the output of `computation` i.e. the same return value as if 

92 computation(*inputs) is called directly, with the following exceptions: 

93 * None output: a NoOp would be returned with a control dependency on 

94 `computation`. 

95 * Single value output: a tuple containing the value would be returned. 

96 * Operation-only outputs: a NoOp would be returned with a control 

97 dependency on `computation`. 

98 TODO(b/121383831): Investigate into removing these special cases. 

99 

100 Raises: 

101 RuntimeError: When eager execution is enabled. 

102 

103 Known issues: 

104 When a tf.random operation is built with XLA, the implementation doesn't 

105 pass the user provided seed to the XLA compiler. As such, the XLA compiler 

106 generates a random number and uses it as a seed when compiling the 

107 operation. This implementation causes a violation of the Tensorflow 

108 defined semantics in two aspects. First, changing the value of the user 

109 defined seed doesn't change the numbers generated by the operation. 

110 Second, when a seed is not specified, running the program multiple times 

111 will generate the same numbers. 

112 """ 

113 if context.executing_eagerly(): 

114 

115 @def_function.function 

116 def xla_compile_wrapper(): 

117 return _compile_internal(computation, inputs) 

118 

119 return xla_compile_wrapper() 

120 

121 return _compile_internal(computation, inputs) 

122 

123 

124class XLACompileContext(control_flow_ops.XLAControlFlowContext): 

125 """A `ControlFlowContext` for nodes inside an XLA computation cluster. 

126 

127 THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. 

128 

129 The primary role of `XLACompileContext` is to mark operators inside a 

130 xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is 

131 a unique name. 

132 

133 `ControlFlowContext` is used to perform the annotation since it integrates 

134 with Tensorflow constructs like ResourceVariables. For example, if a 

135 `ResourceVariable` is constructed inside a xla.compile() block, the 

136 `ResourceVariable` implementation can use 

137 `with ops.control_dependencies(None)` to build the variable's definition 

138 outside the compiled computation. 

139 """ 

140 

141 def __init__(self, name, pivot): 

142 """Builds a new XLACompileContext. 

143 

144 Args: 

145 name: a unique name for the context, used to populate the 

146 `_xla_compile_id` attribute. 

147 pivot: a pivot node. Nodes in the XLACompileContext that do not have any 

148 inputs will have a control dependency on the pivot node. This ensures 

149 that nodes are correctly included in any enclosing control flow 

150 contexts. 

151 """ 

152 super(XLACompileContext, self).__init__() 

153 self._name = name 

154 self._name_as_bytes = compat.as_bytes(name) 

155 self._unsupported_ops = [] 

156 self._pivot = pivot 

157 

158 def report_unsupported_operations(self): 

159 if self._unsupported_ops: 

160 op_str = '\n'.join([ 

161 ' %s (%s)' % (op.type, op.name) 

162 for op in self._unsupported_ops[:_MAX_WARNING_LINES] 

163 ]) 

164 logging.warning('%d unsupported operations found: \n%s', 

165 len(self._unsupported_ops), op_str) 

166 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 

167 logging.warning('... and %d more', 

168 len(self._unsupported_ops) - _MAX_WARNING_LINES) 

169 

170 def _RemoveExternalControlEdges(self, op): 

171 """Remove any external control dependency on this op.""" 

172 internal_control_inputs = [] 

173 external_control_inputs = [] 

174 for x in op.control_inputs: 

175 # pylint: disable=protected-access 

176 is_internal_op = False 

177 ctxt = x._get_control_flow_context() 

178 while ctxt is not None: 

179 if ctxt == self: 

180 is_internal_op = True 

181 break 

182 ctxt = ctxt._outer_context 

183 if is_internal_op: 

184 internal_control_inputs.append(x) 

185 else: 

186 external_control_inputs.append(x) 

187 # pylint: enable=protected-access 

188 # pylint: disable=protected-access 

189 op._remove_all_control_inputs() 

190 op._add_control_inputs(internal_control_inputs) 

191 # pylint: enable=protected-access 

192 return internal_control_inputs, external_control_inputs 

193 

194 def AddOp(self, op): 

195 """Create op in XLACompileContext and notifies outer context recursively.""" 

196 # pylint: disable=protected-access 

197 if op.type in _DENYLISTED_OPS: 

198 logging.error( 

199 'Operation of type %s (%s) is not supported in XLA. Execution will ' 

200 'fail if this op is used in the graph. ', op.type, op.name) 

201 

202 # TODO(ycao): Automatically disable summaries instead of reporting them. 

203 if op.type in _UNSUPPORTED_OPS: 

204 self._unsupported_ops.append(op) 

205 

206 if any(x.dtype._is_ref_dtype for x in op.inputs): 

207 raise NotImplementedError( 

208 'Non-resource Variables are not supported inside XLA computations ' 

209 '(operator name: %s)' % op.name) 

210 

211 if _XLA_COMPILE_ATTR in op.node_def.attr: 

212 raise ValueError('XLA compiled computations cannot be nested, (operator ' 

213 'name: %s)' % op.name) 

214 

215 op._set_attr( 

216 _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) 

217 

218 op.graph.prevent_feeding(op) 

219 op.graph.prevent_fetching(op) 

220 

221 # Remove any control edges from outer control flow contexts. These may cause 

222 # mismatched frame errors. An example is when one of op's inputs is 

223 # generated in a different While control flow context. 

224 (internal_control_inputs, 

225 external_control_inputs) = self._RemoveExternalControlEdges(op) 

226 

227 if not op.inputs: 

228 # Add a control edge from the control pivot to this op. 

229 if not internal_control_inputs: 

230 # pylint: disable=protected-access 

231 op._add_control_input(self._pivot) 

232 # pylint: enable=protected-access 

233 else: 

234 for index in range(len(op.inputs)): 

235 x = op.inputs[index] 

236 real_x = self.AddValue(x) 

237 if real_x is not x: 

238 op._update_input(index, real_x) # pylint: disable=protected-access 

239 

240 if external_control_inputs: 

241 # Use an identity to pull control inputs as data inputs. Note that we 

242 # ignore ops which don't have outputs. TODO(phawkins): fix that. 

243 with ops.control_dependencies(None): 

244 self.Enter() 

245 external_control_inputs = [ 

246 array_ops.identity(x.outputs[0]).op 

247 for x in external_control_inputs 

248 if x.outputs 

249 ] 

250 self.Exit() 

251 # pylint: disable=protected-access 

252 op._add_control_inputs(external_control_inputs) 

253 # pylint: enable=protected-access 

254 

255 # Mark op's outputs as seen by this context and any outer contexts. 

256 output_names = [x.name for x in op.outputs] 

257 context = self 

258 while context is not None: 

259 # pylint: disable=protected-access 

260 context._values.update(output_names) 

261 context = context._outer_context 

262 # pylint: enable=protected-access 

263 

264 if self._outer_context: 

265 self._outer_context.AddInnerOp(op) 

266 

267 def AddValue(self, val): 

268 """Add `val` to the current context and its outer context recursively.""" 

269 if val.name in self._values: 

270 # Use the real value if it comes from outer context. 

271 result = self._external_values.get(val.name) 

272 return val if result is None else result 

273 

274 result = val 

275 self._values.add(val.name) 

276 if self._outer_context: 

277 result = self._outer_context.AddValue(val) 

278 self._values.add(result.name) 

279 

280 self._external_values[val.name] = result 

281 

282 return result 

283 

284 def AddInnerOp(self, op): 

285 self.AddOp(op) 

286 if self._outer_context: 

287 self._outer_context.AddInnerOp(op) 

288 

289 @property 

290 def grad_state(self): 

291 # Define the gradient loop state associated with the XLACompileContext to 

292 # be None as the XLACompileContext does not get nested nor does the 

293 # grad_state outside the XLACompileContext affect the graph inside so the 

294 # grad_state should be as if this is the top-level gradient state. 

295 return None 

296 

297 @property 

298 def back_prop(self): 

299 """Forwards to the enclosing while context, if any.""" 

300 if self.GetWhileContext(): 

301 return self.GetWhileContext().back_prop 

302 return False 

303 

304 

305def _compile_internal(computation, inputs=None): 

306 """Builds graph operators that compiles and symbolically executes computation. 

307 

308 Args: 

309 computation: A Python function that builds the computation to compile and 

310 execute. 

311 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 

312 can be a nested structure containing values that are convertible to 

313 tensors. Note that passing an N-dimension list of compatible values will 

314 result in a N-dimension list of scalar tensors rather than a single Rank-N 

315 tensors. If you need different behavior, convert part of inputs to tensors 

316 with `tf.convert_to_tensor`. 

317 

318 Returns: 

319 Same data structure as if computation(*inputs) is called directly with some 

320 exceptions for correctness. Exceptions include: 1) None output 2) Single 

321 value output 3) Operation-only outputs 

322 Raises: 

323 ValueError: If any element in computation outputs is neither an operations 

324 or a value that can be converted to tensor. 

325 ValueError: If computation outputs is non-flat and contains any Operations. 

326 TypeError: If `inputs` is not a list or tuple. 

327 """ 

328 if inputs is None: 

329 inputs = [] 

330 

331 if not isinstance(inputs, collections_abc.Sequence): 

332 raise TypeError('inputs must be a list') 

333 

334 # Flatten inputs. 

335 flat_inputs = nest.flatten(inputs) 

336 # Converts inputs to Tensors. 

337 flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] 

338 

339 cluster_name = ops.get_default_graph().unique_name('cluster') 

340 pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') 

341 context = XLACompileContext(name=cluster_name, pivot=pivot) 

342 try: 

343 context.Enter() 

344 

345 # Add identity ops so even unused inputs are 'consumed' by the 

346 # computation. 

347 flat_inputs = [ 

348 array_ops.identity(x, name='input_{}'.format(i)) 

349 for i, x in enumerate(flat_inputs) 

350 ] 

351 

352 # Re-pack flat_inputs in same structure as 'inputs'. 

353 computation_inputs = nest.pack_sequence_as( 

354 structure=inputs, flat_sequence=flat_inputs) 

355 

356 # Only resource variables work inside an XLA computation, so turn on 

357 # resource variables for the computation. 

358 vscope = variable_scope.get_variable_scope() 

359 saved_use_resource = vscope.use_resource 

360 vscope.set_use_resource(True) 

361 

362 with _disable_summary_context(): 

363 outputs = computation(*computation_inputs) 

364 

365 # Restore variable scope after computation. 

366 vscope.set_use_resource(saved_use_resource) 

367 

368 outputs_is_flat = is_flat(outputs) 

369 if outputs_is_flat: 

370 output_tensors, control_deps = _postprocess_flat_outputs(outputs) 

371 else: 

372 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) 

373 

374 context.ExitResult(output_tensors) 

375 finally: 

376 context.report_unsupported_operations() 

377 context.Exit() 

378 

379 # When XLA computation returns only operations and no tensors, a NoOp 

380 # dependent on the operations in outputs is returned. Otherwise final 

381 # outputs would be empty and there is no way to trigger returned 

382 # operations. 

383 if not output_tensors: 

384 return control_flow_ops.group(control_deps, name='output_0') 

385 

386 output_tensors = [ 

387 xla_ops.xla_cluster_output(o, name='output{}'.format(i)) 

388 for i, o in enumerate(output_tensors) 

389 ] 

390 

391 with ops.control_dependencies(control_deps): 

392 # Wraps the outputs in identity operators that carries control 

393 # dependencies. 

394 output_tensors = [ 

395 array_ops.identity(o, name='output_%d' % i) 

396 for i, o in enumerate(output_tensors) 

397 ] 

398 

399 # If `computation` returned non-flat output structure, pack output tensors 

400 # back into same structure. 

401 if not outputs_is_flat: 

402 output_tensors = nest.pack_sequence_as( 

403 structure=outputs, flat_sequence=output_tensors) 

404 

405 return output_tensors 

406 

407 

408def is_flat(outputs): 

409 """Checks if outputs is a flat structure. 

410 

411 Following structures and values are considered flat: 

412 1) None 

413 2) A single object 

414 3) A list or tuple of Tensors/Operations 

415 

416 The only structures that this function understands are sequences, 

417 dictionaries and types defined using the attrs library. E.g. this means 

418 that if outputs contains a single user-defined Object, it is considered to 

419 be flat. Errors are raised later on if that Object cannot be converted to a 

420 Tensor. 

421 

422 Args: 

423 outputs: Output from `computation` inside `xla.compile`. 

424 

425 Returns: 

426 A boolean indicates whether outputs is flat. 

427 """ 

428 # If outputs is a list or tuple, check if it has any nested structure. If 

429 # there is, then outputs is non-flat. 

430 if isinstance(outputs, collections_abc.Sequence): 

431 for o in outputs: 

432 if (isinstance(o, collections_abc.Sequence) or 

433 isinstance(o, collections_abc.Mapping) or 

434 hasattr(o.__class__, '__attrs_attrs__')): 

435 return False 

436 

437 # If outputs is a dict, it is non-flat. 

438 if isinstance(outputs, collections_abc.Mapping): 

439 return False 

440 

441 # If outputs is from the attrs library, it is non-flat. 

442 if hasattr(outputs.__class__, '__attrs_attrs__'): 

443 return False 

444 

445 # Getting here means either outputs itself is a single non-structured value 

446 # or it is a flat list of single non-structured values. 

447 return True 

448 

449 

450def _postprocess_flat_outputs(outputs): 

451 """Validates flat outputs and adds back device assignments. 

452 

453 Args: 

454 outputs: Output from `computation` inside `xla.compile`. 

455 

456 Returns: 

457 Tensors and Operations extracted from outputs. 

458 """ 

459 # Following code segment is to preserve legacy behavior. Previously we only 

460 # supported flat outputs and thus for consistency it was nice to convert even 

461 # single element into a tuple. But now that we support arbitrary output 

462 # structure, this is no longer necessary. 

463 # TODO(b/121383831): Migrate all legacy use cases and delete this special 

464 # case. 

465 # If the computation returns `None`, make it an empty tuple. 

466 if outputs is None: 

467 outputs = tuple() 

468 # If the computation only returned one value, make it a tuple. 

469 if not isinstance(outputs, collections_abc.Sequence): 

470 outputs = (outputs,) 

471 

472 # Append `no_op` here so that return value of this function always contains 

473 # at least one op that can trigger XlaLaunch node. 

474 outputs += (control_flow_ops.no_op(),) 

475 try: 

476 outputs = [ 

477 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 

478 for o in outputs 

479 ] 

480 except Exception as e: 

481 raise ValueError( 

482 'XLA computation function return values must all either be Operations' 

483 ' or convertible to Tensors. Got error: "%s"' % str(e)) 

484 

485 # Separates the returned Operations and Tensors. 

486 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 

487 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 

488 

489 if outputs != output_tensors + output_operations: 

490 raise ValueError( 

491 'XLA computation function must return zero or more Tensor values ' 

492 'followed by zero or more Operations.') 

493 

494 new_output_tensors = [] 

495 for t in output_tensors: 

496 with ops.device(t.device if t.device else ''): 

497 new_output_tensors.append(array_ops.identity(t)) 

498 

499 return new_output_tensors, output_operations 

500 

501 

502def _postprocess_non_flat_outputs(outputs): 

503 """Validates non-flat outputs and adds back device assignments. 

504 

505 Args: 

506 outputs: Output from `computation` inside `xla.compile`. 

507 

508 Returns: 

509 Tensors extracted from outputs and an empty list because Operations are not 

510 allowed in non-flat outputs.. 

511 """ 

512 # Convert all non-Operation outputs to Tensors. 

513 new_output_tensors = [] 

514 for o in nest.flatten(outputs): 

515 if isinstance(o, ops.Operation): 

516 raise ValueError( 

517 'xla.compile does not support Operation as return value in non-flat ' 

518 'output structure. You can set returned Operations as control ' 

519 'dependencies of returned Tensors so Operations are triggered when ' 

520 'Tensors are evaluated. Operation found: "%s"' % o.name) 

521 

522 try: 

523 o = ops.convert_to_tensor(o) 

524 except Exception as e: 

525 raise ValueError( 

526 'XLA computation function return values must all either be ' 

527 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) 

528 

529 # Makes sure even pass-through inputs/outputs are touched in compile 

530 # context by creating an Identity node inside compile context. 

531 with ops.device(o.device if o.device else ''): 

532 new_output_tensors.append(array_ops.identity(o)) 

533 

534 return new_output_tensors, [] 

535 

536 

537@contextlib.contextmanager 

538def _disable_summary_context(): 

539 """Enters a context where all summary ops are skipped. 

540 

541 Summaries are not yet supported in xla.compile(). So we provide this context 

542 manager that can skip creating summary ops. This is a temporary workaround due 

543 to XLA not supporting summary ops. 

544 

545 Yields: 

546 None. 

547 """ 

548 original_skip_summary_func = summary_op_util.skip_summary 

549 summary_op_util.skip_summary = lambda: True 

550 

551 try: 

552 yield 

553 finally: 

554 summary_op_util.skip_summary = original_skip_summary_func 

555 

556 

557class _CapturedObject(object): 

558 """A placeholder to capture an object.""" 

559 

560 def __init__(self): 

561 self._object = None 

562 

563 def capture(self, o): 

564 if self._object: 

565 raise RuntimeError( 

566 'InternalError: _CapturedObject can capture only once. Please file ' 

567 'bug.') 

568 

569 self._object = o 

570 

571 def get(self): 

572 return self._object 

573 

574 

575def _get_scaffold(captured_scaffold_fn): 

576 """Retrieves the Scaffold from `captured_scaffold_fn`.""" 

577 scaffold_fn = captured_scaffold_fn.get() 

578 

579 if not scaffold_fn: 

580 return None 

581 

582 scaffold = scaffold_fn() 

583 if scaffold is None: 

584 raise ValueError( 

585 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') 

586 

587 return scaffold 

588 

589 

590def check_function_argument_count(func, input_arity, infeed_queue): 

591 """Validate the number of input arguments to an XLA function. 

592 

593 Args: 

594 func: the Python function that will be called to generate the body of an XLA 

595 computation graph. 

596 input_arity: the number of explicit arguments supplied by the caller. 

597 infeed_queue: if not None, the infeed queue that will supply 

598 additional arguments to the function. 

599 

600 Returns: 

601 None if function can be called with the supplied number of 

602 arguments, or an error string if it cannot. 

603 """ 

604 def format_error(complaint, quantity): 

605 return '%s %d argument%s' % (complaint, quantity, '' 

606 if quantity == 1 else 's') 

607 

608 num_args_supplied = input_arity 

609 if infeed_queue is not None: 

610 num_args_supplied += infeed_queue.number_of_tuple_elements 

611 arg_spec = tf_inspect.getargspec(func) 

612 num_func_args = len(arg_spec.args) 

613 if arg_spec.defaults is None: 

614 num_func_defaults = 0 

615 else: 

616 num_func_defaults = len(arg_spec.defaults) 

617 min_func_args = num_func_args - num_func_defaults 

618 if num_args_supplied < min_func_args: 

619 # The required number of arguments is not enough to call the function. 

620 if num_func_defaults == 0 and arg_spec.varargs is None: 

621 return format_error('exactly', num_func_args) 

622 else: 

623 return format_error('at least', min_func_args) 

624 if arg_spec.varargs is None and num_args_supplied > num_func_args: 

625 # The required number of arguments is too many to call the function. 

626 if num_func_defaults == 0: 

627 return format_error('exactly', num_func_args) 

628 else: 

629 return format_error('at most', num_func_args) 

630 # Reaching here means either 

631 # 1) There are varargs, func can accept any number of arguments greater than 

632 # the minimum. 

633 # 2) Number of supplied arguments falls in range of acceptable argument count 

634 # of func. 

635 return None