Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/saved_model_utils.py: 44%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=unidiomatic-typecheck 

16"""A shim layer for working with functions exported/restored from saved models. 

17 

18This functionality should ultimately be moved into a first-class core API. 

19""" 

20 

21import numpy 

22 

23from tensorflow.python.framework import constant_op 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import tensor_util 

27from tensorflow.python.saved_model import registration 

28from tensorflow.python.trackable import base as trackable 

29 

30 

31@registration.register_tf_serializable() 

32class TrackableConstant(trackable.Trackable): 

33 """Trackable class for captured constants.""" 

34 __slots__ = ("capture", "function", "_exported_tensor") 

35 

36 def __init__(self, capture, function): 

37 self.capture = capture 

38 self.function = function 

39 self._exported_tensor = None 

40 

41 def _export_to_saved_model_graph(self, tensor_map, **unused_kwargs): 

42 capture_constant_value = tensor_util.constant_value(self.capture) 

43 if capture_constant_value is None: 

44 raise ValueError( 

45 f"Unable to save function {self.function.name} because it " 

46 f"captures graph tensor {self.capture} from a parent function which " 

47 "cannot be converted to a constant with `tf.get_static_value`.") 

48 

49 if numpy.prod(self.capture.shape.as_list()) > 1 and numpy.all( 

50 capture_constant_value == capture_constant_value.flat[0]): 

51 # For the common case of a constant array filled with the same 

52 # value, rebuild the constant op specifically with the shape arg, 

53 # since otherwise the whole array is written into the node def, 

54 # causing performance and graph proto size issues (protos cannot be 

55 # bigger than 2GB). 

56 copied_tensor = constant_op.constant( 

57 capture_constant_value.flat[0], 

58 dtype=self.capture.dtype, 

59 shape=self.capture.shape) 

60 else: 

61 copied_tensor = constant_op.constant(capture_constant_value) 

62 

63 tensor_map[self.capture] = copied_tensor 

64 self._exported_tensor = copied_tensor 

65 return [self.capture] 

66 

67 def _serialize_to_proto(self, object_proto=None, **kwargs): 

68 object_proto.constant.operation = self._exported_tensor.op.name 

69 

70 @classmethod 

71 def _deserialize_from_proto(cls, object_proto, operation_attributes, 

72 **kwargs): 

73 tensor_proto = ( 

74 operation_attributes[object_proto.constant.operation]["value"].tensor) 

75 ndarray = tensor_util.MakeNdarray(tensor_proto) 

76 if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string: 

77 with ops.device("CPU"): 

78 # String operations should be done on the CPU. 

79 imported_constant = constant_op.constant(ndarray) 

80 else: 

81 imported_constant = constant_op.constant(ndarray) 

82 return imported_constant