Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/wrap_function.py: 20%

245 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=unidiomatic-typecheck 

16"""Prototype decorator for defining legacy-graph-mode functions.""" 

17 

18import weakref 

19 

20from tensorflow.core.protobuf import meta_graph_pb2 

21from tensorflow.core.protobuf import struct_pb2 

22from tensorflow.python.eager import context 

23from tensorflow.python.eager import function 

24from tensorflow.python.eager import lift_to_graph 

25from tensorflow.python.framework import composite_tensor 

26from tensorflow.python.framework import func_graph 

27from tensorflow.python.framework import importer 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import sparse_tensor 

30from tensorflow.python.framework import tensor_shape 

31from tensorflow.python.framework import tensor_spec 

32from tensorflow.python.framework import tensor_util 

33from tensorflow.python.ops import resource_variable_ops 

34from tensorflow.python.ops import variable_scope 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.saved_model import nested_structure_coder 

37from tensorflow.python.trackable import data_structures 

38from tensorflow.python.util import nest 

39from tensorflow.python.util.tf_export import tf_export 

40 

41 

42class VariableHolder(object): 

43 """Holds variables for a python function.""" 

44 

45 def __init__(self, fn=None, share_variables=False): 

46 self._fn = fn 

47 

48 self._share_variables = share_variables 

49 self._variables_by_name = data_structures.Mapping() 

50 

51 @property 

52 def variables(self): 

53 return self._variables_by_name 

54 

55 def variable_creator_scope(self, next_creator, **kwargs): 

56 """Creates variables & adds them to collections to match legacy code.""" 

57 collections = kwargs.pop("collections", None) 

58 v = None 

59 

60 # Get expected variable name. 

61 with ops.name_scope( 

62 kwargs.get("name", None), "Variable", skip_on_eager=False) as name: 

63 variable_name = ops.name_from_scope_name(name) 

64 kwargs["name"] = name 

65 

66 if self._share_variables: 

67 v = self._variables_by_name.get(variable_name, None) 

68 

69 if v is None: 

70 v = next_creator(**kwargs) 

71 self._variables_by_name[variable_name] = v 

72 

73 if collections is None: 

74 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 

75 if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 

76 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 

77 

78 ops.add_to_collections(collections, v) 

79 

80 return v 

81 

82 def __call__(self, *args, **kwargs): 

83 return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) 

84 

85 def call_with_variable_creator_scope(self, fn): 

86 

87 def wrapped(*args, **kwargs): 

88 with variable_scope.variable_creator_scope(self.variable_creator_scope): 

89 return fn(*args, **kwargs) 

90 

91 return wrapped 

92 

93 

94def _get_element_from_tensor_info(tensor_info, graph): 

95 """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" 

96 encoding = tensor_info.WhichOneof("encoding") 

97 if encoding == "name": 

98 # We may get operations here in some cases. TensorInfo is a bit of a 

99 # misnomer if so. 

100 return graph.as_graph_element(tensor_info.name) 

101 elif encoding == "coo_sparse": 

102 return sparse_tensor.SparseTensor( 

103 graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name), 

104 graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name), 

105 graph.get_tensor_by_name( 

106 tensor_info.coo_sparse.dense_shape_tensor_name)) 

107 elif encoding == "composite_tensor": 

108 spec_proto = struct_pb2.StructuredValue( 

109 type_spec_value=tensor_info.composite_tensor.type_spec) 

110 spec = nested_structure_coder.decode_proto(spec_proto) 

111 components = [graph.get_tensor_by_name(component.name) for component in 

112 tensor_info.composite_tensor.components] 

113 return spec._from_components(components) # pylint: disable=protected-access 

114 else: 

115 raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid " 

116 "encodings are 'name', 'coo_sparse', and " 

117 "'composite_tensor'.") 

118 

119 

120def _lift_single_variable(old_variable, graph, variable_holder): 

121 """Lifts `old_variable` out of the `FuncGraph` `graph`.""" 

122 new_variable = resource_variable_ops.UninitializedVariable( 

123 shape=old_variable.shape, 

124 dtype=old_variable.dtype, 

125 name=old_variable.op.name, 

126 trainable=old_variable.trainable, 

127 extra_handle_data=old_variable.handle) 

128 new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access 

129 graph.add_capture(new_variable.handle, old_variable.handle) 

130 # Now that we've added the new variable to graph.captures, 

131 # graph.capture will use that cached value and do some post-processing 

132 # on the capture like recording it on the tape. 

133 graph.capture(new_variable.handle) 

134 # pylint: disable=protected-access 

135 variable_name = new_variable.name.split(":")[0] 

136 variable_holder._variables_by_name[variable_name] = new_variable 

137 graph._weak_variables.append(weakref.ref(new_variable)) 

138 # pylint: enable=protected-access 

139 graph.watch_variable(new_variable) 

140 return new_variable 

141 

142 

143def _lift_unlifted_variables(graph, variable_holder): 

144 """Finds resource variables and lifts them into the outer context. 

