1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""A shim layer for working with functions exported/restored from saved models.
17
18This functionality should ultimately be moved into a first-class core API.
19"""
20
21import numpy
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.saved_model import registration
28from tensorflow.python.trackable import base as trackable
29
30
31@registration.register_tf_serializable()
32class TrackableConstant(trackable.Trackable):
33 """Trackable class for captured constants."""
34 __slots__ = ("capture", "function", "_exported_tensor")
35
36 def __init__(self, capture, function):
37 self.capture = capture
38 self.function = function
39 self._exported_tensor = None
40
41 def _export_to_saved_model_graph(self, tensor_map, **unused_kwargs):
42 capture_constant_value = tensor_util.constant_value(self.capture)
43 if capture_constant_value is None:
44 raise ValueError(
45 f"Unable to save function {self.function.name} because it "
46 f"captures graph tensor {self.capture} from a parent function which "
47 "cannot be converted to a constant with `tf.get_static_value`.")
48
49 if numpy.prod(self.capture.shape.as_list()) > 1 and numpy.all(
50 capture_constant_value == capture_constant_value.flat[0]):
51 # For the common case of a constant array filled with the same
52 # value, rebuild the constant op specifically with the shape arg,
53 # since otherwise the whole array is written into the node def,
54 # causing performance and graph proto size issues (protos cannot be
55 # bigger than 2GB).
56 copied_tensor = constant_op.constant(
57 capture_constant_value.flat[0],
58 dtype=self.capture.dtype,
59 shape=self.capture.shape)
60 else:
61 copied_tensor = constant_op.constant(capture_constant_value)
62
63 tensor_map[self.capture] = copied_tensor
64 self._exported_tensor = copied_tensor
65 return [self.capture]
66
67 def _serialize_to_proto(self, object_proto=None, **kwargs):
68 object_proto.constant.operation = self._exported_tensor.op.name
69
70 @classmethod
71 def _deserialize_from_proto(cls, object_proto, operation_attributes,
72 **kwargs):
73 tensor_proto = (
74 operation_attributes[object_proto.constant.operation]["value"].tensor)
75 ndarray = tensor_util.MakeNdarray(tensor_proto)
76 if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string:
77 with ops.device("CPU"):
78 # String operations should be done on the CPU.
79 imported_constant = constant_op.constant(ndarray)
80 else:
81 imported_constant = constant_op.constant(ndarray)
82 return imported_constant