Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py: 12%

699 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""A `Network` is way to compose layers: the topological form of a `Model`.""" 

17 

18import collections 

19import copy 

20import itertools 

21import warnings 

22 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.keras import backend 

27from tensorflow.python.keras.engine import base_layer 

28from tensorflow.python.keras.engine import base_layer_utils 

29from tensorflow.python.keras.engine import input_layer as input_layer_module 

30from tensorflow.python.keras.engine import input_spec 

31from tensorflow.python.keras.engine import node as node_module 

32from tensorflow.python.keras.engine import training as training_lib 

33from tensorflow.python.keras.engine import training_utils 

34from tensorflow.python.keras.saving.saved_model import network_serialization 

35from tensorflow.python.keras.utils import generic_utils 

36from tensorflow.python.keras.utils import tf_inspect 

37from tensorflow.python.keras.utils import tf_utils 

38from tensorflow.python.ops import array_ops 

39from tensorflow.python.ops import math_ops 

40from tensorflow.python.platform import tf_logging as logging 

41from tensorflow.python.trackable import base as trackable 

42from tensorflow.python.util import nest 

43from tensorflow.tools.docs import doc_controls 

44 

45 

46# pylint: disable=g-classes-have-attributes 

47class Functional(training_lib.Model): 

48 """A `Functional` model is a `Model` defined as a directed graph of layers. 

49 

50 Three types of `Model` exist: subclassed `Model`, `Functional` model, 

51 and `Sequential` (a special case of `Functional`). 

52 In general, more Keras features are supported with `Functional` 

53 than with subclassed `Model`s, specifically: 

54 

55 - Model cloning (`keras.models.clone`) 