145 

146 When we import a GraphDef inside a wrap_function, no Python graph building 

147 code runs. This means we get VarHandleOps which create variable resources, 

148 but no corresponding Python objects. Leaving them like this works but gives 

149 the user no way to interact with or modify the variables outside the graph. 

150 

151 This method searches for variables and lifts them out as regular variable 

152 objects when possible, indicating to the FuncGraph that they are captures. 

153 

154 Args: 

155 graph: The FuncGraph to lift variables from. 

156 variable_holder: A VariableHolder to record the lifted variables in. 

157 """ 

158 with graph.as_default(): 

159 global_collection_variables = ops.get_collection( 

160 ops.GraphKeys.GLOBAL_VARIABLES) 

161 local_collection_variables = ops.get_collection( 

162 ops.GraphKeys.LOCAL_VARIABLES) 

163 existing_captures = {id(c) for c in graph.internal_captures} 

164 lifted_variables = {} 

165 

166 def _should_lift_variable(v): 

167 return ((v._in_graph_mode # pylint: disable=protected-access 

168 and v.graph.building_function) 

169 and isinstance(v, resource_variable_ops.BaseResourceVariable) 

170 and id(v.handle) not in existing_captures) 

171 

172 for old_variable in global_collection_variables: 

173 if _should_lift_variable(old_variable): 

174 new_variable = _lift_single_variable( 

175 old_variable, graph, variable_holder) 

176 lifted_variables[id(old_variable)] = new_variable 

177 existing_captures.add(id(old_variable.handle)) 

178 

179 for old_variable in local_collection_variables: 

180 if _should_lift_variable(old_variable): 

181 new_variable = _lift_single_variable( 

182 old_variable, graph, variable_holder) 

183 lifted_variables[id(old_variable)] = new_variable 

184 existing_captures.add(id(old_variable.handle)) 

185 if new_variable._in_graph_mode: # pylint: disable=protected-access 

186 outer_graph = new_variable.graph 

187 # Variables are added to the global collection by default. In this 

188 # case we only want the variable in the local collection, so we'll pop 

189 # it out. 

190 global_collection = outer_graph.get_collection_ref( 

191 ops.GraphKeys.GLOBAL_VARIABLES) 

192 global_collection.remove(new_variable) 

193 outer_graph.add_to_collection( 

194 ops.GraphKeys.LOCAL_VARIABLES, new_variable) 

195 

196 # Update the FuncGraph's collections, partly for the user and partly so this 

197 # function is idempotent when it runs again in prune() calls. 

198 for collection_name in [ 

199 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES 

200 ]: 

201 mutable_collection = ops.get_collection_ref(collection_name) 

202 for index, current in enumerate(mutable_collection): 

203 mutable_collection[index] = lifted_variables.get(id(current), current) 

204 if not resource_variable_ops.is_resource_variable( 

205 mutable_collection[index]): 

206 logging.log_first_n( 

207 logging.WARN, 

208 "Unable to create a python object for variable {} because it is " 

209 "a reference variable. It may not be visible to training APIs. " 

210 "If this is a problem, consider rebuilding the SavedModel after " 

211 "running tf.compat.v1.enable_resource_variables().".format( 

212 mutable_collection[index]), 

213 5) 

214 

215 

216# TODO(allenl): make this trackable 

217class WrappedFunction(function.ConcreteFunction): 

218 """Wraps a tf V1 piece of code in a function.""" 

219 

220 def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): 

221 self._variable_holder = variable_holder 

222 _lift_unlifted_variables(fn_graph, variable_holder) 

223 # We call __init__ after lifting variables so that the function's signature 

224 # properly reflects the new captured inputs. 

225 for f in fn_graph.as_graph_def().library.function: 

226 context.context().add_function_def(f) 

227 self._signature = signature 

228 super(WrappedFunction, self).__init__(fn_graph, attrs=attrs) 

229 

230 def _call_impl(self, args, kwargs): 

231 if self._arg_keywords is None: 

232 if kwargs: 

233 raise NotImplementedError( 

234 "Keyword arguments are not supported when calling a " 

235 f"wrap_function-decorated function. Got {kwargs}.") 

236 if self._signature is not None: 

237 args = list(args) 

238 for i, arg in enumerate(args): 

239 if isinstance(self._signature[i], tensor_spec.DenseSpec): 

240 args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) 

241 return self._call_flat(args, self.captured_inputs) 

242 else: 

243 return super(WrappedFunction, self)._call_impl(args, kwargs) 

244 

245 def prune(self, feeds, fetches, name=None, input_signature=None): 

246 """Extract a subgraph of this function's underlying graph. 

