Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/device.py: 46%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Class to represent a device.""" 

17 

18from tensorflow.python import tf2 

19from tensorflow.python.framework import device_spec 

20 

21if tf2.enabled(): 

22 DeviceSpec = device_spec.DeviceSpecV2 

23else: 

24 DeviceSpec = device_spec.DeviceSpecV1 

25 

26 

27def check_valid(spec): 

28 """Check that a device spec is valid. 

29 

30 Args: 

31 spec: a string. 

32 

33 Raises: 

34 An exception if the spec is invalid. 

35 """ 

36 # Construct a DeviceSpec. It will assert a failure if spec is invalid. 

37 DeviceSpec.from_string(spec) 

38 

39 

40def is_device_spec(obj): 

41 """Abstract away the fact that DeviceSpecV2 is the base class.""" 

42 return isinstance(obj, device_spec.DeviceSpecV2) 

43 

44 

45def canonical_name(device): 

46 """Returns a canonical name for the given `DeviceSpec` or device name.""" 

47 if device is None: 

48 return "" 

49 if is_device_spec(device): 

50 return device.to_string() 

51 else: 

52 device = DeviceSpec.from_string(device) 

53 return device.to_string() 

54 

55 

56# Performance caches 

57_cached_mergers = {} 

58_string_merge_cache = {} 

59 

60 

61def merge_device(spec): 

62 """Returns a device function that merges devices specifications. 

63 

64 This can be used to merge partial specifications of devices. The 

65 innermost setting for a device field takes precedence. For example: 

66 

67 with tf.device(merge_device("/device:GPU:0")) 

68 # Nodes created here have device "/device:GPU:0" 

69 with tf.device(merge_device("/job:worker")): 

70 # Nodes created here have device "/job:worker/device:GPU:0" 

71 with tf.device(merge_device("/device:CPU:0")): 

72 # Nodes created here have device "/job:worker/device:CPU:0" 

73 with tf.device(merge_device("/job:ps")): 

74 # Nodes created here have device "/job:ps/device:CPU:0" 

75 

76 Args: 

77 spec: A `DeviceSpec` or a device spec string (partially) describing the 

78 device that should be used for all nodes created in the scope of 

79 the returned device function's with block. 

80 

81 Returns: 

82 A MergeDevice object with the above-described behavior. 

83 

84 Raises: 

85 ValueError: if the spec was not valid. 

86 """ 

87 

88 if isinstance(spec, MergeDevice): 

89 return spec 

90 

91 merger = _cached_mergers.get(spec) 

92 if merger: 

93 return merger 

94 merger = MergeDevice(spec) 

95 # No locking needed, since updates are stateless. 

96 _cached_mergers[spec] = merger 

97 return merger 

98 

99 

100class MergeDevice(object): 

101 """Wraps a device specification (DeviceSpec or str) with merge functionality. 

102 

103 When called, this class will merge a node_def with its own spec. It also 

104 exposes a `shortcut_string_merge` method which can significantly improve 

105 performance of device placement. 

106 """ 

107 

108 __slots__ = ["_spec"] 

109 

110 def __init__(self, spec): 

111 if isinstance(spec, device_spec.DeviceSpecV2): 

112 self._spec = spec 

113 elif isinstance(spec, device_spec.DeviceSpecV1): 

114 # Capture a snapshot of spec. 

115 self._spec = spec.__class__.from_string(spec.to_string()) 

116 else: 

117 self._spec = DeviceSpec.from_string(spec) 

118 

119 def __call__(self, node_def): 

120 # In general a user may create a device function which takes into account 

121 # arbitrary properties of an op. (For instance dynamically placing ops based 

122 # on type.) So even though the standard DeviceSpec route only uses the 

123 # device attribute, we take an entire node_def to maintain a consistent 

124 # signature with general device functions. 

125 current_device = DeviceSpec.from_string(node_def.device or "") 

126 return self._spec.make_merged_spec(current_device) 

127 

128 def shortcut_string_merge(self, node_def): 

129 """Merge a node def without materializing a full DeviceSpec object. 

130 

131 Often a device merge is invoked in order to generate a string which can be 

132 passed into the c api. In such a case, we can cache the 

133 node_def.device -> merge_result_string 

134 

135 map, and in most cases avoid: 

136 - Materializing a copy of self._spec (In the case of DeviceSpecV1) 

137 - Materializing a DeviceSpec for node_def.device 

138 - A DeviceSpec.merge_from invocation 

139 

140 In practice the cache hit rate for this function is very high, because the 

141 number of invocations when iterating through the device stack is much 

142 larger than the number of devices. 

143 

144 Args: 

145 node_def: An Operation (or Operation-like) to merge device constraints 

146 with self._spec 

147 

148 Returns: 

149 A string containing the merged device specification. 

150 """ 

151 device = node_def.device or "" 

152 

153 merge_key = (self._spec, device) 

154 result = _string_merge_cache.get(merge_key) 

155 if result is None: 

156 # This update is not atomic, however because the merge is stateless 

157 # we don't need to lock when updating the cache. 

158 result = self.__call__(node_def).to_string() 

159 _string_merge_cache[merge_key] = result 

160 

161 return result 

162 

163 def __repr__(self): 

164 return "{} (spec: {})".format( 

165 super(MergeDevice, self).__repr__(), self._spec.to_string()) 

166 

167 @property 

168 def is_null_merge(self): 

169 """Indicate whether the wrapped spec is empty. 

170 

171 In the degenerate case where self._spec is an empty specification, a caller 

172 may wish to skip a merge step entirely. (However this class does not have 

173 enough information to make that determination.) 

174 

175 Returns: 

176 A boolean indicating whether a device merge will be trivial. 

177 """ 

178 return not bool(self._spec.to_string())