56 - Serialization (`model.get_config()/from_config`, `model.to_json()` 

57 - Whole-model saving (`model.save()`) 

58 

59 A `Functional` model can be instantiated by passing two arguments to 

60 `__init__`. The first argument is the `keras.Input` Tensors that represent 

61 the inputs to the model. The second argument specifies the output 

62 tensors that represent the outputs of this model. Both arguments can be a 

63 nested structure of tensors. 

64 

65 Example: 

66 

67 ``` 

68 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))} 

69 t = keras.layers.Dense(1, activation='relu')(inputs['x1']) 

70 outputs = keras.layers.Add()([t, inputs['x2']) 

71 model = keras.Model(inputs, outputs) 

72 ``` 

73 

74 A `Functional` model constructed using the Functional API can also include raw 

75 TensorFlow functions, with the exception of functions that create Variables 

76 or assign ops. 

77 

78 Example: 

79 

80 ``` 

81 inputs = keras.Input(shape=(10,)) 

82 x = keras.layers.Dense(1)(inputs) 

83 outputs = tf.nn.relu(x) 

84 model = keras.Model(inputs, outputs) 

85 ``` 

86 

87 Args: 

88 inputs: List of input tensors (must be created via `tf.keras.Input()`). 

89 outputs: List of output tensors. 

90 name: String, optional. Name of the model. 

91 trainable: Boolean, optional. If the model's variables should be trainable. 

92 """ 

93 

94 # See tf.Module for the usage of this property. 

95 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to 

96 # flatten the key since it is trying to convert Trackable/Layer to a string. 

97 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 

98 ('_layer_call_argspecs', '_compiled_trainable_state', 

99 '_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'), 

100 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES 

101 )) 

102 

103 @trackable.no_automatic_dependency_tracking 

104 def __init__(self, inputs, outputs, name=None, trainable=True, 

105 **kwargs): 

106 # This is used by the Model class, since we have some logic to swap the 

107 # class in the __new__ method, which will lead to __init__ get invoked 

108 # twice. Using the skip_init to skip one of the invocation of __init__ to 

109 # avoid any side effects 

110 skip_init = kwargs.pop('skip_init', False) 

111 if skip_init: 

112 return 

113 generic_utils.validate_kwargs(kwargs, {}) 

114 super(Functional, self).__init__(name=name, trainable=trainable) 

115 self._init_graph_network(inputs, outputs) 

116 

117 @trackable.no_automatic_dependency_tracking 

118 def _init_graph_network(self, inputs, outputs): 

119 # This method is needed for Sequential to reinitialize graph network when 

120 # layer is added or removed. 

121 self._is_graph_network = True 

122 

123 # Normalize and set self.inputs, self.outputs. 

124 if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1: 

125 inputs = inputs[0] 

126 if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1: 

127 outputs = outputs[0] 

128 self._nested_inputs = inputs 

129 self._nested_outputs = outputs 

130 self.inputs = nest.flatten(inputs) 

131 self.outputs = nest.flatten(outputs) 

132 

133 # Models constructed with a single Tensor or list of Tensors can 

134 # be called with a dict, where the keys of the dict are the names 

135 # of the `Input` objects. Extra keys are ignored with warning. 

136 if not nest.is_nested(self._nested_inputs): 

137 self._enable_dict_to_input_mapping = True 

138 elif (isinstance(self._nested_inputs, (list, tuple)) and 

139 not any(nest.is_nested(t) for t in self._nested_inputs)): 

140 self._enable_dict_to_input_mapping = True 

141 elif (isinstance(self._nested_inputs, dict) and 

142 not any(nest.is_nested(t) for t in self._nested_inputs.values())): 

143 self._enable_dict_to_input_mapping = True 

144 else: 

145 self._enable_dict_to_input_mapping = False 

146 

147 if not ops.executing_eagerly_outside_functions(): 

148 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): 

149 base_layer_utils.create_keras_history(self._nested_outputs) 

150 

151 self._validate_graph_inputs_and_outputs() 

152 

153 # A Network does not create weights of its own, thus it is already 

154 # built. 

155 self.built = True 

156 self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs) 

157 self._compute_output_and_mask_jointly = True 

158 # `_expects_training_arg` is True since the `training` argument is always 

159 # present in the signature of the `call` method of a graph network. 

160 self._expects_training_arg = True 

161 self._expects_mask_arg = True 

162 # A graph network does not autocast inputs, as its layers will cast them 

163 # instead. 

164 self._autocast = False 

165 

166 self._input_layers = [] 

167 self._output_layers = [] 

168 self._input_coordinates = [] 

169 self._output_coordinates = [] 

170 

171 # This is for performance optimization when calling the Network on new 

172 # inputs. Every time the Network is called on a set on input tensors, 

173 # we compute the output tensors, output masks and output shapes in one pass, 

174 # then cache them here. When any of these outputs is queried later, we 

175 # retrieve it from there instead of recomputing it. 

176 self._output_mask_cache = {} 

177 self._output_tensor_cache = {} 

178 self._output_shape_cache = {} 

179 

180 # Build self._output_layers: 

181 for x in self.outputs: 

182 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 

183 self._output_layers.append(layer) 

184 self._output_coordinates.append((layer, node_index, tensor_index)) 

185 

186 # Build self._input_layers: 

187 for x in self.inputs: 

188 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 

189 # It's supposed to be an input layer, so only one node 

190 # and one tensor output. 

191 assert node_index == 0 

192 assert tensor_index == 0 

193 self._input_layers.append(layer) 

194 self._input_coordinates.append((layer, node_index, tensor_index)) 

195 

196 # Keep track of the network's nodes and layers. 

197 nodes, nodes_by_depth, layers, _ = _map_graph_network( 

198 self.inputs, self.outputs) 

199 self._network_nodes = nodes 

200 self._nodes_by_depth = nodes_by_depth 

201 self._self_tracked_trackables = layers 

202 self._layer_call_argspecs = {} 

203 for layer in self._self_tracked_trackables: 

204 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 

205 

206 # Build self.input_names and self.output_names. 

207 self._set_output_names() 

208 self.input_names = [] 

209 self._feed_input_names = [] 

210 self._feed_inputs = [] 

211 self._feed_input_shapes = [] 

212 for layer in self._input_layers: 

213 self.input_names.append(layer.name) 

214 if layer.is_placeholder: 

215 self._feed_input_names.append(layer.name) 

216 # Use batch_input_shape here because non-eager composite tensors may not 

217 # have a shape attribute that's meaningful (sparse, for instance, has 

218 # a tensor that's non-constant and needs to be fed). This means that 

219 # input layers that create placeholders will need to have the 

220 # batch_input_shape attr to allow for input shape validation. 

221 self._feed_input_shapes.append(layer._batch_input_shape) 

222 self._feed_inputs.append(layer.input) 

223 

224 self._compute_tensor_usage_count() 

225 self._set_save_spec(self._nested_inputs) 

226 tf_utils.assert_no_legacy_layers(self.layers) 

227 

228 @property 

229 def input(self): 

230 """Retrieves the input tensor(s) of a layer. 

231 

232 Only applicable if the layer has exactly one input, 

233 i.e. if it is connected to one incoming layer. 

234 

235 Returns: 

236 Input tensor or list of input tensors. 

237 

238 Raises: 

239 RuntimeError: If called in Eager mode. 

240 AttributeError: If no inbound nodes are found. 

241 """ 

242 return self._nested_inputs 

243 

244 @property 

245 def input_shape(self): 

246 """Retrieves the input shape(s) of a layer. 

247 

248 Only applicable if the layer has exactly one input, 

249 i.e. if it is connected to one incoming layer, or if all inputs 

250 have the same shape. 

251 

252 Returns: 

253 Input shape, as an integer shape tuple 

254 (or list of shape tuples, one tuple per input tensor). 

255 

256 Raises: 

257 AttributeError: if the layer has no defined input_shape. 

258 RuntimeError: if called in Eager mode. 

259 """ 

260 return nest.map_structure(backend.int_shape, self.input) 

261 

262 @property 

263 def input_spec(self): 

264 if hasattr(self, '_manual_input_spec'): 

265 return self._manual_input_spec 

266 if (isinstance(self._nested_inputs, (dict, list, tuple)) and 

267 len(self._nested_inputs) != len(self.inputs)): 

268 # Case where we have a nested structure. 

269 # In such a case we can't safely run any checks. 

270 return None 

271 if isinstance(self._nested_inputs, dict): 

272 # Case where `_nested_inputs` is a plain dict of Inputs. 

273 names = sorted(self._nested_inputs.keys()) 

274 return [input_spec.InputSpec( 

275 shape=shape_with_no_batch_size(self._nested_inputs[name]), 

276 allow_last_axis_squeeze=True, name=name) for name in names] 

277 else: 

278 # Single input, or list / tuple of inputs. 

279 # The data may be passed as a dict keyed by input name. 

280 return [input_spec.InputSpec( 

281 shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True, 

282 name=x._keras_history.layer.name) for x in self.inputs] 

283 

284 @input_spec.setter 

285 def input_spec(self, value): 

286 self._manual_input_spec = value 

287 

288 @property 

289 def output(self): 

290 """Retrieves the output tensor(s) of a layer. 

291 

292 Only applicable if the layer has exactly one output, 

293 i.e. if it is connected to one incoming layer. 

294 

295 Returns: 

296 Output tensor or list of output tensors. 

297 

298 Raises: 

299 AttributeError: if the layer is connected to more than one incoming 

300 layers. 

301 RuntimeError: if called in Eager mode. 

302 """ 

303 return self._nested_outputs 

304 

305 @property 

306 def output_shape(self): 

307 """Retrieves the output shape(s) of a layer. 

308 

309 Only applicable if the layer has one output, 

310 or if all outputs have the same shape. 

311 

312 Returns: 

313 Output shape, as an integer shape tuple 

314 (or list of shape tuples, one tuple per output tensor). 

315 

316 Raises: 

317 AttributeError: if the layer has no defined output shape. 

318 RuntimeError: if called in Eager mode. 

319 """ 

320 return nest.map_structure(backend.int_shape, self.output) 

321 

322 def _set_output_names(self): 

323 """Assigns unique names to the Network's outputs. 

324 

325 Output layers with multiple output tensors would otherwise lead to duplicate 

326 names in self.output_names. 

327 """ 

328 uniquified = [] 

329 output_names = set() 

330 prefix_count = {} 

331 for layer in self._output_layers: 

332 proposal = layer.name 

333 while proposal in output_names: 

334 existing_count = prefix_count.get(layer.name, 1) 

335 proposal = '{}_{}'.format(layer.name, existing_count) 

336 prefix_count[layer.name] = existing_count + 1 

337 output_names.add(proposal) 

338 uniquified.append(proposal) 

339 self.output_names = uniquified 

340 

341 @property 

342 def _layer_checkpoint_dependencies(self): 

343 """Dictionary of layer dependencies to be included in the checkpoint.""" 

344 weight_layer_index = 0 

345 

346 dependencies = collections.OrderedDict() 

347 for layer_index, layer in enumerate(self.layers): 

348 try: 

349 if layer.weights: 

350 # Keep a separate index for layers which have weights. This allows 

351 # users to insert Layers without weights anywhere in the network 

352 # without breaking checkpoints. 

353 dependencies['layer_with_weights-%d' % weight_layer_index] = layer 

354 weight_layer_index += 1 

355 except ValueError: 

356 # The layer might have weights, but may not be built yet. We just treat 

357 # it as layer without weight. 

358 pass 

359 

360 # Even if it doesn't have weights, we should still track everything in 

361 # case it has/will have Trackable dependencies. 

362 dependencies['layer-%d' % layer_index] = layer 

363 return dependencies 

364 

365 def _trackable_children(self, 

366 save_type=trackable.SaveType.CHECKPOINT, 

367 **kwargs): 

368 dependencies = self._layer_checkpoint_dependencies 

369 dependencies.update( 

370 super(Functional, self)._trackable_children(save_type, **kwargs)) 

371 return dependencies 

372 

373 def _lookup_dependency(self, name): 

374 layer_dependencies = self._layer_checkpoint_dependencies 

375 if name in layer_dependencies: 

376 return layer_dependencies[name] 

377 return super(Functional, self)._lookup_dependency(name) 

378 

379 def _handle_deferred_layer_dependencies(self, layers): 

380 """Handles layer checkpoint dependencies that are added after init.""" 

381 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies 

382 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()} 

383 for layer in layers: 

384 if layer in layer_to_name: 

385 self._handle_deferred_dependencies(name=layer_to_name[layer], 

386 trackable=layer) 

387 

388 @property 

389 def _should_compute_mask(self): 

390 return True 

391 

392 def compute_mask(self, inputs, mask): 

393 # TODO(omalleyt): b/123540974 This function is not really safe to call 

394 # by itself because it will duplicate any updates and losses in graph 

395 # mode by `call`ing the Layers again. 

396 output_tensors = self._run_internal_graph(inputs, mask=mask) 

397 return nest.map_structure(lambda t: getattr(t, '_keras_mask', None), 

398 output_tensors) 

399 

400 @doc_controls.do_not_doc_inheritable 

401 def call(self, inputs, training=None, mask=None): 

402 """Calls the model on new inputs. 

403 

404 In this case `call` just reapplies 

405 all ops in the graph to the new inputs 

406 (e.g. build a new computational graph from the provided inputs). 

407 

408 Args: 

409 inputs: A tensor or list of tensors. 

410 training: Boolean or boolean scalar tensor, indicating whether to run 

411 the `Network` in training mode or inference mode. 

412 mask: A mask or list of masks. A mask can be 

413 either a tensor or None (no mask). 

414 

415 Returns: 

416 A tensor if there is a single output, or 

417 a list of tensors if there are more than one outputs. 

418 """ 

419 return self._run_internal_graph( 

420 inputs, training=training, mask=mask) 

421 

422 def compute_output_shape(self, input_shape): 

423 # Convert any shapes in tuple format to TensorShapes. 

424 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 

425 

426 if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)): 

427 raise ValueError('Invalid input_shape argument ' + str(input_shape) + 

428 ': model has ' + str(len(self._input_layers)) + 

429 ' tensor inputs.') 

430 

431 # Use the tuple of TensorShape as the cache key, since tuple is hashable 

432 # and can be used as hash key. 

433 try: 

434 cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True)) 

435 if cache_key in self._output_shape_cache: 

436 # Cache hit. Return shapes as TensorShapes. 

437 return self._output_shape_cache[cache_key] 

438 except ValueError: 

439 # In case there are unknown TensorShape, eg for sparse tensor input, 

440 # We skip the caching since the shape is unknown. 

441 pass 

442 

443 layers_to_output_shapes = {} 

444 for layer, shape in zip(self._input_layers, nest.flatten(input_shape)): 

445 # It's an input layer: then `compute_output_shape` is identity, 

446 # and there is only one node and one tensor.. 

447 shape_key = layer.name + '_0_0' 

448 layers_to_output_shapes[shape_key] = shape 

449 

450 depth_keys = list(self._nodes_by_depth.keys()) 

451 depth_keys.sort(reverse=True) 

452 # Iterate over nodes, by depth level. 

453 if len(depth_keys) > 1: 

454 for depth in depth_keys: 

455 nodes = self._nodes_by_depth[depth] 

456 for node in nodes: 

457 layer = node.layer 

458 if layer in self._input_layers: 

459 # We've already covered the input layers 

460 # a few lines above. 

461 continue 

462 # Get the input shapes for the first argument of the node 

463 layer_input_shapes = [] 

464 layer_inputs = node.call_args[0] 

465 for layer_input in nest.flatten(layer_inputs): 

466 kh = layer_input._keras_history 

467 input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index, 

468 kh.tensor_index) 

469 layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) 

470 layer_input_shapes = nest.pack_sequence_as(layer_inputs, 

471 layer_input_shapes) 

472 # Layers expect shapes to be tuples for `compute_output_shape`. 

473 layer_input_shapes = tf_utils.convert_shapes( 

474 layer_input_shapes, to_tuples=True) 

475 layer_output_shapes = layer.compute_output_shape(layer_input_shapes) 

476 # Convert back to TensorShapes. 

477 layer_output_shapes = tf_utils.convert_shapes( 

478 layer_output_shapes, to_tuples=False) 

479 

480 node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access 

481 for j, shape in enumerate(nest.flatten(layer_output_shapes)): 

482 shape_key = layer.name + '_%s_%s' % (node_index, j) 

483 layers_to_output_shapes[shape_key] = shape 

484 

485 # Read final output shapes from layers_to_output_shapes. 

486 output_shapes = [] 

487 for i in range(len(self._output_layers)): 

488 layer, node_index, tensor_index = self._output_coordinates[i] 

489 shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) 

490 output_shapes.append(layers_to_output_shapes[shape_key]) 

491 output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes) 

492 # Store in cache. 

493 self._output_shape_cache[cache_key] = output_shapes 

494 

495 # Return shapes as TensorShapes. 

496 return output_shapes 

497 

498 def _init_set_name(self, name, zero_based=True): 

499 if not name: 

500 cls_name = self.__class__.__name__ 

501 if self.__class__ == Functional: 

502 # Hide the functional class name from user, since its not a public 

503 # visible class. Use "Model" instead, 

504 cls_name = 'Model' 

505 self._name = backend.unique_object_name( 

506 generic_utils.to_snake_case(cls_name), 

507 zero_based=zero_based) 

508 else: 

509 self._name = name 

510 

511 def _run_internal_graph(self, inputs, training=None, mask=None): 

512 """Computes output tensors for new inputs. 

513 

514 # Note: 

515 - Can be run on non-Keras tensors. 

516 

517 Args: 

518 inputs: Tensor or nested structure of Tensors. 

519 training: Boolean learning phase. 

520 mask: (Optional) Tensor or nested structure of Tensors. 

521 

522 Returns: 

523 output_tensors 

524 """ 

525 inputs = self._flatten_to_reference_inputs(inputs) 

526 if mask is None: 

527 masks = [None] * len(inputs) 

528 else: 

529 masks = self._flatten_to_reference_inputs(mask) 

530 for input_t, mask in zip(inputs, masks): 

531 input_t._keras_mask = mask 

532 

533 # Dictionary mapping reference tensors to computed tensors. 

534 tensor_dict = {} 

535 tensor_usage_count = self._tensor_usage_count 

536 for x, y in zip(self.inputs, inputs): 

537 y = self._conform_to_reference_input(y, ref_input=x) 

538 x_id = str(id(x)) 

539 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 

540 

541 nodes_by_depth = self._nodes_by_depth 

542 depth_keys = list(nodes_by_depth.keys()) 

543 depth_keys.sort(reverse=True) 

544 

545 for depth in depth_keys: 

546 nodes = nodes_by_depth[depth] 

547 for node in nodes: 

548 if node.is_input: 

549 continue # Input tensors already exist. 

550 

551 if any(t_id not in tensor_dict for t_id in node.flat_input_ids): 

552 continue # Node is not computable, try skipping. 

553 

554 args, kwargs = node.map_arguments(tensor_dict) 

555 outputs = node.layer(*args, **kwargs) 

556 

557 # Update tensor_dict. 

558 for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)): 

559 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 

560 

561 output_tensors = [] 

562 for x in self.outputs: 

563 x_id = str(id(x)) 

564 assert x_id in tensor_dict, 'Could not compute output ' + str(x) 

565 output_tensors.append(tensor_dict[x_id].pop()) 

566 

567 return nest.pack_sequence_as(self._nested_outputs, output_tensors) 

568 

569 def _flatten_to_reference_inputs(self, tensors): 

570 """Maps `tensors` to their respective `keras.Input`.""" 

571 if self._enable_dict_to_input_mapping and isinstance(tensors, dict): 

572 ref_inputs = self._nested_inputs 

573 if not nest.is_nested(ref_inputs): 

574 ref_inputs = [self._nested_inputs] 

575 if isinstance(ref_inputs, dict): 

576 # In the case that the graph is constructed with dict input tensors, 

577 # We will use the original dict key to map with the keys in the input 

578 # data. Note that the model.inputs is using nest.flatten to process the 

579 # input tensors, which means the dict input tensors are ordered by their 

580 # keys. 

581 ref_input_names = sorted(ref_inputs.keys()) 

582 else: 

583 ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs] 

584 

585 # Raise an warning if there are more input data comparing to input tensor 

586 if len(tensors) > len(ref_input_names): 

587 warnings.warn( 

588 'Input dict contained keys {} which did not match any model input. ' 

589 'They will be ignored by the model.'.format( 

590 [n for n in tensors.keys() if n not in ref_input_names]) 

591 ) 

592 

593 try: 

594 # Flatten in the order `Input`s were passed during Model construction. 

595 return [tensors[n] for n in ref_input_names] 

596 except KeyError: 

597 # TODO(b/151582614) 

598 return nest.flatten(tensors) 

599 

600 # Otherwise both self.inputs and tensors will already be in same order. 

601 return nest.flatten(tensors) 

602 

603 def _conform_to_reference_input(self, tensor, ref_input): 

604 """Set shape and dtype based on `keras.Input`s.""" 

605 if isinstance(tensor, ops.Tensor): 

606 # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use 

607 # the shape specified by the `keras.Input`. 

608 t_shape = tensor.shape 

609 t_rank = t_shape.rank 

610 ref_shape = ref_input.shape 

611 ref_rank = ref_shape.rank 

612 keras_history = getattr(tensor, '_keras_history', None) 

613 if t_rank is not None and ref_rank is not None: 

614 # Should squeeze last dimension. 

615 # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...). 

616 if (t_rank == ref_rank + 1 and t_shape[-1] == 1): 

617 tensor = array_ops.squeeze_v2(tensor, axis=-1) 

618 # Should expand last_dimension. 

619 # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1). 

620 elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1): 

621 tensor = array_ops.expand_dims_v2(tensor, axis=-1) 

622 if keras_history is not None: # Restore keras history. 

623 tensor._keras_history = keras_history 

624 

625 # Add shape hints to Tensors that may have None shape dims but have shapes 

626 # defined by the `keras.Input` (not applicable in eager mode). 

627 if not context.executing_eagerly(): 

628 try: 

629 tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) 

630 except ValueError: 

631 logging.warning( 

632 'Model was constructed with shape {} for input {}, but it was ' 

633 'called on an input with incompatible shape {}.'.format( 

634 ref_input.shape, ref_input, tensor.shape)) 

635 

636 # Dtype casting. 

637 tensor = math_ops.cast(tensor, dtype=ref_input.dtype) 

638 elif tf_utils.is_extension_type(tensor): 

639 # Dtype casting (If the extension type has a non-variant dtype and 

640 # supports being cast) 

641 ref_input_dtype = getattr(ref_input, 'dtype', None) 

642 if ref_input_dtype is not None and ref_input_dtype != dtypes.variant: 

643 tensor = math_ops.cast(tensor, dtype=ref_input_dtype) 

644 

645 return tensor 

646 

647 def get_config(self): 

648 return copy.deepcopy(get_network_config(self)) 

649 

650 @classmethod 

651 def from_config(cls, config, custom_objects=None): 

652 """Instantiates a Model from its config (output of `get_config()`). 

653 

654 Args: 

655 config: Model config dictionary. 

656 custom_objects: Optional dictionary mapping names 

657 (strings) to custom classes or functions to be 

658 considered during deserialization. 

659 

660 Returns: 

661 A model instance. 

662 

663 Raises: 

664 ValueError: In case of improperly formatted config dict. 

665 """ 

666 with generic_utils.SharedObjectLoadingScope(): 

667 input_tensors, output_tensors, created_layers = reconstruct_from_config( 

668 config, custom_objects) 

669 model = cls(inputs=input_tensors, outputs=output_tensors, 

670 name=config.get('name')) 

671 connect_ancillary_layers(model, created_layers) 

672 return model 

673 

674 def _validate_graph_inputs_and_outputs(self): 

675 """Validates the inputs and outputs of a Graph Network.""" 

676 # Check for redundancy in inputs. 

677 if len({id(i) for i in self.inputs}) != len(self.inputs): 

678 raise ValueError('The list of inputs passed to the model ' 

679 'is redundant. ' 

680 'All inputs should only appear once.' 

681 ' Found: ' + str(self.inputs)) 

682 

683 for x in self.inputs: 

684 # Check that x has appropriate `_keras_history` metadata. 

685 if not hasattr(x, '_keras_history'): 

686 cls_name = self.__class__.__name__ 

687 raise ValueError('Input tensors to a ' + cls_name + ' ' + 

688 'must come from `tf.keras.Input`. ' 

689 'Received: ' + str(x) + 

690 ' (missing previous layer metadata).') 

691 # Check that x is an input tensor. 

692 # pylint: disable=protected-access 

693 layer = x._keras_history.layer 

694 if len(layer._inbound_nodes) > 1 or ( 

695 layer._inbound_nodes and not layer._inbound_nodes[0].is_input): 

696 cls_name = self.__class__.__name__ 

697 logging.warning(cls_name + ' model inputs must come from ' 

698 '`tf.keras.Input` (thus holding past layer metadata), ' 

699 'they cannot be the output of ' 

700 'a previous non-Input layer. ' 

701 'Here, a tensor specified as ' 

702 'input to "' + self.name + '" was not an Input tensor, ' 

703 'it was generated by layer ' + layer.name + '.\n' 

704 'Note that input tensors are ' 

705 'instantiated via `tensor = tf.keras.Input(shape)`.\n' 

706 'The tensor that caused the issue was: ' + str(x.name)) 

707 

708 # Check compatibility of batch sizes of Input Layers. 

709 input_batch_sizes = [ 

710 training_utils.get_static_batch_size(x._keras_history.layer) 

711 for x in self.inputs 

712 ] 

713 consistent_batch_size = None 

714 for batch_size in input_batch_sizes: 

715 if batch_size is not None: 

716 if (consistent_batch_size is not None and 

717 batch_size != consistent_batch_size): 

718 raise ValueError('The specified batch sizes of the Input Layers' 

719 ' are incompatible. Found batch sizes: {}'.format( 

720 input_batch_sizes)) 

721 consistent_batch_size = batch_size 

722 

723 for x in self.outputs: 

724 if not hasattr(x, '_keras_history'): 

725 cls_name = self.__class__.__name__ 

726 raise ValueError('Output tensors of a ' + cls_name + ' model must be ' 

727 'the output of a TensorFlow `Layer` ' 

728 '(thus holding past layer metadata). Found: ' + str(x)) 

729 

730 def _insert_layers(self, layers, relevant_nodes=None): 

731 """Inserts Layers into the Network after Network creation. 

732 

733 This is only valid for Keras Graph Networks. Layers added via this function 

734 will be included in the `call` computation and `get_config` of this Network. 

735 They will not be added to the Network's outputs. 

736 

737 

738 Args: 

739 layers: Arbitrary nested structure of Layers. Layers must be reachable 

740 from one or more of the `keras.Input` Tensors that correspond to this 

741 Network's inputs. 

742 relevant_nodes: Nodes from the Layers that should be considered part of 

743 this Network. If `None`, all Nodes will be considered part of this 

744 Network. 

745 

746 Raises: 

747 ValueError: If the layers depend on `Input`s not found in this Model. 

748 """ 

749 layers = nest.flatten(layers) 

750 tf_utils.assert_no_legacy_layers(layers) 

751 node_to_depth = {} 

752 for depth, nodes in self._nodes_by_depth.items(): 

753 node_to_depth.update({node: depth for node in nodes}) 

754 # The nodes of these Layers that are relevant to this Network. If not 

755 # provided, assume all Nodes are relevant 

756 if not relevant_nodes: 

757 relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers]) 

758 network_nodes = set(relevant_nodes + list(node_to_depth.keys())) 

759 

760 def _get_min_depth(node): 

761 """Gets the minimum depth at which node can be computed.""" 

762 min_depth = 0 

763 for layer, node_id, _, _ in node.iterate_inbound(): 

764 inbound_node = layer._inbound_nodes[node_id] 

765 if inbound_node in node_to_depth: 

766 min_depth = min(min_depth, node_to_depth[inbound_node]) 

767 elif inbound_node not in network_nodes: 

768 continue 

769 else: 

770 # Previous relevant nodes haven't been processed yet. 

771 return None 

772 # New node is one shallower than its shallowest input. 

773 return min_depth - 1 

774 

775 # Insert nodes into `_nodes_by_depth` and other node attrs. 

776 unprocessed_nodes = copy.copy(relevant_nodes) 

777 i = 0 

778 while unprocessed_nodes: 

779 i += 1 

780 # Do a sanity check. This can occur if `Input`s from outside this Model 

781 # are being relied on. 

782 if i > 10000: 

783 raise ValueError('Layers could not be added due to missing ' 

784 'dependencies.') 

785 

786 node = unprocessed_nodes.pop(0) 

787 depth = _get_min_depth(node) 

788 if depth is None: # Defer until inbound nodes are processed. 

789 unprocessed_nodes.append(node) 

790 continue 

791 node_key = _make_node_key(node.layer.name, 

792 node.layer._inbound_nodes.index(node)) 

793 if node_key not in self._network_nodes: 

794 node_to_depth[node] = depth 

795 self._network_nodes.add(node_key) 

796 self._nodes_by_depth[depth].append(node) 

797 

798 # Insert layers and update other layer attrs. 

799 layer_set = set(self._self_tracked_trackables) 

800 deferred_layers = [] 

801 for layer in layers: 

802 if layer not in layer_set: 

803 self._self_tracked_trackables.append(layer) 

804 deferred_layers.append(layer) 

805 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 

806 layer_set.add(layer) 

807 self._handle_deferred_layer_dependencies(deferred_layers) 

808 

809 self._compute_tensor_usage_count() 

810 

811 def _compute_tensor_usage_count(self): 

812 """Compute the #. of tensor usages for all the output tensors of layers. 

813 

814 The computed tensor usage count is saved as `self._tensor_usage_count`. This 

815 is later used for saving memory in eager computation by releasing 

816 no-longer-needed tensors as early as possible. 

817 """ 

818 tensor_usage_count = collections.Counter() 

819 available_tensors = set(str(id(tensor)) for tensor in self.inputs) 

820 

821 depth_keys = list(self._nodes_by_depth.keys()) 

822 depth_keys.sort(reverse=True) 

823 depth_keys = depth_keys[1:] 

824 

825 for depth in depth_keys: 

826 for node in self._nodes_by_depth[depth]: 

827 input_tensors = { 

828 str(id(tensor)) for tensor in nest.flatten(node.keras_inputs) 

829 } 

830 if input_tensors.issubset(available_tensors): 

831 for tensor in nest.flatten(node.keras_inputs): 

832 tensor_usage_count[str(id(tensor))] += 1 

833 

834 for output_tensor in nest.flatten(node.outputs): 

835 available_tensors.add(str(id(output_tensor))) 

836 

837 for tensor in self.outputs: 

838 tensor_usage_count[str(id(tensor))] += 1 

839 

840 self._tensor_usage_count = tensor_usage_count 

841 

842 def _assert_weights_created(self): 

843 # Override the implementation in Model. 

844 # The Functional model should always have weight created already. 

845 return 

846 

847 def _graph_network_add_loss(self, symbolic_loss): 

848 new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss]) 

849 # Losses must be keyed on inputs no matter what in order to be supported in 

850 # DistributionStrategy. 

851 add_loss_layer = base_layer.AddLoss( 

852 unconditional=False, dtype=symbolic_loss.dtype) 

853 add_loss_layer(symbolic_loss) 

854 new_nodes.extend(add_loss_layer.inbound_nodes) 

855 new_layers.append(add_loss_layer) 

856 self._insert_layers(new_layers, new_nodes) 

857 

858 def _graph_network_add_metric(self, value, aggregation, name): 

859 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value]) 

860 add_metric_layer = base_layer.AddMetric( 

861 aggregation, name, dtype=value.dtype) 

862 add_metric_layer(value) 

863 new_nodes.extend(add_metric_layer.inbound_nodes) 

864 new_layers.append(add_metric_layer) 

865 self._insert_layers(new_layers, new_nodes) 

866 

867 @property 

868 def _trackable_saved_model_saver(self): 

869 return network_serialization.NetworkSavedModelSaver(self) 

870 

871 def _get_save_spec(self, dynamic_batch=True): 

872 if getattr(self, '_has_explicit_input_shape', True): 

873 # Functional models and Sequential models that have an explicit input 

874 # shape should use the batch size set by the input layer. 

875 dynamic_batch = False 

876 return super(Functional, self)._get_save_spec(dynamic_batch) 

877 

878 

879def _make_node_key(layer_name, node_index): 

880 return layer_name + '_ib-' + str(node_index) 

881 

882 

883def _map_graph_network(inputs, outputs): 

884 """Validates a network's topology and gather its layers and nodes. 

885 

886 Args: 

887 inputs: List of input tensors. 

888 outputs: List of outputs tensors. 

889 

890 Returns: 

891 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. 

892 - nodes: list of Node instances. 

893 - nodes_by_depth: dict mapping ints (depth) to lists of node instances. 

894 - layers: list of Layer instances. 

895 - layers_by_depth: dict mapping ints (depth) to lists of layer instances. 

896 

897 Raises: 

898 ValueError: In case the network is not valid (e.g. disconnected graph). 

899 """ 

900 # "depth" is number of layers between output Node and the Node. 

901 # Nodes are ordered from inputs -> outputs. 

902 nodes_in_decreasing_depth, layer_indices = _build_map(outputs) 

903 network_nodes = { 

904 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node)) 

905 for node in nodes_in_decreasing_depth 

906 } 

907 

908 nodes_depths = {} # dict {node: depth value} 

909 layers_depths = {} # dict {layer: depth value} 

910 

911 for node in reversed(nodes_in_decreasing_depth): 

912 # If the depth is not set, the node has no outbound nodes (depth 0). 

913 depth = nodes_depths.setdefault(node, 0) 

914 

915 # Update the depth of the corresponding layer 

916 previous_depth = layers_depths.get(node.layer, 0) 

917 # If we've seen this layer before at a higher depth, 

918 # we should use that depth instead of the node depth. 

919 # This is necessary for shared layers that have inputs at different 

920 # depth levels in the graph. 

921 depth = max(depth, previous_depth) 

922 layers_depths[node.layer] = depth 

923 nodes_depths[node] = depth 

924 

925 # Update the depth of inbound nodes. 

926 # The "depth" of a node is the max of the depths 

927 # of all nodes it is connected to + 1. 

928 for node_dep in node.parent_nodes: 

929 previous_depth = nodes_depths.get(node_dep, 0) 

930 nodes_depths[node_dep] = max(depth + 1, previous_depth) 

931 

932 # Handle inputs that are not connected to outputs. 

933 # We do not error out here because the inputs may be used to compute losses 

934 # and metrics. 

935 for input_t in inputs: 

936 input_layer = input_t._keras_history[0] 

937 if input_layer not in layers_depths: 

938 layers_depths[input_layer] = 0 

939 layer_indices[input_layer] = -1 

940 nodes_depths[input_layer._inbound_nodes[0]] = 0 

941 network_nodes.add(_make_node_key(input_layer.name, 0)) 

942 

943 # Build a dict {depth: list of nodes with this depth} 

944 nodes_by_depth = collections.defaultdict(list) 

945 for node, depth in nodes_depths.items(): 

946 nodes_by_depth[depth].append(node) 

947 

948 # Build a dict {depth: list of layers with this depth} 

949 layers_by_depth = collections.defaultdict(list) 

950 for layer, depth in layers_depths.items(): 

951 layers_by_depth[depth].append(layer) 

952 

953 # Get sorted list of layer depths. 

954 depth_keys = list(layers_by_depth.keys()) 

955 depth_keys.sort(reverse=True) 

956 

957 # Set self.layers ordered by depth. 

958 layers = [] 

959 for depth in depth_keys: 

960 layers_for_depth = layers_by_depth[depth] 

961 # Network.layers needs to have a deterministic order: 

962 # here we order them by traversal order. 

963 layers_for_depth.sort(key=lambda x: layer_indices[x]) 

964 layers.extend(layers_for_depth) 

965 

966 # Get sorted list of node depths. 

967 depth_keys = list(nodes_by_depth.keys()) 

968 depth_keys.sort(reverse=True) 

969 

970 # Check that all tensors required are computable. 

971 # computable_tensors: all tensors in the graph 

972 # that can be computed from the inputs provided. 

973 computable_tensors = set() 

974 for x in inputs: 

975 computable_tensors.add(id(x)) 

976 

977 layers_with_complete_input = [] # To provide a better error msg. 

978 for depth in depth_keys: 

979 for node in nodes_by_depth[depth]: 

980 layer = node.layer 

981 if layer and not node.is_input: 

982 for x in nest.flatten(node.keras_inputs): 

983 if id(x) not in computable_tensors: 

984 raise ValueError('Graph disconnected: ' 

985 'cannot obtain value for tensor ' + str(x) + 

986 ' at layer "' + layer.name + '". ' 

987 'The following previous layers ' 

988 'were accessed without issue: ' + 

989 str(layers_with_complete_input)) 

990 for x in nest.flatten(node.outputs): 

991 computable_tensors.add(id(x)) 

992 layers_with_complete_input.append(layer.name) 

993 

994 # Ensure name unicity, which will be crucial for serialization 

995 # (since serialized nodes refer to layers by their name). 

996 all_names = [layer.name for layer in layers] 

997 for name in all_names: 

998 if all_names.count(name) != 1: 

999 raise ValueError('The name "' + name + '" is used ' + 

1000 str(all_names.count(name)) + ' times in the model. ' 

1001 'All layer names should be unique.') 

1002 return network_nodes, nodes_by_depth, layers, layers_by_depth 

1003 

1004 

1005def _build_map(outputs): 

1006 """This method topologically sorts nodes in order from inputs to outputs. 

1007 

1008 It uses a depth-first search to topologically sort nodes that appear in the 

1009 _keras_history connectivity metadata of `outputs`. 

1010 

1011 Args: 

1012 outputs: the output tensors whose _keras_history metadata should be walked. 

1013 This may be an arbitrary nested structure. 

1014 

1015 Returns: 

1016 A tuple like (ordered_nodes, layer_to_first_traversal_index) 

1017 ordered_nodes: list of nodes appearing in the keras history, topologically 

1018 sorted from original inputs to the `outputs`. 

1019 (If outputs have different sets of ancestors, the inputs to one output 

1020 may appear after a different output). 

1021 layer_to_first_traversal_index: 

1022 A dict mapping layer to the traversal index in the DFS where it is 

1023 seen. Note: if a layer is shared by several nodes, the dict will only 

1024 store the index corresponding to the *first* time the layer seen. 

1025 """ 

1026 finished_nodes = set() 

1027 nodes_in_progress = set() 

1028 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. 

1029 layer_indices = {} # layer -> in traversal order. 

1030 for output in nest.flatten(outputs): 

1031 _build_map_helper(output, finished_nodes, nodes_in_progress, 

1032 nodes_in_decreasing_depth, layer_indices) 

1033 return nodes_in_decreasing_depth, layer_indices 

1034 

1035 

1036def _build_map_helper(tensor, finished_nodes, nodes_in_progress, 

1037 nodes_in_decreasing_depth, layer_indices): 

1038 """Recursive helper for `_build_map`.""" 

1039 layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access 

1040 node = layer._inbound_nodes[node_index] # pylint: disable=protected-access 

1041 

1042 # Don't repeat work for shared subgraphs 

1043 if node in finished_nodes: 

1044 return 

1045 

1046 # Prevent cycles. 

1047 if node in nodes_in_progress: 

1048 raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name + 

1049 '" is part of a cycle.') 

1050 

1051 # Store the traversal order for layer sorting. 

1052 if layer not in layer_indices: 

1053 layer_indices[layer] = len(layer_indices) 

1054 

1055 # Propagate to all previous tensors connected to this node. 

1056 nodes_in_progress.add(node) 

1057 if not node.is_input: 

1058 for tensor in node.keras_inputs: 

1059 _build_map_helper(tensor, finished_nodes, nodes_in_progress, 

1060 nodes_in_decreasing_depth, layer_indices) 

1061 

1062 finished_nodes.add(node) 

1063 nodes_in_progress.remove(node) 

1064 nodes_in_decreasing_depth.append(node) 

1065 

1066 

1067def _map_subgraph_network(inputs, outputs): 

1068 """Returns the nodes and layers in the topology from `inputs` to `outputs`. 

1069 

1070 Args: 

1071 inputs: List of input tensors. 

1072 outputs: List of output tensors. 

1073 

1074 Returns: 

1075 A tuple of List{Node] and List[Layer]. 

1076 """ 

1077 if not ops.executing_eagerly_outside_functions(): 

1078 base_layer_utils.create_keras_history(outputs) 

1079 # Keep only nodes and layers in the topology between inputs and outputs. 

1080 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs) 

1081 return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers 

1082 

1083 

1084def _should_skip_first_node(layer): 

1085 """Returns True if the first layer node should not be saved or loaded.""" 

1086 # Networks that are constructed with an Input layer/shape start with a 

1087 # pre-existing node linking their input to output. This node is excluded from 

1088 # the network config. 

1089 if layer._self_tracked_trackables: 

1090 return (isinstance(layer, Functional) and 

1091 # Filter out Sequential models without an input shape. 

1092 isinstance(layer._self_tracked_trackables[0], 

1093 input_layer_module.InputLayer)) 

1094 else: 

1095 return isinstance(layer, Functional) 

1096 

1097 

1098def connect_ancillary_layers(model, created_layers): 

1099 """Adds layers that are not connected to the outputs to the model.""" 

1100 # Layers not connected to outputs, such as those added in `add_loss`. 

1101 ancillary_layers = [ 

1102 layer for layer in created_layers.values() if layer not in model.layers 

1103 ] 

1104 if ancillary_layers: 

1105 relevant_nodes = nest.flatten([ 

1106 layer.inbound_nodes[1:] 

1107 if _should_skip_first_node(layer) else layer.inbound_nodes 

1108 for layer in created_layers.values() 

1109 ]) 

1110 model._insert_layers(ancillary_layers, relevant_nodes) 

1111 return model 

1112 

1113 

1114def reconstruct_from_config(config, custom_objects=None, created_layers=None): 

1115 """Reconstructs graph from config object. 

1116 

1117 Args: 

1118 config: Dictionary returned from Network.get_config() 

1119 custom_objects: Optional dictionary mapping names (strings) to custom 

1120 classes or functions to be considered during deserialization. 

1121 created_layers: Optional dictionary mapping names to Layer objects. Any 

1122 layer not in this dictionary will be created and added to the dict. 

1123 This function will add new nodes to all layers (excluding InputLayers), 

1124 instead of re-using pre-existing nodes in the layers. 

1125 

1126 Returns: 

1127 Tuple of (input tensors, output tensors, dictionary of created layers) 

1128 """ 

1129 # Layer instances created during the graph reconstruction process. 

1130 created_layers = created_layers or collections.OrderedDict() 

1131 

1132 # Maps input data (tuple of inbound layer name, node index) from the config 

1133 # to node indices in the newly generated model. The node indices may be 

1134 # different if the layers have already been called previously. 

1135 node_index_map = {} 

1136 node_count_by_layer = {} 

1137 

1138 # Dictionary mapping layer instances to 

1139 # node data that specifies a layer call. 

1140 # It acts as a queue that maintains any unprocessed 

1141 # layer call until it becomes possible to process it 

1142 # (i.e. until the input tensors to the call all exist). 

1143 unprocessed_nodes = {} 

1144 

1145 def add_unprocessed_node(layer, node_data): 

1146 if layer not in unprocessed_nodes: 

1147 unprocessed_nodes[layer] = [node_data] 

1148 else: 

1149 unprocessed_nodes[layer].append(node_data) 

1150 

1151 def get_node_index(layer, config_node_index): 

1152 """Returns node index in layer (might differ from config_node_index).""" 

1153 if isinstance(layer, input_layer_module.InputLayer): 

1154 return 0 

1155 return node_index_map.get((layer.name, config_node_index), None) 

1156 

1157 def _deserialize_keras_tensors(kwargs, layer_map): 

1158 """Deserializes Keras Tensors passed to `call`..""" 

1159 

1160 def _deserialize_keras_tensor(t): 

1161 """Deserializes a single Keras Tensor passed to `call`.""" 

1162 if isinstance(t, tf_utils.ListWrapper): 

1163 t = t.as_list() 

1164 layer_name = t[0] 

1165 node_index = t[1] 

1166 tensor_index = t[2] 

1167 

1168 layer = layer_map[layer_name] 

1169 new_node_index = get_node_index(layer, node_index) 

1170 if new_node_index is None: 

1171 # The inbound node may not have been processed yet, 

1172 # (This can happen e.g. if it depends on a different set 

1173 # of inputs than those that have been processed already). 

1174 # raise an IndexError so that the current node puts itself 

1175 # back on the unprocessed queue. 

1176 # Caution: This may lead to infinite loops for malformed 

1177 # network configurations! (or when there is a bug in 

1178 # the network config loading code). 

1179 raise IndexError 

1180 node = layer._inbound_nodes[new_node_index] 

1181 return nest.flatten(node.outputs)[tensor_index] 

1182 return t 

1183 

1184 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) 

1185 return nest.map_structure(_deserialize_keras_tensor, kwargs) 

1186 

1187 def process_node(layer, node_data): 

1188 """Deserialize a node. 

1189 

1190 Args: 

1191 layer: layer instance. 

1192 node_data: Nested structure of `ListWrapper`. 

1193 

1194 Raises: 

1195 ValueError: In case of improperly formatted `node_data`. 

1196 """ 

1197 input_tensors = [] 

1198 for input_data in nest.flatten(node_data): 

1199 input_data = input_data.as_list() 

1200 inbound_layer_name = input_data[0] 

1201 inbound_node_index = input_data[1] 

1202 inbound_tensor_index = input_data[2] 

1203 if len(input_data) == 3: 

1204 kwargs = {} 

1205 elif len(input_data) == 4: 

1206 kwargs = input_data[3] 

1207 try: 

1208 kwargs = _deserialize_keras_tensors(kwargs, created_layers) 

1209 except IndexError: 

1210 # Happens if keras tensors in kwargs are still unprocessed 

1211 add_unprocessed_node(layer, node_data) 

1212 return 

1213 else: 

1214 raise ValueError('Improperly formatted model config.') 

1215 

1216 if inbound_layer_name != node_module._CONSTANT_VALUE: 

1217 inbound_layer = created_layers[inbound_layer_name] 

1218 inbound_node_index = get_node_index(inbound_layer, inbound_node_index) 

1219 

1220 if inbound_node_index is None: 

1221 add_unprocessed_node(layer, node_data) 

1222 return 

1223 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 

1224 input_tensors.append( 

1225 nest.flatten(inbound_node.outputs)[inbound_tensor_index]) 

1226 else: 

1227 # We received a constant w/ no Keras history attached 

1228 input_tensors.append(inbound_tensor_index) 

1229 input_tensors = nest.pack_sequence_as(node_data, input_tensors) 

1230 # Call layer on its inputs, thus creating the node 

1231 # and building the layer if needed. 

1232 if input_tensors is not None: 

1233 if not layer._preserve_input_structure_in_config: 

1234 input_tensors = ( 

1235 base_layer_utils.unnest_if_single_tensor(input_tensors)) 

1236 output_tensors = layer(input_tensors, **kwargs) 

1237 

1238 # Update node index map. 

1239 output_index = nest.flatten(output_tensors)[0]._keras_history.node_index 

1240 node_index_map[(layer.name, node_count_by_layer[layer])] = output_index 

1241 node_count_by_layer[layer] += 1 

1242 

1243 def process_layer(layer_data): 

1244 """Deserializes a layer, then call it on appropriate inputs. 

1245 

1246 Args: 

1247 layer_data: layer config dict. 

1248 

1249 Raises: 

1250 ValueError: In case of improperly formatted `layer_data` dict. 

1251 """ 

1252 layer_name = layer_data['name'] 

1253 

1254 if layer_name in created_layers: 

1255 layer = created_layers[layer_name] 

1256 else: 

1257 # Instantiate layer. 

1258 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 

1259 

1260 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 

1261 created_layers[layer_name] = layer 

1262 

1263 node_count_by_layer[layer] = int(_should_skip_first_node(layer)) 

1264 

1265 # Gather layer inputs and convert to `ListWrapper` objects. 

1266 inbound_nodes_data = layer_data['inbound_nodes'] 

1267 inbound_nodes_data = tf_utils.convert_inner_node_data( 

1268 inbound_nodes_data, wrap=True) 

1269 for node_data in inbound_nodes_data: 

1270 # We don't process nodes (i.e. make layer calls) 

1271 # on the fly because the inbound node may not yet exist, 

1272 # in case of layer shared at different topological depths 

1273 # (e.g. a model such as A(B(A(B(x))))) 

1274 add_unprocessed_node(layer, node_data) 

1275 

1276 # First, we create all layers and enqueue nodes to be processed 

1277 for layer_data in config['layers']: 

1278 process_layer(layer_data) 

1279 # Then we process nodes in order of layer depth. 

1280 # Nodes that cannot yet be processed (if the inbound node 

1281 # does not yet exist) are re-enqueued, and the process 

1282 # is repeated until all nodes are processed. 

1283 while unprocessed_nodes: 

1284 for layer_data in config['layers']: 

1285 layer = created_layers[layer_data['name']] 

1286 if layer in unprocessed_nodes: 

1287 for node_data in unprocessed_nodes.pop(layer): 

1288 process_node(layer, node_data) 

1289 

1290 input_tensors = [] 

1291 output_tensors = [] 

1292 

1293 input_layers = tf_utils.convert_inner_node_data( 

1294 config['input_layers'], wrap=True) 

1295 for layer_data in nest.flatten(input_layers): 

1296 layer_name, node_index, tensor_index = layer_data.as_list() 

1297 assert layer_name in created_layers 

1298 layer = created_layers[layer_name] 

1299 node_index = get_node_index(layer, node_index) 

1300 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 

1301 input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 

1302 

1303 output_layers = tf_utils.convert_inner_node_data( 

1304 config['output_layers'], wrap=True) 

1305 for layer_data in nest.flatten(output_layers): 

1306 layer_name, node_index, tensor_index = layer_data.as_list() 

1307 assert layer_name in created_layers 

1308 layer = created_layers[layer_name] 

1309 node_index = get_node_index(layer, node_index) 

1310 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 

1311 output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 

1312 

1313 input_tensors = nest.pack_sequence_as(input_layers, input_tensors) 

1314 output_tensors = nest.pack_sequence_as(output_layers, output_tensors) 

1315 return input_tensors, output_tensors, created_layers 

1316 

1317 

1318def get_network_config(network, serialize_layer_fn=None): 

1319 """Builds the config, which consists of the node graph and serialized layers. 

1320 

1321 Args: 

1322 network: A Network object. 

1323 serialize_layer_fn: Function used to serialize layers. 

1324 

1325 Returns: 

1326 Config dictionary. 

1327 """ 

1328 serialize_layer_fn = ( 

1329 serialize_layer_fn or generic_utils.serialize_keras_object) 

1330 config = { 

1331 'name': network.name, 

1332 } 

1333 node_conversion_map = {} 

1334 for layer in network.layers: 

1335 kept_nodes = 1 if _should_skip_first_node(layer) else 0 

1336 for original_node_index, node in enumerate(layer._inbound_nodes): 

1337 node_key = _make_node_key(layer.name, original_node_index) 

1338 if node_key in network._network_nodes: 

1339 node_conversion_map[node_key] = kept_nodes 

1340 kept_nodes += 1 

1341 layer_configs = [] 

1342 

1343 with generic_utils.SharedObjectSavingScope(): 

1344 for layer in network.layers: # From the earliest layers on. 

1345 filtered_inbound_nodes = [] 

1346 for original_node_index, node in enumerate(layer._inbound_nodes): 

1347 node_key = _make_node_key(layer.name, original_node_index) 

1348 if node_key in network._network_nodes and not node.is_input: 

1349 # The node is relevant to the model: 

1350 # add to filtered_inbound_nodes. 

1351 node_data = node.serialize(_make_node_key, node_conversion_map) 

1352 filtered_inbound_nodes.append(node_data) 

1353 

1354 layer_config = serialize_layer_fn(layer) 

1355 layer_config['name'] = layer.name 

1356 layer_config['inbound_nodes'] = filtered_inbound_nodes 

1357 layer_configs.append(layer_config) 

1358 config['layers'] = layer_configs 

1359 

1360 # Gather info about inputs and outputs. 

1361 model_inputs = [] 

1362 for i in range(len(network._input_layers)): 

1363 layer, node_index, tensor_index = network._input_coordinates[i] 

1364 node_key = _make_node_key(layer.name, node_index) 

1365 if node_key not in network._network_nodes: 

1366 continue 

1367 new_node_index = node_conversion_map[node_key] 

1368 model_inputs.append( 

1369 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 

1370 model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs) 

1371 # Preserve external Keras compat for Models with single input. 

1372 if not nest.is_nested(model_inputs): 

1373 model_inputs = [model_inputs] 

1374 model_inputs = tf_utils.convert_inner_node_data(model_inputs) 

1375 config['input_layers'] = model_inputs 

1376 

1377 model_outputs = [] 

1378 for i in range(len(network._output_layers)): 

1379 layer, node_index, tensor_index = network._output_coordinates[i] 

1380 node_key = _make_node_key(layer.name, node_index) 

1381 if node_key not in network._network_nodes: 

1382 continue 

1383 new_node_index = node_conversion_map[node_key] 

1384 model_outputs.append( 

1385 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 

1386 model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs) 

1387 # Preserve external Keras compat for Models with single output. 

1388 if not nest.is_nested(model_outputs): 

1389 model_outputs = [model_outputs] 

1390 model_outputs = tf_utils.convert_inner_node_data(model_outputs) 

1391 config['output_layers'] = model_outputs 

1392 return config 

1393 

1394 

1395def shape_with_no_batch_size(x): 

1396 if x.shape.rank is None: 

1397 return None 

1398 shape = x.shape.as_list() 

1399 if shape: 

1400 shape[0] = None 

1401 return shape 

1402 

1403 

1404class ModuleWrapper(base_layer.Layer): 

1405 """Wrapper for `tf.Module`s to support the Functional and Sequential API.""" 

1406 

1407 def __init__(self, module, method_name=None, **kwargs): 

1408 """Initializes the wrapper Layer for this module. 

1409 

1410 Args: 

1411 module: The `tf.Module` instance to be wrapped. 

1412 method_name: (Optional) str. The name of the method to use as the forward 

1413 pass of the module. If not set, defaults to '__call__' if defined, or 

1414 'call'. 

1415 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`. 

1416 

1417 Raises: 

1418 ValueError: If `method` is not defined on `module`. 

1419 """ 

1420 super(ModuleWrapper, self).__init__(**kwargs) 

1421 if method_name is None: 

1422 if hasattr(module, '__call__'): 

1423 method_name = '__call__' 

1424 elif hasattr(module, 'call'): 

1425 method_name = 'call' 

1426 if method_name is None or not hasattr(module, method_name): 

1427 raise ValueError('{} is not defined on object {}'.format( 

1428 method_name, module)) 

1429 

1430 self._module = module 

1431 self._method_name = method_name 

1432 

1433 # Check if module.__call__ has a `training` arg or accepts `**kwargs`. 

1434 method = getattr(module, method_name) 

1435 method_arg_spec = tf_inspect.getfullargspec(method) 

1436 self._expects_training_arg = ('training' in method_arg_spec.args or 

1437 method_arg_spec.varkw is not None) 

1438 self._expects_mask_arg = ('mask' in method_arg_spec.args or 

1439 method_arg_spec.varkw is not None) 

1440 

1441 def call(self, *args, **kwargs): 

1442 if 'training' in kwargs and not self._expects_training_arg: 

1443 kwargs.pop('training') 

1444 if 'mask' in kwargs and not self._expects_mask_arg: 

1445 kwargs.pop('mask') 

1446 return getattr(self._module, self._method_name)(*args, **kwargs)