247 

248 Wraps the subgraph in a new `WrappedFunction` object. 

249 

250 Args: 

251 feeds: Input tensors to the subgraph to extract, as `Tensor` objects. 

252 fetches: Possibly-nested Python data structure containing information 

253 about outputs of the target subgraph. Each entry can either be a 

254 `Tensor` object (for data outputs), an `Operation` object (for control 

255 outputs), or a `TensorInfo` proto. Any additional shape/dtype 

256 information provided in a `TensorInfo` and not present in the original 

257 graph will be added to the returned subgraph. 

258 name: (optional) Name to give to the underlying `FuncGraph` of the 

259 returned object. If no name is provided, the graph's name will be 

260 `"pruned"`. 

261 input_signature: (optional) possibly-nested Python data structure 

262 containing `TensorSpec` objects, with which to populate the returned 

263 functions's `FuncGraph`'s `structured_input_signature` field. 

264 

265 Returns: 

266 A new `WrappedFunction` object containing a copy of the portion of this 

267 object's graph that goes from `feeds` to `fetches`. 

268 """ 

269 # TODO(b/129646028): Add support for CompositeTensors. 

270 name = name or "pruned" 

271 flat_feeds = nest.flatten(feeds, expand_composites=True) 

272 flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] 

273 for f in flat_feeds: 

274 if not isinstance(f, ops.Tensor): 

275 raise ValueError("All memebers of argument `feeds` must be tensors. " 

276 f"Got {f} with type {type(f)}.") 

277 

278 # Ignoring all feeds that are captures allows prune to be called 

279 # using wrapped_func.inputs even when it uses variables 

280 internal_captures = {id(c) for c in self.graph.internal_captures} 

281 flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures] 

282 

283 operation_fetches = [] 

284 tensor_fetches = [] 

285 tensor_infos = [] 

286 

287 def _fetch_preprocessing_callback(fetch): 

288 """Extract out lists of ops, tensors, and tensor type info. 

289 

290 Turns TensorInfos into Tensors in the original `fetches` structure. 

291 Also extracts ops from `fetches`. 

292 

293 Args: 

294 fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or 

295 string identifying a Tensor or Operation. 

296 

297 Returns: 

298 `fetch` converted to a Tensor. 

