Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/options.py: 26%

70 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for tf.data options.""" 

16 

17import collections 

18 

19from absl import logging 

20 

21 

22def _internal_attr_name(name): 

23 return "_" + name 

24 

25 

26class OptionsBase: 

27 """Base class for representing a set of tf.data options. 

28 

29 Attributes: 

30 _options: Stores the option values. 

31 """ 

32 

33 def __init__(self): 

34 # NOTE: Cannot use `self._options` here as we override `__setattr__` 

35 object.__setattr__(self, "_options", {}) 

36 object.__setattr__(self, "_mutable", True) 

37 

38 def __eq__(self, other): 

39 if not isinstance(other, self.__class__): 

40 return NotImplemented 

41 for name in set(self._options) | set(other._options): # pylint: disable=protected-access 

42 if getattr(self, name) != getattr(other, name): 

43 return False 

44 return True 

45 

46 def __ne__(self, other): 

47 if isinstance(other, self.__class__): 

48 return not self.__eq__(other) 

49 else: 

50 return NotImplemented 

51 

52 def __setattr__(self, name, value): 

53 if not self._mutable: 

54 raise ValueError("Mutating `tf.data.Options()` returned by " 

55 "`tf.data.Dataset.options()` has no effect. Use " 

56 "`tf.data.Dataset.with_options(options)` to set or " 

57 "update dataset options.") 

58 if hasattr(self, name): 

59 object.__setattr__(self, name, value) 

60 else: 

61 raise AttributeError("Cannot set the property {} on {}.".format( 

62 name, 

63 type(self).__name__)) 

64 

65 def _set_mutable(self, mutable): 

66 """Change the mutability property to `mutable`.""" 

67 object.__setattr__(self, "_mutable", mutable) 

68 

69 def _to_proto(self): 

70 """Convert options to protocol buffer.""" 

71 raise NotImplementedError("{}._to_proto()".format(type(self).__name__)) 

72 

73 def _from_proto(self, pb): 

74 """Convert protocol buffer to options.""" 

75 raise NotImplementedError("{}._from_proto()".format(type(self).__name__)) 

76 

77 

78# Creates a namedtuple with three keys for optimization graph rewrites settings. 

79def graph_rewrites(): 

80 return collections.namedtuple("GraphRewrites", 

81 ["enabled", "disabled", "default"]) 

82 

83 

84def create_option(name, ty, docstring, default_factory=lambda: None): 

85 """Creates a type-checked property. 

86 

87 Args: 

88 name: The name to use. 

89 ty: The type to use. The type of the property will be validated when it 

90 is set. 

91 docstring: The docstring to use. 

92 default_factory: A callable that takes no arguments and returns a default 

93 value to use if not set. 

94 

95 Returns: 

96 A type-checked property. 

97 """ 

98 

99 def get_fn(option): 

100 # pylint: disable=protected-access 

101 if name not in option._options: 

102 option._options[name] = default_factory() 

103 return option._options.get(name) 

104 

105 def set_fn(option, value): 

106 if not isinstance(value, ty): 

107 raise TypeError( 

108 "Property \"{}\" must be of type {}, got: {} (type: {})".format( 

109 name, ty, value, type(value))) 

110 option._options[name] = value # pylint: disable=protected-access 

111 

112 return property(get_fn, set_fn, None, docstring) 

113 

114 

115def merge_options(*options_list): 

116 """Merges the given options, returning the result as a new options object. 

117 

118 The input arguments are expected to have a matching type that derives from 

119 `tf.data.OptionsBase` (and thus each represent a set of options). The method 

120 outputs an object of the same type created by merging the sets of options 

121 represented by the input arguments. 

122 

123 If an option is set to different values by different options objects, the 

124 result will match the setting of the options object that appears in the input 

125 list last. 

126 

127 If an option is an instance of `tf.data.OptionsBase` itself, then this method 

128 is applied recursively to the set of options represented by this option. 

129 

130 Args: 

131 *options_list: options to merge 

132 

133 Raises: 

134 TypeError: if the input arguments are incompatible or not derived from 

135 `tf.data.OptionsBase` 

136 

137 Returns: 

138 A new options object which is the result of merging the given options. 

139 """ 

140 if len(options_list) < 1: 

141 raise ValueError("At least one options should be provided") 

142 result_type = type(options_list[0]) 

143 

144 for options in options_list: 

145 if not isinstance(options, result_type): 

146 raise TypeError( 

147 "Could not merge incompatible options of type {} and {}.".format( 

148 type(options), result_type)) 

149 

150 if not isinstance(options_list[0], OptionsBase): 

151 raise TypeError( 

152 "All options to be merged should inherit from `OptionsBase` but found " 

153 "option of type {} which does not.".format(type(options_list[0]))) 

154 

155 default_options = result_type() 

156 result = result_type() 

157 for options in options_list: 

158 # Iterate over all set options and merge them into the result. 

159 for name in options._options: # pylint: disable=protected-access 

160 this = getattr(result, name) 

161 that = getattr(options, name) 

162 default = getattr(default_options, name) 

163 if that == default: 

164 continue 

165 elif this == default: 

166 setattr(result, name, that) 

167 elif isinstance(this, OptionsBase): 

168 setattr(result, name, merge_options(this, that)) 

169 elif this != that: 

170 logging.warning("Changing the value of option %s from %r to %r.", name, 

171 this, that) 

172 setattr(result, name, that) 

173 return result