299 """ 

300 if isinstance(fetch, ops.Operation): 

301 operation_fetches.append(fetch) 

302 return fetch 

303 elif isinstance(fetch, meta_graph_pb2.TensorInfo): 

304 tensor_infos.append(fetch) 

305 decoded = _get_element_from_tensor_info(fetch, self._func_graph) 

306 if (tensor_util.is_tf_type(decoded) or 

307 isinstance(decoded, composite_tensor.CompositeTensor)): 

308 tensor_fetches.append(decoded) 

309 else: 

310 operation_fetches.append(decoded) 

311 return decoded 

312 elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): 

313 tensor_fetches.append(fetch) 

314 return fetch 

315 else: 

316 graph_element = self.graph.as_graph_element(fetch) 

317 return _fetch_preprocessing_callback(graph_element) 

318 

319 fetches = nest.map_structure(_fetch_preprocessing_callback, fetches) 

320 

321 # Expand composite tensors into their component dense Tensors. 

322 tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) 

323 

324 for f in flat_feeds + tensor_fetches + operation_fetches: 

325 if f.graph is not self._func_graph: 

326 raise ValueError("Can only prune function whose feeds and fetches " 

327 f"from graph {self._func_graph}. Input " 

328 f"{f} is from a different graph {f.graph}.") 

329 with self._func_graph.as_default(): 

330 pruned_graph = func_graph.FuncGraph(name) 

331 lift_map = lift_to_graph.lift_to_graph( 

332 operation_fetches + tensor_fetches, 

333 pruned_graph, 

334 sources=flat_feeds + self.graph.internal_captures, 

335 base_graph=self._func_graph) 

336 

337 # Note that we add the component tensors of any composite tensors to the 

338 # returned function's outputs list; the list must contain these component 

339 # tensors, or the function's sparse outputs won't work properly. 

340 pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) 

341 pruned_graph.control_outputs.extend( 

342 [lift_map[operation] for operation in operation_fetches]) 

343 pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) 

344 for external_capture, internal_capture in self.graph.captures: 

345 pruned_graph.add_capture(external_capture, lift_map[internal_capture]) 

346 for ti in tensor_infos: 

347 if ti.WhichOneof("encoding") == "name": # Dense tensors only 

348 t = pruned_graph.as_graph_element(ti.name) 

349 if tensor_util.is_tf_type(t): 

350 t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) 

351 # pylint: disable=protected-access 

352 for f in self.graph._functions.values(): 

353 pruned_graph._add_function(f) 

354 # pylint: enable=protected-access 

355 

356 pruned_graph.variables = self.graph.variables 

357 

358 def _structured_output_mapping(fetched): 

359 """callback for `nest.map_structure()`""" 

360 lifted = lift_map[fetched] 

361 if isinstance(lifted, ops.Operation): 

362 return None 

363 return lifted 

364 

365 # expand_composites=True here causes composite tensors to be expanded 

366 # into their component dense Tensors, mapped to the new graph, and then 

367 # reconstituted into their original composite form. 

368 pruned_graph.structured_outputs = nest.map_structure( 

369 _structured_output_mapping, fetches, expand_composites=True) 

370 

371 if input_signature: 

372 # canonicalize the signature before setting 

373 args, kwargs = input_signature 

374 args = () if args is None else args 

375 input_signature = (args, kwargs) 

376 

377 pruned_graph.structured_input_signature = input_signature 

378 pruned_fn = WrappedFunction( 

379 pruned_graph, variable_holder=self._variable_holder) 

380 pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access 

381 # TODO(kathywu): Enable keyword arguments if an input signature is specified 

382 pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access 

383 return pruned_fn 

384 

385 

386def _filter_returned_ops(fn): 

387 """Filtering out any ops returned by function. 

388 

389 Args: 

390 fn: a function 

391 

392 Returns: 

393 A tuple of ( 

394 Wrapped function that returns `None` in place of any ops, 

395 dict that maps the index in the flat output structure to the returned op 

396 ) 

397 """ 

398 returned_ops = {} 

399 

400 def wrap_and_filter_returned_ops(*args, **kwargs): 

401 outputs = fn(*args, **kwargs) 

402 flat_outputs = nest.flatten(outputs) 

403 for n in range(len(flat_outputs)): 

404 output = flat_outputs[n] 

405 if isinstance(output, ops.Operation): 

406 returned_ops[n] = output 

407 flat_outputs[n] = None 

408 return nest.pack_sequence_as(outputs, flat_outputs) 

409 

410 return wrap_and_filter_returned_ops, returned_ops 

411 

412 

413class WrappedGraph(object): 

414 """Class for wrapping multiple TF 1.X functions in a single graph. 

415 

416 Maintains a dictionary mapping names to wrapped functions. See 

417 `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. 

418 

419 Functions wrapped using this class have access to variables and collections 

420 created in other wrapped functions, using the standard TF 1.X API ( 

421 `tf.compat.v1.get_variable` or 

422 `tf.compat.v1.get_default_graph().get_collection(...)`) 

423 

424 Outside a function, variables and collections may be accessed using the 

425 `variables` and `graph` properties. 

426 

427 Example: 

428 

429 ``` 

430 def add_v1(x): 

431 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 

432 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 

433 return v + x 

434 

435 def increment_var_v1(x): 

436 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 

437 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 

438 return v.assign_add(x) 

439 

440 g = WrappedGraph() 

441 add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) 

442 increment_var = g.wrap_function(increment_var_v1, 

443 [tf.TensorSpec([], tf.int32)]) 

444 

445 assert len(g.variables) == 1 

446 assert g.variables[0].numpy() == 0 

447 increment_var(tf.constant(5)) 

448 assert g.variables[0].numpy() == 5 

449 

450 ``` 

451 """ 

452 

453 def __init__(self, variable_holder=None, **kwargs): 

454 self._variable_holder = ( 

455 variable_holder or VariableHolder(share_variables=True)) 

456 

457 name = kwargs.pop("name", "wrapped_function_graph") 

458 # Always start with empty collections, unless otherwise specified. Setting 

459 # `collections=None` will copy the collections from the outer graph. 

460 collections = kwargs.pop("collections", {}) 

461 self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) 

462 

463 self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) 

464 self._functions = {} 

465 

466 @property 

467 def functions(self): 

468 return self._functions 

469 

470 @property 

471 def variables(self): 

472 return self._variable_holder.variables 

473 

474 def wrap_function(self, fn, signature, name=None): 

475 """Wraps a TF 1.X function and returns an eager-compatible function. 

476 

477 All functions wrapped in the same `WrappedGraph` will have access to the 

478 same graph (`tf.compat.v1.get_default_graph` to get the graph object 

479 within a function, or `WrappedGraph.graph` to get the graph outside a 

480 function). Variables created within the function will be added to the 

481 `variables` list. 

482 

483 Function inputs: All inputs to the function must be tensors (nested ok), 

484 with their shapes and dtypes defined in the `signature` argument. 

485 

486 Function outputs: 

487 

488 * The 1.X function may return tensors, variables, and ops. The wrapped 

489 eager-compatible function will always return tensors in the same nested 

490 structure. 

491 * Variables are replaced with a tensor containing the latest read values. 

492 * Returned ops are executed, and replaced with None. 

493 * The order of op execution and variable reads in the return is 

494 nondeterministic. For example: 

495 

496 ``` 

497 def update_var(x): 

498 v = tf.Variable(0) 

499 op = tf.compat.v1.assign(v, x).op 

500 return v, op 

501 

502 g = WrappedGraph() 

503 fn = g.wrap_function(update_var) 

504 read_value, _ = fn(tf.constant(3)) 

505 print(read_value.numpy()) # could be 0 or 3 

506 print(g.variables[0].numpy()) # always 3 

507 ``` 

508 

509 To ensure that ops in the function are executed (e.g. ops added to the 

510 `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. 

511 

512 Args: 

513 fn: a 1.X tensorflow function. 

514 signature: a possibly nested sequence of `TensorSpecs` specifying the 

515 shapes and dtypes of the arguments. 

516 name: an optional string name for the function. The function will be saved 

517 with key `name` in the `functions` dictionary. 

518 

519 Returns: 

520 An eager-compatible function. 

521 """ 

522 return self._wrap_function(fn, signature=signature, name=name) 

523 

524 def _wrap_function(self, 

525 fn, 

526 args=None, 

527 kwargs=None, 

528 signature=None, 

529 name=None): 

530 """Internal wrap function method with extended func_graph arguments.""" 

531 fn_with_filter_and_scope, returned_ops = _filter_returned_ops( 

532 self._variable_holder.call_with_variable_creator_scope(fn)) 

533 

534 func_graph.func_graph_from_py_func( 

535 None, # Name is unused. 

536 fn_with_filter_and_scope, 

537 args=args, 

538 kwargs=kwargs, 

539 signature=signature, 

540 add_control_dependencies=False, 

541 func_graph=self.graph) 

542 

543 # This code relies on questional behavior from `func_graph_from_py_func`. 

544 # If an existing FuncGraph is passed into the `func_graph` arg, the inputs 

545 # and structured outputs are overwritten. Pretty sure this is a bug, 

546 # because structured outputs doesn't match up with the outputs... 

547 fn_inputs = self.graph.inputs[:-len(self.graph.captures)] 

548 

549 # Return filtered ops to the flattened outputs. 

550 flat_fn_outputs = nest.flatten(self.graph.structured_outputs) 

551 for index, op in returned_ops.items(): 

552 flat_fn_outputs[index] = op 

553 fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, 

554 flat_fn_outputs) 

555 

556 name = name or fn.__name__ 

557 wrapped_function = self._wrapped_function.prune( 

558 fn_inputs, fn_outputs, name, self.graph.structured_input_signature) 

559 self._functions[name] = wrapped_function 

560 return wrapped_function 

561 

562 

563@tf_export(v1=["wrap_function"]) 

564def wrap_function(fn, signature, name=None): 

565 """Wraps the TF 1.x function fn into a graph function. 

566 

567 The python function `fn` will be called once with symbolic arguments specified 

568 in the `signature`, traced, and turned into a graph function. Any variables 

569 created by `fn` will be owned by the object returned by `wrap_function`. The 

570 resulting graph function can be called with tensors which match the 

571 signature. 

572 

573 ```python 

574 def f(x, do_add): 

575 v = tf.Variable(5.0) 

576 if do_add: 

577 op = v.assign_add(x) 

578 else: 

579 op = v.assign_sub(x) 

580 with tf.control_dependencies([op]): 

581 return v.read_value() 

582 

583 f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) 

584 

585 assert float(f_add(1.0)) == 6.0 

586 assert float(f_add(1.0)) == 7.0 

587 

588 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 

589 # of variables, and possibly different non-template arguments. 

590 f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) 

591 

592 assert float(f_sub(1.0)) == 4.0 

593 assert float(f_sub(1.0)) == 3.0 

594 ``` 

595 

596 Both `tf.compat.v1.wrap_function` and `tf.function` create a callable 

597 TensorFlow graph. But while `tf.function` runs all stateful operations 

598 (e.g. `tf.print`) and sequences operations to provide the same semantics as 

599 eager execution, `wrap_function` is closer to the behavior of `session.run` in 

600 TensorFlow 1.x. It will not run any operations unless they are required to 

601 compute the function's outputs, either through a data dependency or a control 

602 dependency. Nor will it sequence operations. 

603 

604 Unlike `tf.function`, `wrap_function` will only trace the Python function 

605 once. As with placeholders in TF 1.x, shapes and dtypes must be provided to 

606 `wrap_function`'s `signature` argument. 

607 

608 Since it is only traced once, variables and state may be created inside the 

609 function and owned by the function wrapper object. 

610 

611 Args: 

612 fn: python function to be wrapped 

613 signature: the placeholder and python arguments to be passed to the wrapped 

614 function 

615 name: Optional. The name of the function. 

616 

617 Returns: 

618 the wrapped graph function. 

619 """ 

620 holder = VariableHolder(fn) 

621 func_graph_name = "wrapped_function" 

622 if name is not None: 

623 func_graph_name = "wrapped_function_" + name 

624 return WrappedFunction( 

625 func_graph.func_graph_from_py_func( 

626 func_graph_name, 

627 holder, 

628 args=None, 

629 kwargs=None, 

630 signature=signature, 

631 add_control_dependencies=False, 

632 collections={}), 

633 variable_holder=holder, 

634 signature=signature) 

635 

636 

637def function_from_graph_def(graph_def, inputs, outputs, captures=None): 

638 """Creates a ConcreteFunction from a GraphDef. 

639 

640 Args: 

641 graph_def: A GraphDef to make a function out of. 

642 inputs: A Tensor name or nested structure of names in `graph_def` which 

643 should be inputs to the function. 

644 outputs: A Tensor name or nested structure of names in `graph_def` which 

645 should be outputs of the function. 

646 captures: (Optional) A dictionary mapping node names in `graph_def` that 

647 should be captured as inputs to tensors containing the value of the 

648 captured inputs. 

649 

650 Returns: 

651 A ConcreteFunction. 

652 """ 

653 

654 def _imports_graph_def(): 

655 importer.import_graph_def(graph_def, name="") 

656 graph = ops.get_default_graph() 

657 if captures is not None: 

658 for c in captures: 

659 graph.add_capture(captures[c], graph.get_tensor_by_name(str(c) + ":0")) 

660 

661 wrapped_import = wrap_function(_imports_graph_def, []) 

662 import_graph = wrapped_import.graph 

663 return wrapped_import.prune( 

664 nest.map_structure(import_graph.as_graph_element, inputs), 

665 nest.map_structure(import_graph.as_graph_element, outputs))